2023-04-18 04:26:04 +00:00
|
|
|
import warnings
|
2023-04-18 21:49:38 +00:00
|
|
|
|
2023-04-16 21:54:13 +00:00
|
|
|
from keras_core import backend
|
2023-04-18 04:26:04 +00:00
|
|
|
from keras_core import metrics as metrics_module
|
2023-04-18 21:49:38 +00:00
|
|
|
from keras_core import operations as ops
|
2023-04-25 01:46:03 +00:00
|
|
|
from keras_core.saving import serialization_lib
|
2023-04-16 21:54:13 +00:00
|
|
|
from keras_core.trainers.compile_utils import CompileLoss
|
|
|
|
from keras_core.trainers.compile_utils import CompileMetrics
|
2023-04-19 23:25:56 +00:00
|
|
|
from keras_core.utils import tracking
|
2023-04-16 21:54:13 +00:00
|
|
|
|
|
|
|
|
2023-04-09 19:21:45 +00:00
|
|
|
class Trainer:
|
2023-04-17 21:56:37 +00:00
|
|
|
def __init__(self):
|
|
|
|
self._run_eagerly = False
|
|
|
|
self._jit_compile = True
|
|
|
|
self.compiled = False
|
2023-04-17 22:41:48 +00:00
|
|
|
|
2023-04-19 23:25:56 +00:00
|
|
|
@tracking.no_automatic_dependency_tracking
|
2023-04-16 01:51:10 +00:00
|
|
|
def compile(
|
|
|
|
self,
|
|
|
|
optimizer,
|
|
|
|
loss=None,
|
|
|
|
loss_weights=None,
|
|
|
|
metrics=None,
|
|
|
|
weighted_metrics=None,
|
|
|
|
run_eagerly=False,
|
2023-04-17 21:56:37 +00:00
|
|
|
jit_compile=True,
|
2023-04-16 01:51:10 +00:00
|
|
|
):
|
|
|
|
# TODO: get from module
|
2023-04-09 19:21:45 +00:00
|
|
|
self.optimizer = optimizer
|
2023-04-16 21:54:13 +00:00
|
|
|
if loss is not None:
|
|
|
|
self._compile_loss = CompileLoss(loss, loss_weights)
|
|
|
|
else:
|
|
|
|
self._compile_loss = None
|
|
|
|
if metrics is not None:
|
|
|
|
self._compile_metrics = CompileMetrics(metrics, weighted_metrics)
|
|
|
|
else:
|
|
|
|
self._compile_metrics = None
|
2023-04-17 21:56:37 +00:00
|
|
|
if jit_compile and run_eagerly:
|
2023-04-18 04:26:04 +00:00
|
|
|
jit_compile = False
|
|
|
|
warnings.warn(
|
|
|
|
"If `run_eagerly` is True, then `jit_compile` cannot also be True. "
|
|
|
|
"Disabling `jit_compile`.",
|
|
|
|
stacklevel=2,
|
2023-04-17 22:41:48 +00:00
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
self.jit_compile = jit_compile
|
2023-04-16 01:51:10 +00:00
|
|
|
self.run_eagerly = run_eagerly
|
2023-04-17 21:56:37 +00:00
|
|
|
self.stop_training = False
|
|
|
|
self.compiled = True
|
2023-04-18 04:26:04 +00:00
|
|
|
self._loss_tracker = metrics_module.Mean(name="loss")
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-04-25 01:46:03 +00:00
|
|
|
self._compile_config = serialization_lib.SerializableDict(
|
|
|
|
optimizer=optimizer,
|
|
|
|
loss=loss,
|
|
|
|
loss_weights=loss_weights,
|
|
|
|
metrics=metrics,
|
|
|
|
weighted_metrics=weighted_metrics,
|
|
|
|
run_eagerly=run_eagerly,
|
|
|
|
jit_compile=jit_compile,
|
|
|
|
)
|
|
|
|
|
2023-04-16 01:51:10 +00:00
|
|
|
@property
|
|
|
|
def jit_compile(self):
|
|
|
|
return self._jit_compile
|
|
|
|
|
|
|
|
@jit_compile.setter
|
|
|
|
def jit_compile(self, value):
|
|
|
|
self._jit_compile = value
|
|
|
|
|
|
|
|
@property
|
|
|
|
def run_eagerly(self):
|
|
|
|
return self._run_eagerly
|
|
|
|
|
|
|
|
@run_eagerly.setter
|
|
|
|
def run_eagerly(self, value):
|
|
|
|
self._run_eagerly = value
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def metrics(self):
|
2023-04-18 04:26:04 +00:00
|
|
|
metrics = [self._loss_tracker]
|
2023-04-18 00:23:53 +00:00
|
|
|
metrics.extend(self._metrics[:])
|
|
|
|
if self._compile_metrics is not None and self._compile_metrics.built:
|
2023-04-17 22:41:48 +00:00
|
|
|
metrics += [self._compile_metrics]
|
2023-04-16 21:54:13 +00:00
|
|
|
return metrics
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-04-18 21:49:38 +00:00
|
|
|
@property
|
|
|
|
def metrics_variables(self):
|
|
|
|
vars = []
|
|
|
|
for metric in self.metrics:
|
|
|
|
vars.extend(metric.variables)
|
|
|
|
return vars
|
|
|
|
|
2023-04-09 19:21:45 +00:00
|
|
|
def reset_metrics(self):
|
|
|
|
for m in self.metrics:
|
|
|
|
m.reset_state()
|
|
|
|
|
2023-04-16 01:51:10 +00:00
|
|
|
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
|
|
|
|
"""Compute the total loss, validate it, and return it.
|
|
|
|
|
|
|
|
Subclasses can optionally override this method to provide custom loss
|
|
|
|
computation logic.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
```python
|
|
|
|
class MyModel(Model):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.loss_tracker = metrics.Mean(name='loss')
|
|
|
|
|
|
|
|
def compute_loss(self, x, y, y_pred, sample_weight):
|
|
|
|
loss = ops.means((y_pred - y) ** 2)
|
|
|
|
loss += ops.sum(self.losses)
|
|
|
|
self.loss_tracker.update_state(loss)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
def reset_metrics(self):
|
|
|
|
self.loss_tracker.reset_state()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def metrics(self):
|
|
|
|
return [self.loss_tracker]
|
|
|
|
|
|
|
|
inputs = layers.Input(shape=(10,), name='my_input')
|
|
|
|
outputs = layers.Dense(10)(inputs)
|
|
|
|
model = MyModel(inputs, outputs)
|
|
|
|
model.add_loss(ops.sum(outputs))
|
|
|
|
|
|
|
|
optimizer = SGD()
|
|
|
|
model.compile(optimizer, loss='mse', steps_per_execution=10)
|
|
|
|
dataset = ...
|
|
|
|
model.fit(dataset, epochs=2, steps_per_epoch=10)
|
2023-04-16 21:54:13 +00:00
|
|
|
print(f"Custom loss: {model.loss_tracker.result()}")
|
2023-04-16 01:51:10 +00:00
|
|
|
```
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x: Input data.
|
|
|
|
y: Target data.
|
|
|
|
y_pred: Predictions returned by the model (output of `model(x)`)
|
|
|
|
sample_weight: Sample weights for weighting the loss function.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The total loss as a scalar tensor, or `None` if no loss results
|
|
|
|
(which is the case when called by `Model.test_step`).
|
|
|
|
"""
|
|
|
|
del x # The default implementation does not use `x`.
|
2023-04-16 21:54:13 +00:00
|
|
|
losses = []
|
2023-04-17 21:56:37 +00:00
|
|
|
if self._compile_loss is not None:
|
|
|
|
loss = self._compile_loss(y, y_pred, sample_weight)
|
2023-04-16 21:54:13 +00:00
|
|
|
if loss is not None:
|
|
|
|
losses.append(loss)
|
2023-04-22 06:16:51 +00:00
|
|
|
for loss in self.losses:
|
|
|
|
losses.append(ops.cast(loss, dtype=backend.floatx()))
|
2023-04-16 21:54:13 +00:00
|
|
|
if len(losses) == 0:
|
|
|
|
raise ValueError(
|
|
|
|
"No loss to compute. Provide a `loss` argument in `compile()`."
|
|
|
|
)
|
|
|
|
if len(losses) == 1:
|
2023-04-18 04:26:04 +00:00
|
|
|
total_loss = losses[0]
|
|
|
|
else:
|
|
|
|
total_loss = ops.sum(losses)
|
|
|
|
return total_loss
|
2023-04-16 01:51:10 +00:00
|
|
|
|
2023-04-18 21:49:38 +00:00
|
|
|
def compute_metrics(self, x, y, y_pred, sample_weight=None):
|
2023-04-16 01:51:10 +00:00
|
|
|
"""Update metric states and collect all metrics to be returned.
|
|
|
|
|
|
|
|
Subclasses can optionally override this method to provide custom metric
|
|
|
|
updating and collection logic.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
```python
|
|
|
|
class MyModel(Sequential):
|
|
|
|
def compute_metrics(self, x, y, y_pred, sample_weight):
|
|
|
|
# This super call updates `self.compiled_metrics` and returns
|
|
|
|
# results for all metrics listed in `self.metrics`.
|
|
|
|
metric_results = super().compute_metrics(x, y, y_pred, sample_weight)
|
|
|
|
|
|
|
|
# Note that `self.custom_metric` is not listed in `self.metrics`.
|
|
|
|
self.custom_metric.update_state(x, y, y_pred, sample_weight)
|
|
|
|
metric_results['custom_metric_name'] = self.custom_metric.result()
|
|
|
|
return metric_results
|
|
|
|
```
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x: Input data.
|
|
|
|
y: Target data.
|
|
|
|
y_pred: Predictions returned by the model (output of `model.call(x)`)
|
|
|
|
sample_weight: Sample weights for weighting the loss function.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A `dict` containing values that will be passed to
|
|
|
|
`tf.keras.callbacks.CallbackList.on_train_batch_end()`. Typically, the
|
|
|
|
values of the metrics listed in `self.metrics` are returned. Example:
|
|
|
|
`{'loss': 0.2, 'accuracy': 0.7}`.
|
|
|
|
"""
|
|
|
|
del x # The default implementation does not use `x`.
|
2023-04-16 21:54:13 +00:00
|
|
|
if self._compile_metrics is not None:
|
|
|
|
self._compile_metrics.update_state(y, y_pred, sample_weight)
|
2023-04-16 01:51:10 +00:00
|
|
|
return self.get_metrics_result()
|
|
|
|
|
|
|
|
def get_metrics_result(self):
|
|
|
|
"""Returns the model's metrics values as a dict.
|
|
|
|
|
|
|
|
If any of the metric result is a dict (containing multiple metrics),
|
|
|
|
each of them gets added to the top level returned dict of this method.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A `dict` containing values of the metrics listed in `self.metrics`.
|
|
|
|
Example:
|
|
|
|
`{'loss': 0.2, 'accuracy': 0.7}`.
|
|
|
|
"""
|
|
|
|
return_metrics = {}
|
|
|
|
for metric in self.metrics:
|
|
|
|
result = metric.result()
|
|
|
|
if isinstance(result, dict):
|
|
|
|
return_metrics.update(result)
|
|
|
|
else:
|
|
|
|
return_metrics[metric.name] = result
|
|
|
|
return return_metrics
|
|
|
|
|
2023-04-17 03:57:14 +00:00
|
|
|
def fit(
|
|
|
|
self,
|
|
|
|
x=None,
|
|
|
|
y=None,
|
|
|
|
batch_size=None,
|
|
|
|
epochs=1,
|
|
|
|
verbose="auto",
|
|
|
|
callbacks=None,
|
|
|
|
validation_split=0.0,
|
|
|
|
validation_data=None,
|
|
|
|
shuffle=True,
|
|
|
|
class_weight=None,
|
|
|
|
sample_weight=None,
|
|
|
|
initial_epoch=0,
|
|
|
|
steps_per_epoch=None,
|
|
|
|
validation_steps=None,
|
|
|
|
validation_batch_size=None,
|
|
|
|
validation_freq=1,
|
|
|
|
):
|
2023-04-09 19:21:45 +00:00
|
|
|
raise NotImplementedError
|
|
|
|
|
2023-04-17 03:57:14 +00:00
|
|
|
def evaluate(
|
|
|
|
self,
|
|
|
|
x=None,
|
|
|
|
y=None,
|
|
|
|
batch_size=None,
|
|
|
|
verbose="auto",
|
|
|
|
sample_weight=None,
|
|
|
|
steps=None,
|
|
|
|
callbacks=None,
|
|
|
|
return_dict=False,
|
|
|
|
**kwargs,
|
|
|
|
):
|
2023-04-09 19:21:45 +00:00
|
|
|
raise NotImplementedError
|
|
|
|
|
2023-04-17 03:57:14 +00:00
|
|
|
def predict(
|
|
|
|
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
|
|
|
|
):
|
2023-04-09 19:21:45 +00:00
|
|
|
raise NotImplementedError
|
2023-04-17 22:41:48 +00:00
|
|
|
|
2023-04-20 01:37:25 +00:00
|
|
|
def get_compile_config(self):
|
2023-04-25 01:46:03 +00:00
|
|
|
"""Returns a serialized config with information for compiling the model.
|
|
|
|
|
|
|
|
This method returns a config dictionary containing all the information
|
|
|
|
(optimizer, loss, metrics, etc.) with which the model was compiled.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A dict containing information for compiling the model.
|
|
|
|
"""
|
|
|
|
if self.compiled and hasattr(self, "_compile_config"):
|
|
|
|
return self._compile_config.serialize()
|
|
|
|
|
|
|
|
def compile_from_config(self, config):
|
|
|
|
"""Compiles the model with the information given in config.
|
2023-04-20 01:37:25 +00:00
|
|
|
|
2023-04-25 01:46:03 +00:00
|
|
|
This method uses the information in the config (optimizer, loss,
|
|
|
|
metrics, etc.) to compile the model.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
config: Dict containing information for compiling the model.
|
|
|
|
"""
|
|
|
|
has_overridden_compile = self.__class__.compile != Trainer.compile
|
|
|
|
if has_overridden_compile:
|
|
|
|
warnings.warn(
|
|
|
|
"`compile()` was not called as part of model loading "
|
|
|
|
"because the model's `compile()` method is custom. "
|
|
|
|
"All subclassed Models that have `compile()` "
|
|
|
|
"overridden should also override "
|
|
|
|
"`get_compile_config()` and `compile_from_config(config)`. "
|
|
|
|
"Alternatively, you can "
|
|
|
|
"call `compile()` manually after loading.",
|
|
|
|
stacklevel=2,
|
|
|
|
)
|
|
|
|
return
|
|
|
|
config = serialization_lib.deserialize_keras_object(config)
|
|
|
|
self.compile(**config)
|
|
|
|
if hasattr(self, "optimizer") and self.built:
|
|
|
|
# Create optimizer variables.
|
|
|
|
self.optimizer.build(self.trainable_variables)
|
2023-04-20 01:37:25 +00:00
|
|
|
|
2023-04-17 21:56:37 +00:00
|
|
|
def _should_eval(self, epoch, validation_freq):
|
|
|
|
epoch = epoch + 1 # one-index the user-facing epoch.
|
|
|
|
if isinstance(validation_freq, int):
|
|
|
|
return epoch % validation_freq == 0
|
|
|
|
elif isinstance(validation_freq, list):
|
|
|
|
return epoch in validation_freq
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
"Expected `validation_freq` to be a list or int. "
|
|
|
|
f"Received: validation_freq={validation_freq} of the "
|
|
|
|
f"type {type(validation_freq)}."
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-04-20 01:37:25 +00:00
|
|
|
def _pythonify_logs(self, logs):
|
2023-04-19 20:50:22 +00:00
|
|
|
result = {}
|
|
|
|
for key, value in logs.items():
|
|
|
|
try:
|
|
|
|
value = float(value)
|
|
|
|
except:
|
|
|
|
pass
|
|
|
|
result[key] = value
|
2023-04-19 22:24:35 +00:00
|
|
|
return result
|
2023-04-20 01:37:25 +00:00
|
|
|
|
|
|
|
def _flatten_metrics_in_order(self, logs):
|
|
|
|
"""Turns the `logs` dict into a list as per key order of `metrics_names`."""
|
|
|
|
metric_names = [m.name for m in self.metrics]
|
|
|
|
results = []
|
|
|
|
for name in metric_names:
|
|
|
|
if name in logs:
|
|
|
|
results.append(logs[name])
|
|
|
|
for key in sorted(logs.keys()):
|
|
|
|
if key not in metric_names:
|
|
|
|
results.append(logs[key])
|
|
|
|
if len(results) == 1:
|
|
|
|
return results[0]
|
|
|
|
return results
|