Fix tests.
This commit is contained in:
parent
f885af4960
commit
1d81c47283
@ -146,10 +146,7 @@ class TorchTrainer(base_trainer.Trainer):
|
|||||||
self.predict_function = one_step_on_data
|
self.predict_function = one_step_on_data
|
||||||
|
|
||||||
def _symbolic_build(self, data_batch):
|
def _symbolic_build(self, data_batch):
|
||||||
try:
|
model_unbuilt = not all(layer.built for layer in self._flatten_layers())
|
||||||
model_unbuilt = not all(
|
|
||||||
layer.built for layer in self._flatten_layers()
|
|
||||||
)
|
|
||||||
compile_metrics_unbuilt = (
|
compile_metrics_unbuilt = (
|
||||||
self._compile_metrics is not None
|
self._compile_metrics is not None
|
||||||
and not self._compile_metrics.built
|
and not self._compile_metrics.built
|
||||||
@ -162,16 +159,25 @@ class TorchTrainer(base_trainer.Trainer):
|
|||||||
return KerasTensor(v.shape, standardize_dtype(v.dtype))
|
return KerasTensor(v.shape, standardize_dtype(v.dtype))
|
||||||
return v
|
return v
|
||||||
|
|
||||||
data_batch = tf.nest.map_structure(
|
data_batch = tf.nest.map_structure(to_symbolic_input, data_batch)
|
||||||
to_symbolic_input, data_batch
|
|
||||||
)
|
|
||||||
(
|
(
|
||||||
x,
|
x,
|
||||||
y,
|
y,
|
||||||
sample_weight,
|
sample_weight,
|
||||||
) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)
|
) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)
|
||||||
# Build all model state with `backend.compute_output_spec`.
|
# Build all model state with `backend.compute_output_spec`.
|
||||||
|
try:
|
||||||
y_pred = backend.compute_output_spec(self, x)
|
y_pred = backend.compute_output_spec(self, x)
|
||||||
|
except:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Unable to automatically build the model. "
|
||||||
|
"Please build it yourself before calling "
|
||||||
|
"fit/evaluate/predict. "
|
||||||
|
"A model is 'built' when its variables have "
|
||||||
|
"been created and its `self.built` attribute "
|
||||||
|
"is True. Usually, calling the model on a batch "
|
||||||
|
"of data is the right way to build it."
|
||||||
|
)
|
||||||
if compile_metrics_unbuilt:
|
if compile_metrics_unbuilt:
|
||||||
# Build all metric state with `backend.compute_output_spec`.
|
# Build all metric state with `backend.compute_output_spec`.
|
||||||
backend.compute_output_spec(
|
backend.compute_output_spec(
|
||||||
@ -185,15 +191,6 @@ class TorchTrainer(base_trainer.Trainer):
|
|||||||
# Build optimizer
|
# Build optimizer
|
||||||
self.optimizer.build(self.trainable_variables)
|
self.optimizer.build(self.trainable_variables)
|
||||||
self._post_build()
|
self._post_build()
|
||||||
except:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Unable to automatically build the model. "
|
|
||||||
"Please build it yourself before calling fit/evaluate/predict. "
|
|
||||||
"A model is 'built' when its variables have been created and "
|
|
||||||
"its `self.built` attribute is True. Usually, "
|
|
||||||
"calling the model on a batch of data "
|
|
||||||
"is the right way to build it."
|
|
||||||
)
|
|
||||||
|
|
||||||
@traceback_utils.filter_traceback
|
@traceback_utils.filter_traceback
|
||||||
def fit(
|
def fit(
|
||||||
|
Loading…
Reference in New Issue
Block a user