keras/keras_core/trainers/trainer.py
Francois Chollet 57d8ac26f1 Nits.
2023-06-20 15:16:03 -07:00

743 lines
32 KiB
Python

import platform
import warnings
from keras_core import backend
from keras_core import metrics as metrics_module
from keras_core import operations as ops
from keras_core import optimizers
from keras_core.saving import serialization_lib
from keras_core.trainers.compile_utils import CompileLoss
from keras_core.trainers.compile_utils import CompileMetrics
from keras_core.utils import traceback_utils
from keras_core.utils import tracking
class Trainer:
def __init__(self):
self._lock = False
self._run_eagerly = False
self._jit_compile = None
self.compiled = False
self.steps_per_execution = 1
@traceback_utils.filter_traceback
@tracking.no_automatic_dependency_tracking
def compile(
self,
optimizer="rmsprop",
loss=None,
loss_weights=None,
metrics=None,
weighted_metrics=None,
run_eagerly=False,
steps_per_execution=1,
jit_compile="auto",
):
self.optimizer = optimizers.get(optimizer)
if loss is not None:
self._compile_loss = CompileLoss(loss, loss_weights)
else:
self._compile_loss = None
if metrics is not None:
self._compile_metrics = CompileMetrics(metrics, weighted_metrics)
else:
self._compile_metrics = None
if jit_compile == "auto":
if not run_eagerly and model_supports_jit(self):
jit_compile = True
else:
jit_compile = False
if jit_compile and run_eagerly:
jit_compile = False
warnings.warn(
"If `run_eagerly` is True, then `jit_compile` "
"cannot also be True. Disabling `jit_compile`.",
stacklevel=2,
)
self.jit_compile = jit_compile
self.run_eagerly = run_eagerly
self.stop_training = False
self.compiled = True
self._loss_tracker = metrics_module.Mean(name="loss")
self.steps_per_execution = steps_per_execution
self._compile_config = serialization_lib.SerializableDict(
optimizer=optimizer,
loss=loss,
loss_weights=loss_weights,
metrics=metrics,
weighted_metrics=weighted_metrics,
run_eagerly=run_eagerly,
steps_per_execution=steps_per_execution,
jit_compile=jit_compile,
)
@property
def jit_compile(self):
if self._jit_compile is None:
# Value was never set. Resolve it now.
jit_compile = model_supports_jit(self)
self._jit_compile = jit_compile
return self._jit_compile
@jit_compile.setter
def jit_compile(self, value):
self._jit_compile = value
@property
def run_eagerly(self):
return self._run_eagerly
@run_eagerly.setter
def run_eagerly(self, value):
self._run_eagerly = value
@property
def metrics(self):
metrics = [self._loss_tracker]
metrics.extend(self._metrics[:])
if self._compile_metrics is not None and self._compile_metrics.built:
metrics += [self._compile_metrics]
return metrics
@property
def metrics_names(self):
return [m.name for m in self.metrics]
@property
def metrics_variables(self):
vars = []
for metric in self.metrics:
vars.extend(metric.variables)
return vars
def reset_metrics(self):
for m in self.metrics:
m.reset_state()
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
"""Compute the total loss, validate it, and return it.
Subclasses can optionally override this method to provide custom loss
computation logic.
Example:
```python
class MyModel(Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_tracker = metrics.Mean(name='loss')
def compute_loss(self, x, y, y_pred, sample_weight):
loss = ops.means((y_pred - y) ** 2)
loss += ops.sum(self.losses)
self.loss_tracker.update_state(loss)
return loss
def reset_metrics(self):
self.loss_tracker.reset_state()
@property
def metrics(self):
return [self.loss_tracker]
inputs = layers.Input(shape=(10,), name='my_input')
outputs = layers.Dense(10)(inputs)
model = MyModel(inputs, outputs)
model.add_loss(ops.sum(outputs))
optimizer = SGD()
model.compile(optimizer, loss='mse', steps_per_execution=10)
dataset = ...
model.fit(dataset, epochs=2, steps_per_epoch=10)
print(f"Custom loss: {model.loss_tracker.result()}")
```
Args:
x: Input data.
y: Target data.
y_pred: Predictions returned by the model (output of `model(x)`)
sample_weight: Sample weights for weighting the loss function.
Returns:
The total loss as a scalar tensor, or `None` if no loss results
(which is the case when called by `Model.test_step`).
"""
del x # The default implementation does not use `x`.
losses = []
if self._compile_loss is not None:
loss = self._compile_loss(y, y_pred, sample_weight)
if loss is not None:
losses.append(loss)
for loss in self.losses:
losses.append(ops.cast(loss, dtype=backend.floatx()))
if len(losses) == 0:
raise ValueError(
"No loss to compute. Provide a `loss` argument in `compile()`."
)
if len(losses) == 1:
total_loss = losses[0]
else:
total_loss = ops.sum(losses)
return total_loss
def compute_metrics(self, x, y, y_pred, sample_weight=None):
"""Update metric states and collect all metrics to be returned.
Subclasses can optionally override this method to provide custom metric
updating and collection logic.
Example:
```python
class MyModel(Sequential):
def compute_metrics(self, x, y, y_pred, sample_weight):
# This super call updates `self.compiled_metrics` and returns
# results for all metrics listed in `self.metrics`.
metric_results = super().compute_metrics(
x, y, y_pred, sample_weight)
# Note that `self.custom_metric` is not listed
# in `self.metrics`.
self.custom_metric.update_state(x, y, y_pred, sample_weight)
metric_results['metric_name'] = self.custom_metric.result()
return metric_results
```
Args:
x: Input data.
y: Target data.
y_pred: Predictions returned by the model output of `model.call(x)`.
sample_weight: Sample weights for weighting the loss function.
Returns:
A `dict` containing values that will be passed to
`keras_core.callbacks.CallbackList.on_train_batch_end()`. Typically,
the values of the metrics listed in `self.metrics` are returned.
Example: `{'loss': 0.2, 'accuracy': 0.7}`.
"""
del x # The default implementation does not use `x`.
if self._compile_metrics is not None:
self._compile_metrics.update_state(y, y_pred, sample_weight)
return self.get_metrics_result()
def get_metrics_result(self):
"""Returns the model's metrics values as a dict.
If any of the metric result is a dict (containing multiple metrics),
each of them gets added to the top level returned dict of this method.
Returns:
A `dict` containing values of the metrics listed in `self.metrics`.
Example: `{'loss': 0.2, 'accuracy': 0.7}`.
"""
return_metrics = {}
for metric in self.metrics:
result = metric.result()
if isinstance(result, dict):
return_metrics.update(result)
else:
return_metrics[metric.name] = result
return self._pythonify_logs(return_metrics)
def fit(
self,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose="auto",
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
):
"""Trains the model for a fixed number of epochs (dataset iterations).
Args:
x: Input data. It could be:
- A NumPy array (or array-like), or a list of arrays
(in case the model has multiple inputs).
- A tensor, or a list of tensors
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data.Dataset`. Should return a tuple
of either `(inputs, targets)` or
`(inputs, targets, sample_weights)`.
- A `keras_core.utils.PyDataset` returning `(inputs,
targets)` or `(inputs, targets, sample_weights)`.
y: Target data. Like the input data `x`,
it could be either NumPy array(s) or backend-native tensor(s).
If `x` is a dataset, generator,
or `keras_core.utils.PyDataset` instance, `y` should
not be specified (since targets will be obtained from `x`).
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` if your data is in the
form of datasets, generators, or `keras_core.utils.PyDataset`
instances (since they generate batches).
epochs: Integer. Number of epochs to train the model.
An epoch is an iteration over the entire `x` and `y`
data provided
(unless the `steps_per_epoch` flag is set to
something other than None).
Note that in conjunction with `initial_epoch`,
`epochs` is to be understood as "final epoch".
The model is not trained for a number of iterations
given by `epochs`, but merely until the epoch
of index `epochs` is reached.
verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
0 = silent, 1 = progress bar, 2 = one line per epoch.
"auto" becomes 1 for most cases.
Note that the progress bar is not
particularly useful when logged to a file,
so `verbose=2` is recommended when not running interactively
(e.g., in a production environment). Defaults to `"auto"`.
callbacks: List of `keras_core.callbacks.Callback` instances.
List of callbacks to apply during training.
See `keras_core.callbacks`. Note
`keras_core.callbacks.ProgbarLogger` and
`keras_core.callbacks.History` callbacks are created
automatically and need not be passed to `model.fit()`.
`keras_core.callbacks.ProgbarLogger` is created
or not based on the `verbose` argument in `model.fit()`.
validation_split: Float between 0 and 1.
Fraction of the training data to be used as validation data.
The model will set apart this fraction of the training data,
will not train on it, and will evaluate
the loss and any model metrics
on this data at the end of each epoch.
The validation data is selected from the last samples
in the `x` and `y` data provided, before shuffling. This
argument is not supported when `x` is a dataset, generator or
`keras_core.utils.PyDataset` instance.
If both `validation_data` and `validation_split` are provided,
`validation_data` will override `validation_split`.
validation_data: Data on which to evaluate
the loss and any model metrics at the end of each epoch.
The model will not be trained on this data. Thus, note the fact
that the validation loss of data provided using
`validation_split` or `validation_data` is not affected by
regularization layers like noise and dropout.
`validation_data` will override `validation_split`.
`validation_data` could be:
- A tuple `(x_val, y_val)` of NumPy arrays or tensors.
- A tuple `(x_val, y_val, val_sample_weights)` of NumPy
arrays.
- A `tf.data.Dataset`.
- A Python generator or `keras_core.utils.PyDataset` returning
`(inputs, targets)` or `(inputs, targets, sample_weights)`.
shuffle: Boolean, whether to shuffle the training data
before each epoch. This argument is
ignored when `x` is a generator or a `tf.data.Dataset`.
class_weight: Optional dictionary mapping class indices (integers)
to a weight (float) value, used for weighting the loss function
(during training only).
This can be useful to tell the model to
"pay more attention" to samples from
an under-represented class. When `class_weight` is specified
and targets have a rank of 2 or greater, either `y` must be
one-hot encoded, or an explicit final dimension of `1` must
be included for sparse class labels.
sample_weight: Optional NumPy array of weights for
the training samples, used for weighting the loss function
(during training only). You can either pass a flat (1D)
NumPy array with the same length as the input samples
(1:1 mapping between weights and samples),
or in the case of temporal data,
you can pass a 2D array with shape
`(samples, sequence_length)`,
to apply a different weight to every timestep of every sample.
This argument is not supported when `x` is a dataset, generator,
or `keras_core.utils.PyDataset` instance, instead provide the
sample_weights as the third element of `x`.
Note that sample weighting does not apply to metrics specified
via the `metrics` argument in `compile()`. To apply sample
weighting to your metrics, you can specify them via the
`weighted_metrics` in `compile()` instead.
initial_epoch: Integer.
Epoch at which to start training
(useful for resuming a previous training run).
steps_per_epoch: Integer or `None`.
Total number of steps (batches of samples)
before declaring one epoch finished and starting the
next epoch. When training with input tensors such as
backend-native tensors, the default `None` is equal to
the number of samples in your dataset divided by
the batch size, or 1 if that cannot be determined. If `x` is a
`tf.data.Dataset`, and `steps_per_epoch`
is `None`, the epoch will run until the input dataset is
exhausted. When passing an infinitely repeating dataset, you
must specify the `steps_per_epoch` argument. If
`steps_per_epoch=-1` the training will run indefinitely with an
infinitely repeating dataset.
validation_steps: Only relevant if `validation_data` is provided.
Total number of steps (batches of
samples) to draw before stopping when performing validation
at the end of every epoch. If `validation_steps` is `None`,
validation will run until the `validation_data` dataset is
exhausted. In the case of an infinitely repeated dataset, it
will run into an infinite loop. If `validation_steps` is
specified and only part of the dataset will be consumed, the
evaluation will start from the beginning of the dataset at each
epoch. This ensures that the same validation samples are used
every time.
validation_batch_size: Integer or `None`.
Number of samples per validation batch.
If unspecified, will default to `batch_size`.
Do not specify the `validation_batch_size` if your data is in
the form of datasets or `keras_core.utils.PyDataset`
instances (since they generate batches).
validation_freq: Only relevant if validation data is provided.
Specifies how many training epochs to run
before a new validation run is performed, e.g. `validation_freq=2`
runs validation every 2 epochs.
Unpacking behavior for iterator-like inputs:
A common pattern is to pass an iterator like object such as a
`tf.data.Dataset` or a `keras_core.utils.PyDataset` to `fit()`,
which will in fact yield not only features (`x`)
but optionally targets (`y`) and sample weights (`sample_weight`).
Keras requires that the output of such iterator-likes be
unambiguous. The iterator should return a tuple
of length 1, 2, or 3, where the optional second and third elements
will be used for `y` and `sample_weight` respectively.
Any other type provided will be wrapped in
a length-one tuple, effectively treating everything as `x`. When
yielding dicts, they should still adhere to the top-level tuple
structure,
e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate
features, targets, and weights from the keys of a single dict.
A notable unsupported data type is the `namedtuple`. The reason is
that it behaves like both an ordered datatype (tuple) and a mapping
datatype (dict). So given a namedtuple of the form:
`namedtuple("example_tuple", ["y", "x"])`
it is ambiguous whether to reverse the order of the elements when
interpreting the value. Even worse is a tuple of the form:
`namedtuple("other_tuple", ["x", "y", "z"])`
where it is unclear if the tuple was intended to be unpacked
into `x`, `y`, and `sample_weight` or passed through
as a single element to `x`.
Returns:
A `History` object. Its `History.history` attribute is
a record of training loss values and metrics values
at successive epochs, as well as validation loss values
and validation metrics values (if applicable).
"""
raise NotImplementedError
def evaluate(
self,
x=None,
y=None,
batch_size=None,
verbose="auto",
sample_weight=None,
steps=None,
callbacks=None,
return_dict=False,
**kwargs,
):
"""Returns the loss value & metrics values for the model in test mode.
Computation is done in batches (see the `batch_size` arg.)
Args:
x: Input data. It could be:
- A NumPy array (or array-like), or a list of arrays
(in case the model has multiple inputs).
- A tensor, or a list of tensors
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data.Dataset`. Should return a tuple
of either `(inputs, targets)` or
`(inputs, targets, sample_weights)`.
- A generator or `keras_core.utils.PyDataset` returning
`(inputs, targets)` or `(inputs, targets, sample_weights)`.
y: Target data. Like the input data `x`, it could be either NumPy
array(s) or backend-native tensor(s).
If `x` is a `tf.data.Dataset` or `keras_core.utils.PyDataset`
instance, `y` should not be specified
(since targets will be obtained from the iterator/dataset).
batch_size: Integer or `None`. Number of samples per batch of
computation. If unspecified, `batch_size` will default to 32. Do
not specify the `batch_size` if your data is in the form of a
dataset, generators, or `keras_core.utils.PyDataset` instances
(since they generate batches).
verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
0 = silent, 1 = progress bar, 2 = single line.
`"auto"` becomes 1 for most cases.
Note that the progress bar is not
particularly useful when logged to a file, so `verbose=2` is
recommended when not running interactively
(e.g. in a production environment). Defaults to `"auto"`.
sample_weight: Optional NumPy array of weights for the test samples,
used for weighting the loss function. You can either pass a flat
(1D) NumPy array with the same length as the input samples
(1:1 mapping between weights and samples), or in the case of
temporal data, you can pass a 2D array with shape `(samples,
sequence_length)`, to apply a different weight to every
timestep of every sample. This argument is not supported when
`x` is a dataset, instead pass sample weights as the third
element of `x`.
steps: Integer or `None`. Total number of steps (batches of samples)
before declaring the evaluation round finished. Ignored with the
default value of `None`. If `x` is a `tf.data.Dataset` and
`steps` is `None`, evaluation will run until the dataset
is exhausted.
callbacks: List of `keras_core.callbacks.Callback` instances.
List of callbacks to apply during evaluation.
return_dict: If `True`, loss and metric results are returned as a
dict, with each key being the name of the metric.
If `False`, they are returned as a list.
Returns:
Scalar test loss (if the model has a single output and no metrics)
or list of scalars (if the model has multiple outputs
and/or metrics). The attribute `model.metrics_names` will give you
the display labels for the scalar outputs.
"""
raise NotImplementedError
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
):
"""Generates output predictions for the input samples.
Computation is done in batches. This method is designed for batch
processing of large numbers of inputs. It is not intended for use inside
of loops that iterate over your data and process small numbers of inputs
at a time.
For small numbers of inputs that fit in one batch,
directly use `__call__()` for faster execution, e.g.,
`model(x)`, or `model(x, training=False)` if you have layers such as
`BatchNormalization` that behave differently during
inference.
Note: See [this FAQ entry](
https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call)
for more details about the difference between `Model` methods
`predict()` and `__call__()`.
Args:
x: Input samples. It could be:
- A NumPy array (or array-like), or a list of arrays
(in case the model has multiple inputs).
- A tensor, or a list of tensors
(in case the model has multiple inputs).
- A `tf.data.Dataset`.
- A `keras_core.utils.PyDataset` instance.
batch_size: Integer or `None`.
Number of samples per batch.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` if your data is in the
form of dataset, generators, or `keras_core.utils.PyDataset`
instances (since they generate batches).
verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
0 = silent, 1 = progress bar, 2 = single line.
`"auto"` becomes 1 for most cases. Note that the progress bar
is not particularly useful when logged to a file,
so `verbose=2` is recommended when not running interactively
(e.g. in a production environment). Defaults to `"auto"`.
steps: Total number of steps (batches of samples)
before declaring the prediction round finished.
Ignored with the default value of `None`.
If `x` is a `tf.data.Dataset` and `steps` is `None`,
`predict()` will run until the input dataset is exhausted.
callbacks: List of `keras_core.callbacks.Callback` instances.
List of callbacks to apply during prediction.
Returns:
NumPy array(s) of predictions.
"""
raise NotImplementedError
def train_on_batch(
self,
x,
y=None,
sample_weight=None,
class_weight=None,
return_dict=False,
):
"""Runs a single gradient update on a single batch of data.
Args:
x: Input data. Must be array-like.
y: Target data. Must be array-like.
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample.
In the case of temporal data, you can pass a 2D array
with shape `(samples, sequence_length)`, to apply a different
weight to every timestep of every sample.
class_weight: Optional dictionary mapping class indices (integers)
to a weight (float) to apply to the model's loss for the samples
from this class during training. This can be useful to tell the
model to "pay more attention" to samples from an
under-represented class. When `class_weight` is specified
and targets have a rank of 2 or greater, either `y` must
be one-hot encoded, or an explicit final dimension of 1
must be included for sparse class labels.
return_dict: If `True`, loss and metric results are returned as a
dict, with each key being the name of the metric. If `False`,
they are returned as a list.
Returns:
A scalar loss value (when no metrics and `return_dict=False`),
a list of loss and metric values
(if there are metrics and `return_dict=False`), or a dict of
metric and loss values (if `return_dict=True`).
"""
raise NotImplementedError
def test_on_batch(
self,
x,
y=None,
sample_weight=None,
return_dict=False,
):
"""Test the model on a single batch of samples.
Args:
x: Input data. Must be array-like.
y: Target data. Must be array-like.
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample.
In the case of temporal data, you can pass a 2D array
with shape `(samples, sequence_length)`, to apply a different
weight to every timestep of every sample.
return_dict: If `True`, loss and metric results are returned as a
dict, with each key being the name of the metric. If `False`,
they are returned as a list.
Returns:
A scalar loss value (when no metrics and `return_dict=False`),
a list of loss and metric values
(if there are metrics and `return_dict=False`), or a dict of
metric and loss values (if `return_dict=True`).
"""
raise NotImplementedError
def predict_on_batch(self, x):
"""Returns predictions for a single batch of samples.
Args:
x: Input data. It must be array-like.
Returns:
NumPy array(s) of predictions.
"""
raise NotImplementedError
def get_compile_config(self):
"""Returns a serialized config with information for compiling the model.
This method returns a config dictionary containing all the information
(optimizer, loss, metrics, etc.) with which the model was compiled.
Returns:
A dict containing information for compiling the model.
"""
if self.compiled and hasattr(self, "_compile_config"):
return self._compile_config.serialize()
def compile_from_config(self, config):
"""Compiles the model with the information given in config.
This method uses the information in the config (optimizer, loss,
metrics, etc.) to compile the model.
Args:
config: Dict containing information for compiling the model.
"""
has_overridden_compile = self.__class__.compile != Trainer.compile
if has_overridden_compile:
warnings.warn(
"`compile()` was not called as part of model loading "
"because the model's `compile()` method is custom. "
"All subclassed Models that have `compile()` "
"overridden should also override "
"`get_compile_config()` and `compile_from_config(config)`. "
"Alternatively, you can "
"call `compile()` manually after loading.",
stacklevel=2,
)
return
config = serialization_lib.deserialize_keras_object(config)
self.compile(**config)
if hasattr(self, "optimizer") and self.built:
# Create optimizer variables.
self.optimizer.build(self.trainable_variables)
def _should_eval(self, epoch, validation_freq):
epoch = epoch + 1 # one-index the user-facing epoch.
if isinstance(validation_freq, int):
return epoch % validation_freq == 0
elif isinstance(validation_freq, list):
return epoch in validation_freq
else:
raise ValueError(
"Expected `validation_freq` to be a list or int. "
f"Received: validation_freq={validation_freq} of the "
f"type {type(validation_freq)}."
)
def _pythonify_logs(self, logs):
result = {}
for key, value in sorted(logs.items()):
try:
value = float(value)
except:
pass
result[key] = value
return result
def _flatten_metrics_in_order(self, logs):
"""Turns `logs` dict into a list as per key order of `metrics_names`."""
metric_names = [m.name for m in self.metrics]
results = []
for name in metric_names:
if name in logs:
results.append(logs[name])
for key in sorted(logs.keys()):
if key not in metric_names:
results.append(logs[key])
if len(results) == 1:
return results[0]
return results
def _assert_compile_called(self, method_name=None):
if not self.compiled:
msg = "You must call `compile()` before "
if metrics_module:
msg += "using the model."
else:
msg += f"calling `{method_name}()`."
raise ValueError(msg)
def model_supports_jit(model):
if platform.system() == "Darwin" and "arm" in platform.processor().lower():
if backend.backend() == "tensorflow":
import tensorflow as tf
if tf.config.list_physical_devices("GPU"):
return False
if all(x.supports_jit for x in model._flatten_layers()):
return True
return False