keras/keras_core/backend/jax/trainer.py

535 lines
18 KiB
Python
Raw Normal View History

import jax
2023-04-20 20:21:41 +00:00
import numpy as np
import tensorflow as tf # for nest
2023-04-18 04:26:04 +00:00
from keras_core import backend
from keras_core import callbacks as callbacks_module
from keras_core import optimizers as optimizers_module
from keras_core.trainers import trainer as base_trainer
2023-04-19 04:11:21 +00:00
from keras_core.trainers.data_adapters import data_adapter_utils
from keras_core.trainers.epoch_iterator import EpochIterator
2023-04-20 21:59:20 +00:00
class JAXTrainer(base_trainer.Trainer):
2023-04-19 23:25:56 +00:00
def compute_loss_and_updates(
2023-04-19 22:24:35 +00:00
self, trainable_variables, non_trainable_variables, x, y, sample_weight
):
2023-04-19 23:25:56 +00:00
"""This method is stateless and is intended for use with jax.grad."""
2023-04-19 22:24:35 +00:00
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x
)
2023-04-19 22:24:35 +00:00
loss = self.compute_loss(x, y, y_pred, sample_weight)
return loss, (y_pred, non_trainable_variables)
def fit(
self,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose="auto",
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
):
if not self.compiled:
raise ValueError(
"You must call `compile()` before calling `fit()`."
)
2023-04-18 04:26:04 +00:00
# TODO: respect compiled trainable state
if validation_split and validation_data is None:
# Create the validation data using the training data. Only supported
# for TF/numpy/jax arrays.
(
x,
y,
sample_weight,
2023-04-19 04:11:21 +00:00
), validation_data = data_adapter_utils.train_validation_split(
(x, y, sample_weight), validation_split=validation_split
)
if validation_data:
(
val_x,
val_y,
val_sample_weight,
2023-04-19 04:11:21 +00:00
) = data_adapter_utils.unpack_x_y_sample_weight(validation_data)
2023-04-18 04:26:04 +00:00
# Create an iterator that yields batches for one epoch.
epoch_iterator = EpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps_per_epoch,
shuffle=shuffle,
class_weight=class_weight,
steps_per_execution=self.steps_per_execution,
)
2023-04-20 01:37:25 +00:00
compile_metrics_unbuilt = (
self._compile_metrics is not None
and not self._compile_metrics.built
)
if not self.built or compile_metrics_unbuilt:
2023-04-18 04:26:04 +00:00
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
data = data[0]
2023-04-18 21:49:38 +00:00
(
x,
y,
sample_weight,
2023-04-19 04:11:21 +00:00
) = data_adapter_utils.unpack_x_y_sample_weight(data)
2023-04-18 04:26:04 +00:00
# Build model
2023-04-18 21:49:38 +00:00
y_pred = self(x)
2023-04-20 01:37:25 +00:00
if compile_metrics_unbuilt:
# Build metrics
self.compute_metrics(
x, y, y_pred, sample_weight=sample_weight
)
2023-04-18 04:26:04 +00:00
break
if not self.optimizer.built:
# Build optimizer
self.optimizer.build(self.trainable_variables)
2023-04-18 04:26:04 +00:00
# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
verbose=verbose,
epochs=epochs,
steps=epoch_iterator.num_batches,
model=self,
)
2023-04-19 22:24:35 +00:00
grad_fn = jax.value_and_grad(
2023-04-19 23:25:56 +00:00
self.compute_loss_and_updates, has_aux=True
2023-04-19 22:24:35 +00:00
)
def _train_step(state, data):
2023-04-18 21:49:38 +00:00
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
2023-04-19 04:11:21 +00:00
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(
2023-04-18 21:49:38 +00:00
data
)
2023-04-18 20:02:29 +00:00
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
2023-04-19 22:24:35 +00:00
trainable_variables,
non_trainable_variables,
x,
y,
sample_weight,
2023-04-18 20:02:29 +00:00
)
2023-04-18 21:49:38 +00:00
(
trainable_variables,
optimizer_variables,
) = self.optimizer.stateless_apply(
grads, trainable_variables, optimizer_variables
)
2023-04-18 20:02:29 +00:00
2023-04-18 21:49:38 +00:00
with backend.StatelessScope(
state_mapping=[
(ref_v, v)
for ref_v, v in zip(
self.metrics_variables, metrics_variables
)
]
) as scope:
logs = self.compute_metrics(x, y, y_pred, sample_weight)
self._loss_tracker.update_state(loss)
new_metrics_variables = []
for ref_v in self.metrics_variables:
new_v = scope.get_current_value(ref_v)
if new_v is None:
new_v = ref_v.value
new_metrics_variables.append(new_v)
metrics_variables = new_metrics_variables
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
)
return logs, state
2023-04-19 22:24:35 +00:00
def _train_multi_step(state, data):
for single_step_data in data:
2023-05-16 17:52:14 +00:00
logs, state = _train_step(state, single_step_data)
return logs, state
2023-04-19 21:35:28 +00:00
if not self.run_eagerly and self.jit_compile:
2023-04-19 22:24:35 +00:00
@jax.jit
def train_step(state, data):
2023-05-16 17:52:14 +00:00
if self.steps_per_execution > 1:
return _train_multi_step(state, data)
return _train_step(state, data[0])
2023-04-19 22:24:35 +00:00
else:
2023-05-16 17:52:14 +00:00
train_step = _train_multi_step
2023-04-19 22:24:35 +00:00
self.stop_training = False
callbacks.on_train_begin()
2023-04-18 20:02:29 +00:00
for epoch in range(initial_epoch, epochs):
self.reset_metrics()
callbacks.on_epoch_begin(epoch)
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
optimizer_variables = self.optimizer.variables
metrics_variables = self.metrics_variables
2023-04-18 04:26:04 +00:00
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
# Callbacks
callbacks.on_train_batch_begin(step)
2023-04-18 04:26:04 +00:00
2023-04-18 20:02:29 +00:00
# Train step
2023-04-18 21:49:38 +00:00
state = (
2023-04-18 20:02:29 +00:00
trainable_variables,
non_trainable_variables,
optimizer_variables,
2023-04-18 21:49:38 +00:00
metrics_variables,
2023-04-18 20:02:29 +00:00
)
2023-04-18 21:49:38 +00:00
logs, state = train_step(state, data)
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
2023-04-18 20:02:29 +00:00
2023-04-18 04:26:04 +00:00
# Callbacks
callbacks.on_train_batch_end(step, logs)
if self.stop_training:
break
2023-04-18 21:49:38 +00:00
# Update variable values
# NOTE: doing this after each step would be a big performance
# bottleneck.
2023-04-18 21:49:38 +00:00
for ref_v, v in zip(self.trainable_variables, trainable_variables):
ref_v.assign(v)
for ref_v, v in zip(
self.non_trainable_variables, non_trainable_variables
):
ref_v.assign(v)
for ref_v, v in zip(self.optimizer.variables, optimizer_variables):
ref_v.assign(v)
for ref_v, v in zip(self.metrics_variables, metrics_variables):
ref_v.assign(v)
# Override with model metrics instead of last step logs
2023-04-20 01:37:25 +00:00
epoch_logs = self._pythonify_logs(self.get_metrics_result())
# Run validation.
if validation_data and self._should_eval(epoch, validation_freq):
# Create EpochIterator for evaluation and cache it.
if getattr(self, "_eval_epoch_iterator", None) is None:
self._eval_epoch_iterator = EpochIterator(
x=val_x,
y=val_y,
sample_weight=val_sample_weight,
batch_size=validation_batch_size or batch_size,
steps_per_execution=self.steps_per_execution,
)
val_logs = self.evaluate(
x=val_x,
y=val_y,
sample_weight=val_sample_weight,
batch_size=validation_batch_size or batch_size,
steps=validation_steps,
callbacks=callbacks,
return_dict=True,
_use_cached_eval_dataset=True,
)
val_logs = {
"val_" + name: val for name, val in val_logs.items()
}
2023-04-20 01:37:25 +00:00
epoch_logs.update(self._pythonify_logs(val_logs))
callbacks.on_epoch_end(epoch, epoch_logs)
training_logs = epoch_logs
if self.stop_training:
break
if (
isinstance(self.optimizer, optimizers_module.Optimizer)
and epochs > 0
):
self.optimizer.finalize_variable_values(self.trainable_weights)
# If _eval_epoch_iterator exists, delete it after all epochs are done.
if getattr(self, "_eval_epoch_iterator", None) is not None:
del self._eval_epoch_iterator
callbacks.on_train_end(logs=training_logs)
return self.history
def evaluate(
self,
x=None,
y=None,
batch_size=None,
verbose="auto",
sample_weight=None,
steps=None,
callbacks=None,
return_dict=False,
**kwargs,
):
2023-04-20 01:37:25 +00:00
# TODO: respect compiled trainable state
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
if kwargs:
raise ValueError(f"Arguments not recognized: {kwargs}")
if use_cached_eval_dataset:
epoch_iterator = self._eval_epoch_iterator
else:
2023-04-20 20:21:41 +00:00
# Create an iterator that yields batches of input/target data.
2023-04-20 01:37:25 +00:00
epoch_iterator = EpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
steps_per_execution=self.steps_per_execution,
2023-04-20 01:37:25 +00:00
)
if not self.built:
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
data = data[0]
2023-04-20 01:37:25 +00:00
(
x,
y,
sample_weight,
) = data_adapter_utils.unpack_x_y_sample_weight(data)
# Build model
y_pred = self(x)
# Build metrics
self.compute_metrics(x, y, y_pred, sample_weight)
self.reset_metrics()
break
# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
verbose=verbose,
epochs=1,
steps=epoch_iterator.num_batches,
model=self,
)
def _test_step(state, data):
(
trainable_variables,
non_trainable_variables,
metrics_variables,
) = state
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(
data
)
loss, (
y_pred,
non_trainable_variables,
) = self.compute_loss_and_updates(
trainable_variables,
non_trainable_variables,
x,
y,
sample_weight,
)
with backend.StatelessScope(
state_mapping=[
(ref_v, v)
for ref_v, v in zip(
self.metrics_variables, metrics_variables
)
]
) as scope:
logs = self.compute_metrics(x, y, y_pred, sample_weight)
self._loss_tracker.update_state(loss)
new_metrics_variables = []
for ref_v in self.metrics_variables:
new_v = scope.get_current_value(ref_v)
if new_v is None:
new_v = ref_v.value
new_metrics_variables.append(new_v)
metrics_variables = new_metrics_variables
state = (
non_trainable_variables,
metrics_variables,
)
return logs, state
def _test_multi_step(state, data):
for single_step_data in data:
2023-05-16 17:52:14 +00:00
logs, state = _test_step(state, single_step_data)
return logs, state
2023-04-20 01:37:25 +00:00
if not self.run_eagerly and self.jit_compile:
@jax.jit
def test_step(state, data):
2023-05-16 17:52:14 +00:00
if self.steps_per_execution > 1:
return _test_multi_step(state, data)
return _test_step(state, data[0])
2023-04-20 01:37:25 +00:00
else:
2023-05-16 17:52:14 +00:00
test_step = _test_multi_step
2023-04-20 01:37:25 +00:00
callbacks.on_test_begin()
logs = None
self.reset_metrics()
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
metrics_variables = self.metrics_variables
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_test_batch_begin(step)
state = (
trainable_variables,
non_trainable_variables,
metrics_variables,
)
logs, state = test_step(state, data)
# Note that trainable variables are not returned since they're
# immutable here.
2023-05-16 17:52:14 +00:00
non_trainable_variables, metrics_variables = state
2023-04-20 01:37:25 +00:00
callbacks.on_test_batch_end(step, logs)
for ref_v, v in zip(
self.non_trainable_variables, non_trainable_variables
):
# I wouldn't recommend modifying non-trainable model state
# during evaluate(), but it's allowed.
ref_v.assign(v)
for ref_v, v in zip(self.metrics_variables, metrics_variables):
ref_v.assign(v)
logs = self._pythonify_logs(self.get_metrics_result())
callbacks.on_test_end(logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
):
2023-04-20 20:21:41 +00:00
# Create an iterator that yields batches of input data.
epoch_iterator = EpochIterator(
x=x,
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
steps_per_execution=self.steps_per_execution,
2023-04-20 20:21:41 +00:00
)
if not self.built:
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
# Build model
self(data[0])
2023-04-20 20:21:41 +00:00
break
# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
verbose=verbose,
epochs=1,
steps=epoch_iterator.num_batches,
model=self,
)
def _predict_multi_step(
trainable_variables, non_trainable_variables, data
):
2023-05-16 17:52:14 +00:00
return [
self.stateless_call(
trainable_variables,
non_trainable_variables,
2023-05-16 17:52:14 +00:00
single_step_data,
)
2023-05-16 17:52:14 +00:00
for single_step_data in data
]
2023-04-20 20:21:41 +00:00
if not self.run_eagerly and self.jit_compile:
@jax.jit
def predict_step(
trainable_variables, non_trainable_variables, data
):
2023-05-16 17:52:14 +00:00
if self.steps_per_execution > 1:
return _predict_multi_step(
trainable_variables, non_trainable_variables, data
)
return [
self.stateless_call(
trainable_variables, non_trainable_variables, data[0]
)
]
2023-04-20 20:21:41 +00:00
else:
2023-05-16 17:52:14 +00:00
predict_step = _predict_multi_step
2023-04-20 20:21:41 +00:00
callbacks.on_predict_begin()
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
outputs = None
for step, x in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_predict_batch_begin(step)
multi_step_return_values = predict_step(
2023-04-20 20:21:41 +00:00
trainable_variables, non_trainable_variables, x
)
for batch_outputs, _ in multi_step_return_values:
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(
batch_output
),
outputs,
batch_outputs,
)
2023-04-20 20:21:41 +00:00
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_end()
return tf.__internal__.nest.map_structure_up_to(
batch_outputs, np.concatenate, outputs
)