keras/keras_core/trainers/trainer.py

344 lines
11 KiB
Python
Raw Normal View History

2023-04-18 04:26:04 +00:00
import warnings
2023-04-18 21:49:38 +00:00
2023-04-16 21:54:13 +00:00
from keras_core import backend
2023-04-18 04:26:04 +00:00
from keras_core import metrics as metrics_module
2023-04-18 21:49:38 +00:00
from keras_core import operations as ops
2023-04-25 19:59:32 +00:00
from keras_core import optimizers
from keras_core.saving import serialization_lib
2023-04-16 21:54:13 +00:00
from keras_core.trainers.compile_utils import CompileLoss
from keras_core.trainers.compile_utils import CompileMetrics
2023-04-19 23:25:56 +00:00
from keras_core.utils import tracking
2023-04-16 21:54:13 +00:00
2023-04-09 19:21:45 +00:00
class Trainer:
2023-04-17 21:56:37 +00:00
def __init__(self):
self._lock = False
2023-04-17 21:56:37 +00:00
self._run_eagerly = False
self._jit_compile = True
self.compiled = False
self.steps_per_execution = 1
2023-04-17 22:41:48 +00:00
2023-04-19 23:25:56 +00:00
@tracking.no_automatic_dependency_tracking
2023-04-16 01:51:10 +00:00
def compile(
self,
optimizer="rmsprop",
2023-04-16 01:51:10 +00:00
loss=None,
loss_weights=None,
metrics=None,
weighted_metrics=None,
run_eagerly=False,
steps_per_execution=1,
2023-04-17 21:56:37 +00:00
jit_compile=True,
2023-04-16 01:51:10 +00:00
):
2023-04-25 19:59:32 +00:00
self.optimizer = optimizers.get(optimizer)
2023-04-16 21:54:13 +00:00
if loss is not None:
self._compile_loss = CompileLoss(loss, loss_weights)
else:
self._compile_loss = None
if metrics is not None:
self._compile_metrics = CompileMetrics(metrics, weighted_metrics)
else:
self._compile_metrics = None
2023-04-17 21:56:37 +00:00
if jit_compile and run_eagerly:
2023-04-18 04:26:04 +00:00
jit_compile = False
warnings.warn(
2023-04-27 03:42:23 +00:00
"If `run_eagerly` is True, then `jit_compile` "
"cannot also be True. Disabling `jit_compile`.",
2023-04-18 04:26:04 +00:00
stacklevel=2,
2023-04-17 22:41:48 +00:00
)
2023-04-09 19:21:45 +00:00
self.jit_compile = jit_compile
2023-04-16 01:51:10 +00:00
self.run_eagerly = run_eagerly
2023-04-17 21:56:37 +00:00
self.stop_training = False
self.compiled = True
2023-04-18 04:26:04 +00:00
self._loss_tracker = metrics_module.Mean(name="loss")
self.steps_per_execution = steps_per_execution
2023-04-09 19:21:45 +00:00
self._compile_config = serialization_lib.SerializableDict(
optimizer=optimizer,
loss=loss,
loss_weights=loss_weights,
metrics=metrics,
weighted_metrics=weighted_metrics,
run_eagerly=run_eagerly,
steps_per_execution=steps_per_execution,
jit_compile=jit_compile,
)
2023-04-16 01:51:10 +00:00
@property
def jit_compile(self):
return self._jit_compile
@jit_compile.setter
def jit_compile(self, value):
self._jit_compile = value
@property
def run_eagerly(self):
return self._run_eagerly
@run_eagerly.setter
def run_eagerly(self, value):
self._run_eagerly = value
2023-04-09 19:21:45 +00:00
@property
def metrics(self):
2023-04-18 04:26:04 +00:00
metrics = [self._loss_tracker]
2023-04-18 00:23:53 +00:00
metrics.extend(self._metrics[:])
if self._compile_metrics is not None and self._compile_metrics.built:
2023-04-17 22:41:48 +00:00
metrics += [self._compile_metrics]
2023-04-16 21:54:13 +00:00
return metrics
2023-04-09 19:21:45 +00:00
2023-04-18 21:49:38 +00:00
@property
def metrics_variables(self):
vars = []
for metric in self.metrics:
vars.extend(metric.variables)
return vars
2023-04-09 19:21:45 +00:00
def reset_metrics(self):
for m in self.metrics:
m.reset_state()
2023-04-16 01:51:10 +00:00
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
"""Compute the total loss, validate it, and return it.
Subclasses can optionally override this method to provide custom loss
computation logic.
Example:
```python
class MyModel(Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_tracker = metrics.Mean(name='loss')
def compute_loss(self, x, y, y_pred, sample_weight):
loss = ops.means((y_pred - y) ** 2)
loss += ops.sum(self.losses)
self.loss_tracker.update_state(loss)
return loss
def reset_metrics(self):
self.loss_tracker.reset_state()
@property
def metrics(self):
return [self.loss_tracker]
inputs = layers.Input(shape=(10,), name='my_input')
outputs = layers.Dense(10)(inputs)
model = MyModel(inputs, outputs)
model.add_loss(ops.sum(outputs))
optimizer = SGD()
model.compile(optimizer, loss='mse', steps_per_execution=10)
dataset = ...
model.fit(dataset, epochs=2, steps_per_epoch=10)
2023-04-16 21:54:13 +00:00
print(f"Custom loss: {model.loss_tracker.result()}")
2023-04-16 01:51:10 +00:00
```
Args:
x: Input data.
y: Target data.
y_pred: Predictions returned by the model (output of `model(x)`)
sample_weight: Sample weights for weighting the loss function.
Returns:
The total loss as a scalar tensor, or `None` if no loss results
(which is the case when called by `Model.test_step`).
"""
del x # The default implementation does not use `x`.
2023-04-16 21:54:13 +00:00
losses = []
2023-04-17 21:56:37 +00:00
if self._compile_loss is not None:
loss = self._compile_loss(y, y_pred, sample_weight)
2023-04-16 21:54:13 +00:00
if loss is not None:
losses.append(loss)
for loss in self.losses:
losses.append(ops.cast(loss, dtype=backend.floatx()))
2023-04-16 21:54:13 +00:00
if len(losses) == 0:
raise ValueError(
"No loss to compute. Provide a `loss` argument in `compile()`."
)
if len(losses) == 1:
2023-04-18 04:26:04 +00:00
total_loss = losses[0]
else:
total_loss = ops.sum(losses)
return total_loss
2023-04-16 01:51:10 +00:00
2023-04-18 21:49:38 +00:00
def compute_metrics(self, x, y, y_pred, sample_weight=None):
2023-04-16 01:51:10 +00:00
"""Update metric states and collect all metrics to be returned.
Subclasses can optionally override this method to provide custom metric
updating and collection logic.
Example:
```python
class MyModel(Sequential):
def compute_metrics(self, x, y, y_pred, sample_weight):
# This super call updates `self.compiled_metrics` and returns
# results for all metrics listed in `self.metrics`.
2023-04-27 03:42:23 +00:00
metric_results = super().compute_metrics(
x, y, y_pred, sample_weight)
2023-04-16 01:51:10 +00:00
# Note that `self.custom_metric` is not listed in `self.metrics`.
self.custom_metric.update_state(x, y, y_pred, sample_weight)
metric_results['custom_metric_name'] = self.custom_metric.result()
return metric_results
```
Args:
x: Input data.
y: Target data.
2023-04-27 03:42:23 +00:00
y_pred: Predictions returned by the model output of `model.call(x)`.
2023-04-16 01:51:10 +00:00
sample_weight: Sample weights for weighting the loss function.
Returns:
A `dict` containing values that will be passed to
2023-04-27 03:42:23 +00:00
`tf.keras.callbacks.CallbackList.on_train_batch_end()`. Typically,
the values of the metrics listed in `self.metrics` are returned.
Example: `{'loss': 0.2, 'accuracy': 0.7}`.
2023-04-16 01:51:10 +00:00
"""
del x # The default implementation does not use `x`.
2023-04-16 21:54:13 +00:00
if self._compile_metrics is not None:
self._compile_metrics.update_state(y, y_pred, sample_weight)
2023-04-16 01:51:10 +00:00
return self.get_metrics_result()
def get_metrics_result(self):
"""Returns the model's metrics values as a dict.
If any of the metric result is a dict (containing multiple metrics),
each of them gets added to the top level returned dict of this method.
Returns:
A `dict` containing values of the metrics listed in `self.metrics`.
Example:
`{'loss': 0.2, 'accuracy': 0.7}`.
"""
return_metrics = {}
for metric in self.metrics:
result = metric.result()
if isinstance(result, dict):
return_metrics.update(result)
else:
return_metrics[metric.name] = result
return return_metrics
def fit(
self,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose="auto",
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
):
2023-04-09 19:21:45 +00:00
raise NotImplementedError
def evaluate(
self,
x=None,
y=None,
batch_size=None,
verbose="auto",
sample_weight=None,
steps=None,
callbacks=None,
return_dict=False,
**kwargs,
):
2023-04-09 19:21:45 +00:00
raise NotImplementedError
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
):
2023-04-09 19:21:45 +00:00
raise NotImplementedError
2023-04-17 22:41:48 +00:00
2023-04-20 01:37:25 +00:00
def get_compile_config(self):
"""Returns a serialized config with information for compiling the model.
This method returns a config dictionary containing all the information
(optimizer, loss, metrics, etc.) with which the model was compiled.
Returns:
A dict containing information for compiling the model.
"""
if self.compiled and hasattr(self, "_compile_config"):
return self._compile_config.serialize()
def compile_from_config(self, config):
"""Compiles the model with the information given in config.
2023-04-20 01:37:25 +00:00
This method uses the information in the config (optimizer, loss,
metrics, etc.) to compile the model.
Args:
config: Dict containing information for compiling the model.
"""
has_overridden_compile = self.__class__.compile != Trainer.compile
if has_overridden_compile:
warnings.warn(
"`compile()` was not called as part of model loading "
"because the model's `compile()` method is custom. "
"All subclassed Models that have `compile()` "
"overridden should also override "
"`get_compile_config()` and `compile_from_config(config)`. "
"Alternatively, you can "
"call `compile()` manually after loading.",
stacklevel=2,
)
return
config = serialization_lib.deserialize_keras_object(config)
self.compile(**config)
if hasattr(self, "optimizer") and self.built:
# Create optimizer variables.
self.optimizer.build(self.trainable_variables)
2023-04-20 01:37:25 +00:00
2023-04-17 21:56:37 +00:00
def _should_eval(self, epoch, validation_freq):
epoch = epoch + 1 # one-index the user-facing epoch.
if isinstance(validation_freq, int):
return epoch % validation_freq == 0
elif isinstance(validation_freq, list):
return epoch in validation_freq
else:
raise ValueError(
"Expected `validation_freq` to be a list or int. "
f"Received: validation_freq={validation_freq} of the "
f"type {type(validation_freq)}."
)
2023-04-09 19:21:45 +00:00
2023-04-20 01:37:25 +00:00
def _pythonify_logs(self, logs):
result = {}
for key, value in logs.items():
try:
value = float(value)
except:
pass
result[key] = value
2023-04-19 22:24:35 +00:00
return result
2023-04-20 01:37:25 +00:00
def _flatten_metrics_in_order(self, logs):
2023-04-27 03:42:23 +00:00
"""Turns `logs` dict into a list as per key order of `metrics_names`."""
2023-04-20 01:37:25 +00:00
metric_names = [m.name for m in self.metrics]
results = []
for name in metric_names:
if name in logs:
results.append(logs[name])
for key in sorted(logs.keys()):
if key not in metric_names:
results.append(logs[key])
if len(results) == 1:
return results[0]
return results