Sort metrics in base trainer.

This commit is contained in:
Francois Chollet 2023-06-12 11:58:24 -07:00
parent d579f815cd
commit 02076132b8
5 changed files with 57 additions and 309 deletions

@ -29,11 +29,6 @@ history = model.fit(
model.evaluate(x, y, verbose=0)
model.predict(x, verbose=0)
# Test on batch functions
model.train_on_batch(x, y)
model.test_on_batch(x, y)
model.predict_on_batch(x)
# Test functional model.
inputs = keras_core.Input(shape=(32, 32, 3))
outputs = layers.Conv2D(filters=10, kernel_size=3)(inputs)

@ -18,100 +18,45 @@ class TorchTrainer(base_trainer.Trainer):
self.test_function = None
self.predict_function = None
def make_train_function(self, force=False):
if self.train_function is not None and not force:
return self.train_function
def train_step(self, data):
data = data[0]
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
def one_step_on_data(data):
"""Runs a single training step on a batch of data."""
data = data[0]
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(
data
)
# Compute prediction error
if self._call_has_training_arg():
y_pred = self(x, training=True)
else:
y_pred = self(x)
# Compute prediction error
if self._call_has_training_arg():
y_pred = self(x, training=True)
else:
y_pred = self(x)
# Call torch.nn.Module.zero_grad() to clear the leftover gradients for
# the weights from the previous train step.
self.zero_grad()
# Call torch.nn.Module.zero_grad() to clear the leftover gradients
# for the weights from the previous train step.
self.zero_grad()
loss = self.compute_loss(
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
)
self._loss_tracker.update_state(loss)
loss = self.compute_loss(
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
)
self._loss_tracker.update_state(loss)
# Compute gradients
if self.trainable_weights:
# Backpropagation
trainable_weights = [v for v in self.trainable_weights]
# Compute gradients
if self.trainable_weights:
# Backpropagation
trainable_weights = [v for v in self.trainable_weights]
# Call torch.Tensor.backward() on the loss to compute gradients for
# the weights.
loss.backward()
# Call torch.Tensor.backward() on the loss to compute gradients
# for the weights.
loss.backward()
gradients = [v.value.grad for v in trainable_weights]
gradients = [v.value.grad for v in trainable_weights]
# Update weights
with torch.no_grad():
self.optimizer.apply_gradients(
zip(gradients, trainable_weights)
)
else:
warnings.warn("The model does not have any trainable weights.")
return self.compute_metrics(
x, y, y_pred, sample_weight=sample_weight
)
self.train_function = one_step_on_data
def make_test_function(self, force=False):
if self.test_function is not None and not force:
return self.test_function
def one_step_on_data(data):
"""Runs a single test step on a batch of data."""
# Update weights
with torch.no_grad():
data = data[0]
(
x,
y,
sample_weight,
) = data_adapter_utils.unpack_x_y_sample_weight(data)
if self._call_has_training_arg():
y_pred = self(x, training=False)
else:
y_pred = self(x)
loss = self.compute_loss(
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
)
self._loss_tracker.update_state(loss)
return self.compute_metrics(
x, y, y_pred, sample_weight=sample_weight
self.optimizer.apply_gradients(
zip(gradients, trainable_weights)
)
else:
warnings.warn("The model does not have any trainable weights.")
self.test_function = one_step_on_data
def make_predict_function(self, force=False):
if self.predict_function is not None and not force:
return self.predict_function
def one_step_on_data(data):
"""Runs a predict test step on a batch of data."""
with torch.no_grad():
data = data[0]
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)
if self._call_has_training_arg():
y_pred = self(x, training=False)
else:
y_pred = self(x)
return y_pred
self.predict_function = one_step_on_data
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
def fit(
self,
@ -182,7 +127,6 @@ class TorchTrainer(base_trainer.Trainer):
)
self.stop_training = False
self.make_train_function()
callbacks.on_train_begin()
for epoch in range(initial_epoch, epochs):
@ -197,7 +141,7 @@ class TorchTrainer(base_trainer.Trainer):
# Callbacks
callbacks.on_train_batch_begin(step)
logs = self.train_function(data)
logs = self.train_step(data)
# Callbacks
callbacks.on_train_batch_end(step, self._pythonify_logs(logs))
@ -253,6 +197,19 @@ class TorchTrainer(base_trainer.Trainer):
callbacks.on_train_end(logs=training_logs)
return self.history
def test_step(self, data):
data = data[0]
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
if self._call_has_training_arg():
y_pred = self(x, training=False)
else:
y_pred = self(x)
loss = self.compute_loss(
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
)
self._loss_tracker.update_state(loss)
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
def evaluate(
self,
x=None,
@ -299,13 +256,13 @@ class TorchTrainer(base_trainer.Trainer):
# Switch the torch Module back to testing mode.
self.eval()
self.make_test_function()
callbacks.on_test_begin()
logs = None
self.reset_metrics()
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_test_batch_begin(step)
logs = self.test_function(data)
with torch.no_grad():
logs = self.test_step(data)
callbacks.on_test_batch_end(step, self._pythonify_logs(logs))
logs = self.get_metrics_result()
callbacks.on_test_end(logs)
@ -314,6 +271,15 @@ class TorchTrainer(base_trainer.Trainer):
return logs
return self._flatten_metrics_in_order(logs)
def predict_step(self, data):
data = data[0]
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)
if self._call_has_training_arg():
y_pred = self(x, training=False)
else:
y_pred = self(x)
return y_pred
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
):
@ -356,126 +322,15 @@ class TorchTrainer(base_trainer.Trainer):
# Switch the torch Module back to testing mode.
self.eval()
self.make_predict_function()
callbacks.on_predict_begin()
outputs = None
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_predict_batch_begin(step)
batch_outputs = self.predict_function(data)
with torch.no_grad():
batch_outputs = self.predict_step(data)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_end()
return tf.__internal__.nest.map_structure_up_to(
batch_outputs, np.concatenate, outputs
)
def train_on_batch(
self,
x,
y=None,
sample_weight=None,
class_weight=None,
return_dict=False,
):
"""Runs a single gradient update on a single batch of data.
Args:
x: Input data. Must be array-like.
y: Target data. Must be array-like.
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample.
In the case of temporal data, you can pass a 2D array
with shape `(samples, sequence_length)`, to apply a different
weight to every timestep of every sample.
class_weight: Optional dictionary mapping class indices (integers)
to a weight (float) to apply to the model's loss for the samples
from this class during training. This can be useful to tell the
model to "pay more attention" to samples from an
under-represented class. When `class_weight` is specified
and targets have a rank of 2 or greater, either `y` must
be one-hot encoded, or an explicit final dimension of 1
must be included for sparse class labels.
return_dict: If `True`, loss and metric results are returned as a
dict, with each key being the name of the metric. If `False`,
they are returned as a list.
Returns:
A scalar loss value (when no metrics and `return_dict=False`),
a list of loss and metric values
(if there are metrics and `return_dict=False`), or a dict of
metric and loss values (if `return_dict=True`).
"""
self._assert_compile_called("train_on_batch")
self.make_train_function()
if class_weight is not None:
if sample_weight is not None:
raise ValueError(
"Arguments `sample_weight` and `class_weight` "
"cannot be specified at the same time. "
f"Received: sample_weight={sample_weight}, "
f"class_weight={class_weight}"
)
sample_weight = data_adapter_utils.class_weight_to_sample_weights(
y, class_weight
)
data = (x, y, sample_weight)
logs = self.train_function([data])
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
def test_on_batch(
self,
x,
y=None,
sample_weight=None,
return_dict=False,
):
"""Test the model on a single batch of samples.
Args:
x: Input data. Must be array-like.
y: Target data. Must be array-like.
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample.
In the case of temporal data, you can pass a 2D array
with shape `(samples, sequence_length)`, to apply a different
weight to every timestep of every sample.
return_dict: If `True`, loss and metric results are returned as a
dict, with each key being the name of the metric. If `False`,
they are returned as a list.
Returns:
A scalar loss value (when no metrics and `return_dict=False`),
a list of loss and metric values
(if there are metrics and `return_dict=False`), or a dict of
metric and loss values (if `return_dict=True`).
"""
self._assert_compile_called("test_on_batch")
self.make_test_function()
data = (x, y, sample_weight)
logs = self.test_function([data])
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
def predict_on_batch(self, x):
"""Returns predictions for a single batch of samples.
Args:
x: Input data. It must be array-like.
Returns:
NumPy array(s) of predictions.
"""
self.make_predict_function()
batch_outputs = self.predict_function((x,))
batch_outputs = tf.nest.map_structure(
lambda x: np.array(x), batch_outputs
)
return batch_outputs

@ -152,9 +152,6 @@ class Functional(Function, Model):
self._layers = self.layers
self.built = True
# We will convert directly (to the correct dtype per input).
self._convert_input_args = False
self._allow_non_tensor_positional_args = True
self._post_build()
@property
@ -235,16 +232,6 @@ class Functional(Function, Model):
# Otherwise both ref inputs and inputs will already be in same order.
return nest.flatten(inputs)
def _convert_inputs_to_tensors(self, flat_inputs):
flat_dtypes = [x.dtype for x in self._inputs]
converted = []
for x, dtype in zip(flat_inputs, flat_dtypes):
if backend.is_tensor(x):
converted.append(backend.cast(x, dtype=dtype))
else:
converted.append(backend.convert_to_tensor(x, dtype=dtype))
return converted
def _adjust_input_rank(self, flat_inputs):
flat_ref_shapes = [x.shape for x in self._inputs]
adjusted = []
@ -276,7 +263,6 @@ class Functional(Function, Model):
def _standardize_inputs(self, inputs):
flat_inputs = self._flatten_to_reference_inputs(inputs)
flat_inputs = self._convert_inputs_to_tensors(flat_inputs)
return self._adjust_input_rank(flat_inputs)
@property

@ -171,16 +171,6 @@ class FunctionalTest(testing.TestCase):
out_val = model(np.random.random((2, 3)))
self.assertEqual(out_val.shape, (2, 3, 3))
def test_dtype_standardization(self):
float_input = Input(shape=(2,), dtype="float16")
int_input = Input(shape=(2,), dtype="int32")
float_output = float_input + 2
int_output = int_input + 2
model = Functional((float_input, int_input), (float_output, int_output))
float_data, int_data = model((np.ones((2, 2)), np.ones((2, 2))))
self.assertEqual(backend.standardize_dtype(float_data.dtype), "float16")
self.assertEqual(backend.standardize_dtype(int_data.dtype), "int32")
def test_serialization(self):
# Test basic model
inputs = Input(shape=(3,), batch_size=2)

@ -277,84 +277,6 @@ class Trainer:
):
raise NotImplementedError
def train_on_batch(
self,
x,
y=None,
sample_weight=None,
class_weight=None,
return_dict=False,
):
"""Runs a single gradient update on a single batch of data.
Args:
x: Input data. Must be array-like.
y: Target data. Must be array-like.
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample.
In the case of temporal data, you can pass a 2D array
with shape `(samples, sequence_length)`, to apply a different
weight to every timestep of every sample.
class_weight: Optional dictionary mapping class indices (integers)
to a weight (float) to apply to the model's loss for the samples
from this class during training. This can be useful to tell the
model to "pay more attention" to samples from an
under-represented class. When `class_weight` is specified
and targets have a rank of 2 or greater, either `y` must
be one-hot encoded, or an explicit final dimension of 1
must be included for sparse class labels.
return_dict: If `True`, loss and metric results are returned as a
dict, with each key being the name of the metric. If `False`,
they are returned as a list.
Returns:
A scalar loss value (when no metrics and `return_dict=False`),
a list of loss and metric values
(if there are metrics and `return_dict=False`), or a dict of
metric and loss values (if `return_dict=True`).
"""
raise NotImplementedError
def test_on_batch(
self,
x,
y=None,
sample_weight=None,
return_dict=False,
):
"""Test the model on a single batch of samples.
Args:
x: Input data. Must be array-like.
y: Target data. Must be array-like.
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample.
In the case of temporal data, you can pass a 2D array
with shape `(samples, sequence_length)`, to apply a different
weight to every timestep of every sample.
return_dict: If `True`, loss and metric results are returned as a
dict, with each key being the name of the metric. If `False`,
they are returned as a list.
Returns:
A scalar loss value (when no metrics and `return_dict=False`),
a list of loss and metric values
(if there are metrics and `return_dict=False`), or a dict of
metric and loss values (if `return_dict=True`).
"""
raise NotImplementedError
def predict_on_batch(self, x):
"""Returns predictions for a single batch of samples.
Args:
x: Input data. It must be array-like.
Returns:
NumPy array(s) of predictions.
"""
raise NotImplementedError
def get_compile_config(self):
"""Returns a serialized config with information for compiling the model.
@ -410,7 +332,7 @@ class Trainer:
def _pythonify_logs(self, logs):
result = {}
for key, value in logs.items():
for key, value in sorted(logs.items()):
try:
value = float(value)
except: