Fix tests.

This commit is contained in:
Francois Chollet 2023-06-27 21:37:53 -07:00
parent f885af4960
commit 1d81c47283

@ -146,54 +146,51 @@ class TorchTrainer(base_trainer.Trainer):
self.predict_function = one_step_on_data
def _symbolic_build(self, data_batch):
try:
model_unbuilt = not all(
layer.built for layer in self._flatten_layers()
)
compile_metrics_unbuilt = (
self._compile_metrics is not None
and not self._compile_metrics.built
)
if model_unbuilt or compile_metrics_unbuilt:
# Create symbolic tensors matching an input batch.
model_unbuilt = not all(layer.built for layer in self._flatten_layers())
compile_metrics_unbuilt = (
self._compile_metrics is not None
and not self._compile_metrics.built
)
if model_unbuilt or compile_metrics_unbuilt:
# Create symbolic tensors matching an input batch.
def to_symbolic_input(v):
if is_tensor(v):
return KerasTensor(v.shape, standardize_dtype(v.dtype))
return v
def to_symbolic_input(v):
if is_tensor(v):
return KerasTensor(v.shape, standardize_dtype(v.dtype))
return v
data_batch = tf.nest.map_structure(
to_symbolic_input, data_batch
data_batch = tf.nest.map_structure(to_symbolic_input, data_batch)
(
x,
y,
sample_weight,
) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)
# Build all model state with `backend.compute_output_spec`.
try:
y_pred = backend.compute_output_spec(self, x)
except:
raise RuntimeError(
"Unable to automatically build the model. "
"Please build it yourself before calling "
"fit/evaluate/predict. "
"A model is 'built' when its variables have "
"been created and its `self.built` attribute "
"is True. Usually, calling the model on a batch "
"of data is the right way to build it."
)
(
if compile_metrics_unbuilt:
# Build all metric state with `backend.compute_output_spec`.
backend.compute_output_spec(
self.compute_metrics,
x,
y,
sample_weight,
) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)
# Build all model state with `backend.compute_output_spec`.
y_pred = backend.compute_output_spec(self, x)
if compile_metrics_unbuilt:
# Build all metric state with `backend.compute_output_spec`.
backend.compute_output_spec(
self.compute_metrics,
x,
y,
y_pred,
sample_weight=sample_weight,
)
if self.optimizer is not None and not self.optimizer.built:
# Build optimizer
self.optimizer.build(self.trainable_variables)
self._post_build()
except:
raise RuntimeError(
"Unable to automatically build the model. "
"Please build it yourself before calling fit/evaluate/predict. "
"A model is 'built' when its variables have been created and "
"its `self.built` attribute is True. Usually, "
"calling the model on a batch of data "
"is the right way to build it."
)
y_pred,
sample_weight=sample_weight,
)
if self.optimizer is not None and not self.optimizer.built:
# Build optimizer
self.optimizer.build(self.trainable_variables)
self._post_build()
@traceback_utils.filter_traceback
def fit(