Fix tests.

This commit is contained in:
Francois Chollet 2023-06-27 21:37:53 -07:00
parent f885af4960
commit 1d81c47283

@ -146,10 +146,7 @@ class TorchTrainer(base_trainer.Trainer):
self.predict_function = one_step_on_data self.predict_function = one_step_on_data
def _symbolic_build(self, data_batch): def _symbolic_build(self, data_batch):
try: model_unbuilt = not all(layer.built for layer in self._flatten_layers())
model_unbuilt = not all(
layer.built for layer in self._flatten_layers()
)
compile_metrics_unbuilt = ( compile_metrics_unbuilt = (
self._compile_metrics is not None self._compile_metrics is not None
and not self._compile_metrics.built and not self._compile_metrics.built
@ -162,16 +159,25 @@ class TorchTrainer(base_trainer.Trainer):
return KerasTensor(v.shape, standardize_dtype(v.dtype)) return KerasTensor(v.shape, standardize_dtype(v.dtype))
return v return v
data_batch = tf.nest.map_structure( data_batch = tf.nest.map_structure(to_symbolic_input, data_batch)
to_symbolic_input, data_batch
)
( (
x, x,
y, y,
sample_weight, sample_weight,
) = data_adapter_utils.unpack_x_y_sample_weight(data_batch) ) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)
# Build all model state with `backend.compute_output_spec`. # Build all model state with `backend.compute_output_spec`.
try:
y_pred = backend.compute_output_spec(self, x) y_pred = backend.compute_output_spec(self, x)
except:
raise RuntimeError(
"Unable to automatically build the model. "
"Please build it yourself before calling "
"fit/evaluate/predict. "
"A model is 'built' when its variables have "
"been created and its `self.built` attribute "
"is True. Usually, calling the model on a batch "
"of data is the right way to build it."
)
if compile_metrics_unbuilt: if compile_metrics_unbuilt:
# Build all metric state with `backend.compute_output_spec`. # Build all metric state with `backend.compute_output_spec`.
backend.compute_output_spec( backend.compute_output_spec(
@ -185,15 +191,6 @@ class TorchTrainer(base_trainer.Trainer):
# Build optimizer # Build optimizer
self.optimizer.build(self.trainable_variables) self.optimizer.build(self.trainable_variables)
self._post_build() self._post_build()
except:
raise RuntimeError(
"Unable to automatically build the model. "
"Please build it yourself before calling fit/evaluate/predict. "
"A model is 'built' when its variables have been created and "
"its `self.built` attribute is True. Usually, "
"calling the model on a batch of data "
"is the right way to build it."
)
@traceback_utils.filter_traceback @traceback_utils.filter_traceback
def fit( def fit(