Fix tests.
This commit is contained in:
parent
f885af4960
commit
1d81c47283
@ -146,54 +146,51 @@ class TorchTrainer(base_trainer.Trainer):
|
||||
self.predict_function = one_step_on_data
|
||||
|
||||
def _symbolic_build(self, data_batch):
|
||||
try:
|
||||
model_unbuilt = not all(
|
||||
layer.built for layer in self._flatten_layers()
|
||||
)
|
||||
compile_metrics_unbuilt = (
|
||||
self._compile_metrics is not None
|
||||
and not self._compile_metrics.built
|
||||
)
|
||||
if model_unbuilt or compile_metrics_unbuilt:
|
||||
# Create symbolic tensors matching an input batch.
|
||||
model_unbuilt = not all(layer.built for layer in self._flatten_layers())
|
||||
compile_metrics_unbuilt = (
|
||||
self._compile_metrics is not None
|
||||
and not self._compile_metrics.built
|
||||
)
|
||||
if model_unbuilt or compile_metrics_unbuilt:
|
||||
# Create symbolic tensors matching an input batch.
|
||||
|
||||
def to_symbolic_input(v):
|
||||
if is_tensor(v):
|
||||
return KerasTensor(v.shape, standardize_dtype(v.dtype))
|
||||
return v
|
||||
def to_symbolic_input(v):
|
||||
if is_tensor(v):
|
||||
return KerasTensor(v.shape, standardize_dtype(v.dtype))
|
||||
return v
|
||||
|
||||
data_batch = tf.nest.map_structure(
|
||||
to_symbolic_input, data_batch
|
||||
data_batch = tf.nest.map_structure(to_symbolic_input, data_batch)
|
||||
(
|
||||
x,
|
||||
y,
|
||||
sample_weight,
|
||||
) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)
|
||||
# Build all model state with `backend.compute_output_spec`.
|
||||
try:
|
||||
y_pred = backend.compute_output_spec(self, x)
|
||||
except:
|
||||
raise RuntimeError(
|
||||
"Unable to automatically build the model. "
|
||||
"Please build it yourself before calling "
|
||||
"fit/evaluate/predict. "
|
||||
"A model is 'built' when its variables have "
|
||||
"been created and its `self.built` attribute "
|
||||
"is True. Usually, calling the model on a batch "
|
||||
"of data is the right way to build it."
|
||||
)
|
||||
(
|
||||
if compile_metrics_unbuilt:
|
||||
# Build all metric state with `backend.compute_output_spec`.
|
||||
backend.compute_output_spec(
|
||||
self.compute_metrics,
|
||||
x,
|
||||
y,
|
||||
sample_weight,
|
||||
) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)
|
||||
# Build all model state with `backend.compute_output_spec`.
|
||||
y_pred = backend.compute_output_spec(self, x)
|
||||
if compile_metrics_unbuilt:
|
||||
# Build all metric state with `backend.compute_output_spec`.
|
||||
backend.compute_output_spec(
|
||||
self.compute_metrics,
|
||||
x,
|
||||
y,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
)
|
||||
if self.optimizer is not None and not self.optimizer.built:
|
||||
# Build optimizer
|
||||
self.optimizer.build(self.trainable_variables)
|
||||
self._post_build()
|
||||
except:
|
||||
raise RuntimeError(
|
||||
"Unable to automatically build the model. "
|
||||
"Please build it yourself before calling fit/evaluate/predict. "
|
||||
"A model is 'built' when its variables have been created and "
|
||||
"its `self.built` attribute is True. Usually, "
|
||||
"calling the model on a batch of data "
|
||||
"is the right way to build it."
|
||||
)
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
)
|
||||
if self.optimizer is not None and not self.optimizer.built:
|
||||
# Build optimizer
|
||||
self.optimizer.build(self.trainable_variables)
|
||||
self._post_build()
|
||||
|
||||
@traceback_utils.filter_traceback
|
||||
def fit(
|
||||
|
Loading…
Reference in New Issue
Block a user