Fix bugs.

This commit is contained in:
Francois Chollet 2023-05-09 12:17:31 -07:00
parent 0f1a13b80b
commit 0a1f3da5a9
3 changed files with 18 additions and 12 deletions

@ -74,9 +74,8 @@ def train_step(data):
with tf.GradientTape() as tape:
y_pred = model(x)
loss = loss_fn(y, y_pred)
# !! Glitch to be resolved !!
gradients = tape.gradient(
loss, [v.value for v in model.trainable_variables]
loss, model.trainable_variables
)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss

@ -19,6 +19,12 @@ DYNAMIC_SHAPES_OK = True
class Variable(KerasVariable, tf.__internal__.types.Tensor):
_should_act_as_resource_variable = True
@property
def handle(self):
return self.value.handle
def _initialize(self, value):
self._value = tf.Variable(
value, dtype=self._dtype, trainable=self.trainable

@ -433,6 +433,8 @@ class Layer(Operation):
outputs = super().__call__(*args, **kwargs)
else:
outputs = super().__call__(*args, **kwargs)
if not self.built:
self.built = True
# Record activity regularizer loss.
if self.activity_regularizer is not None:
for output in nest.flatten(outputs):
@ -763,6 +765,9 @@ class Layer(Operation):
f"{list(shapes_dict.keys())}"
)
if failure:
if call_spec.eager:
# Will let the actual eager call do the state-building
return
raise ValueError(
f"Layer '{self.name}' looks like it has "
"unbuilt state, but Keras is not able to "
@ -797,11 +802,7 @@ class Layer(Operation):
try:
backend.compute_output_spec(self.call, input_tensors)
return True
except Exception as e:
warnings.warn(
"Error when attempting to automatically build "
f"the layer by tracing it: {e}"
)
except:
return False
def _build_by_run_for_kwargs(self, shapes_dict):
@ -816,11 +817,7 @@ class Layer(Operation):
try:
backend.compute_output_spec(self.call, **input_tensors)
return True
except Exception as e:
warnings.warn(
"Error when attempting to automatically build "
f"the layer by tracing it: {e}"
)
except:
return False
else:
# Not supported: nested input keyword arguments.
@ -962,6 +959,10 @@ class CallSpec:
self.tensor_arguments_names = tensor_arg_names
self.nested_tensor_argument_names = nested_tensor_arg_names
self.first_arg = arg_dict[arg_names[0]]
if all(backend.is_tensor(x) for x in self.tensor_arguments_dict.values()):
self.eager = True
else:
self.eager = False
def get_arguments_dict(fn, args, kwargs):