keras/keras_core/backend/jax/trainer.py
2023-06-12 10:40:40 -07:00

811 lines
28 KiB
Python

import jax
import numpy as np
import tensorflow as tf # for nest
from keras_core import backend
from keras_core import callbacks as callbacks_module
from keras_core import optimizers as optimizers_module
from keras_core.trainers import trainer as base_trainer
from keras_core.trainers.data_adapters import data_adapter_utils
from keras_core.trainers.epoch_iterator import EpochIterator
from keras_core.utils import traceback_utils
class JAXTrainer(base_trainer.Trainer):
def __init__(self):
super().__init__()
self.train_function = None
self.test_function = None
self.predict_function = None
def compute_loss_and_updates(
self,
trainable_variables,
non_trainable_variables,
x,
y,
sample_weight,
training=False,
):
"""This method is stateless and is intended for use with jax.grad."""
kwargs = {}
if self._call_has_training_arg():
kwargs["training"] = training
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x, **kwargs
)
loss = self.compute_loss(x, y, y_pred, sample_weight)
return loss, (y_pred, non_trainable_variables)
def _eager_build(self, data_batch):
compile_metrics_unbuilt = (
self._compile_metrics is not None
and not self._compile_metrics.built
)
if not self.built or compile_metrics_unbuilt:
# Build the model on one batch of data.
(
x,
y,
sample_weight,
) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)
# Build model
with backend.StatelessScope():
y_pred = self(x)
if compile_metrics_unbuilt:
# Build metrics
self.compute_metrics(
x, y, y_pred, sample_weight=sample_weight
)
if self.optimizer is not None and not self.optimizer.built:
# Build optimizer
self.optimizer.build(self.trainable_variables)
def make_train_function(self, force=False):
if self.train_function is not None and not force:
return self.train_function
grad_fn = jax.value_and_grad(
self.compute_loss_and_updates, has_aux=True
)
def one_train_step(state, data):
data = data[0]
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(
data
)
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
trainable_variables,
non_trainable_variables,
x,
y,
sample_weight,
training=True,
)
(
trainable_variables,
optimizer_variables,
) = self.optimizer.stateless_apply(
grads, trainable_variables, optimizer_variables
)
with backend.StatelessScope(
state_mapping=[
(ref_v, v)
for ref_v, v in zip(
self.metrics_variables, metrics_variables
)
]
) as scope:
self._loss_tracker.update_state(loss)
logs = self.compute_metrics(x, y, y_pred, sample_weight)
new_metrics_variables = []
for ref_v in self.metrics_variables:
new_v = scope.get_current_value(ref_v)
if new_v is None:
new_v = ref_v.value
new_metrics_variables.append(new_v)
metrics_variables = new_metrics_variables
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
)
return logs, state
def multi_train_steps(state, data):
for single_step_data in data:
logs, state = one_train_step(state, [single_step_data])
return logs, state
if self.steps_per_execution > 1:
train_step = multi_train_steps
else:
train_step = one_train_step
if not self.run_eagerly and self.jit_compile:
@jax.jit
def compiled_train_step(state, data):
return train_step(state, data)
self.train_function = compiled_train_step
else:
self.train_function = train_step
def make_test_function(self, force=False):
if self.test_function is not None and not force:
return self.test_function
def one_test_step(state, data):
data = data[0]
(
trainable_variables,
non_trainable_variables,
metrics_variables,
) = state
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(
data
)
loss, (
y_pred,
non_trainable_variables,
) = self.compute_loss_and_updates(
trainable_variables,
non_trainable_variables,
x,
y,
sample_weight,
training=False,
)
with backend.StatelessScope(
state_mapping=[
(ref_v, v)
for ref_v, v in zip(
self.metrics_variables, metrics_variables
)
]
) as scope:
self._loss_tracker.update_state(loss)
logs = self.compute_metrics(x, y, y_pred, sample_weight)
new_metrics_variables = []
for ref_v in self.metrics_variables:
new_v = scope.get_current_value(ref_v)
if new_v is None:
new_v = ref_v.value
new_metrics_variables.append(new_v)
metrics_variables = new_metrics_variables
state = (
trainable_variables,
non_trainable_variables,
metrics_variables,
)
return logs, state
def multi_test_steps(state, data):
for single_step_data in data:
logs, state = one_test_step(state, [single_step_data])
return logs, state
if self.steps_per_execution > 1:
test_step = multi_test_steps
else:
test_step = one_test_step
if not self.run_eagerly and self.jit_compile:
@jax.jit
def compiled_test_step(state, data):
return test_step(state, data)
self.test_function = compiled_test_step
else:
self.test_function = test_step
def make_predict_function(self, force=False):
if self.predict_function is not None and not force:
return self.predict_function
def one_predict_step(
trainable_variables, non_trainable_variables, data
):
kwargs = {}
if self._call_has_training_arg():
kwargs["training"] = False
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data[0])
outputs, _ = self.stateless_call(
trainable_variables, non_trainable_variables, x, **kwargs
)
return outputs
def multi_predict_steps(
trainable_variables, non_trainable_variables, data
):
outputs = one_predict_step(
trainable_variables, non_trainable_variables, data[:1]
)
for single_step_data in data[1:]:
step_outputs = one_predict_step(
trainable_variables,
non_trainable_variables,
[single_step_data],
)
outputs = tf.nest.map_structure(
lambda t1, t2: jax.numpy.concatenate([t1, t2]),
outputs,
step_outputs,
)
return outputs
if self.steps_per_execution > 1:
predict_step = multi_predict_steps
else:
predict_step = one_predict_step
if not self.run_eagerly and self.jit_compile:
@jax.jit
def compiled_predict_step(
trainable_variables, non_trainable_variables, data
):
return predict_step(
trainable_variables, non_trainable_variables, data
)
self.predict_function = compiled_predict_step
else:
self.predict_function = predict_step
@traceback_utils.filter_traceback
def fit(
self,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose="auto",
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
):
self._assert_compile_called("fit")
# TODO: respect compiled trainable state
if validation_split and validation_data is None:
# Create the validation data using the training data. Only supported
# for TF/numpy/jax arrays.
(
x,
y,
sample_weight,
), validation_data = data_adapter_utils.train_validation_split(
(x, y, sample_weight), validation_split=validation_split
)
if validation_data:
(
val_x,
val_y,
val_sample_weight,
) = data_adapter_utils.unpack_x_y_sample_weight(validation_data)
# Create an iterator that yields batches for one epoch.
epoch_iterator = EpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps_per_epoch,
shuffle=shuffle,
class_weight=class_weight,
steps_per_execution=self.steps_per_execution,
)
needs_building = (
not self.built
or not self.optimizer.built
or (
self._compile_metrics is not None
and not self._compile_metrics.built
)
)
if needs_building:
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
data_batch = data[0]
self._eager_build(data_batch)
break
# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
verbose=verbose,
epochs=epochs,
steps=epoch_iterator.num_batches,
model=self,
)
self.make_train_function()
self.stop_training = False
callbacks.on_train_begin()
for epoch in range(initial_epoch, epochs):
self.reset_metrics()
callbacks.on_epoch_begin(epoch)
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
optimizer_variables = self.optimizer.variables
metrics_variables = self.metrics_variables
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
# Callbacks
callbacks.on_train_batch_begin(step)
# Train step
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
)
logs, state = self.train_function(state, data)
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
# Setting _jax_state enables callbacks to force a state sync
# if they need to.
self._jax_state = {
"trainable_variables": trainable_variables,
"non_trainable_variables": non_trainable_variables,
"optimizer_variables": optimizer_variables,
"metrics_variables": metrics_variables,
}
# Callbacks
callbacks.on_train_batch_end(step, self._pythonify_logs(logs))
if self.stop_training:
break
# Reattach state to model variables.
# NOTE: doing this after each step would be a big performance
# bottleneck.
self.jax_state_sync()
# Override with model metrics instead of last step logs
epoch_logs = self.get_metrics_result()
# Run validation.
if validation_data and self._should_eval(epoch, validation_freq):
# Create EpochIterator for evaluation and cache it.
if getattr(self, "_eval_epoch_iterator", None) is None:
self._eval_epoch_iterator = EpochIterator(
x=val_x,
y=val_y,
sample_weight=val_sample_weight,
batch_size=validation_batch_size or batch_size,
steps_per_execution=self.steps_per_execution,
)
val_logs = self.evaluate(
x=val_x,
y=val_y,
sample_weight=val_sample_weight,
batch_size=validation_batch_size or batch_size,
steps=validation_steps,
callbacks=callbacks,
return_dict=True,
_use_cached_eval_dataset=True,
)
val_logs = {
"val_" + name: val for name, val in val_logs.items()
}
epoch_logs.update(val_logs)
callbacks.on_epoch_end(epoch, epoch_logs)
training_logs = epoch_logs
if self.stop_training:
break
if (
isinstance(self.optimizer, optimizers_module.Optimizer)
and epochs > 0
):
self.optimizer.finalize_variable_values(self.trainable_weights)
# If _eval_epoch_iterator exists, delete it after all epochs are done.
if getattr(self, "_eval_epoch_iterator", None) is not None:
del self._eval_epoch_iterator
callbacks.on_train_end(logs=training_logs)
self._jax_state = None
return self.history
@traceback_utils.filter_traceback
def evaluate(
self,
x=None,
y=None,
batch_size=None,
verbose="auto",
sample_weight=None,
steps=None,
callbacks=None,
return_dict=False,
**kwargs,
):
self._assert_compile_called("evaluate")
# TODO: respect compiled trainable state
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
if kwargs:
raise ValueError(f"Arguments not recognized: {kwargs}")
if use_cached_eval_dataset:
epoch_iterator = self._eval_epoch_iterator
else:
# Create an iterator that yields batches of input/target data.
epoch_iterator = EpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
steps_per_execution=self.steps_per_execution,
)
if not self.built:
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
data_batch = data[0]
self._eager_build(data_batch)
break
# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
verbose=verbose,
epochs=1,
steps=epoch_iterator.num_batches,
model=self,
)
self.make_test_function()
callbacks.on_test_begin()
logs = None
self.reset_metrics()
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
metrics_variables = self.metrics_variables
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_test_batch_begin(step)
state = (
trainable_variables,
non_trainable_variables,
metrics_variables,
)
logs, state = self.test_function(state, data)
# Note that trainable variables are not returned since they're
# immutable here.
_, non_trainable_variables, metrics_variables = state
# Setting _jax_state enables callbacks to force a state sync
# if they need to.
self._jax_state = {
# I wouldn't recommend modifying non-trainable model state
# during evaluate(), but it's allowed.
"non_trainable_variables": non_trainable_variables,
"metrics_variables": metrics_variables,
}
callbacks.on_test_batch_end(step, self._pythonify_logs(logs))
# Reattach state back to model.
self.jax_state_sync()
logs = self.get_metrics_result()
callbacks.on_test_end(logs)
self._jax_state = None
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
@traceback_utils.filter_traceback
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
):
# Create an iterator that yields batches of input data.
epoch_iterator = EpochIterator(
x=x,
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
steps_per_execution=self.steps_per_execution,
)
if not self.built:
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
# Build model
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data[0])
with backend.StatelessScope():
self(x)
break
# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
verbose=verbose,
epochs=1,
steps=epoch_iterator.num_batches,
model=self,
)
self.make_predict_function()
callbacks.on_predict_begin()
def append_to_outputs(batch_outputs, outputs):
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(batch_output),
outputs,
batch_outputs,
)
return outputs
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
outputs = None
for step, x in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_predict_batch_begin(step)
batch_outputs = self.predict_function(
trainable_variables, non_trainable_variables, x
)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_end()
return tf.__internal__.nest.map_structure_up_to(
batch_outputs, np.concatenate, outputs
)
def train_on_batch(
self,
x,
y=None,
sample_weight=None,
class_weight=None,
return_dict=False,
):
"""Runs a single gradient update on a single batch of data.
Args:
x: Input data. Must be array-like.
y: Target data. Must be array-like.
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample.
In the case of temporal data, you can pass a 2D array
with shape `(samples, sequence_length)`, to apply a different
weight to every timestep of every sample.
class_weight: Optional dictionary mapping class indices (integers)
to a weight (float) to apply to the model's loss for the samples
from this class during training. This can be useful to tell the
model to "pay more attention" to samples from an
under-represented class. When `class_weight` is specified
and targets have a rank of 2 or greater, either `y` must
be one-hot encoded, or an explicit final dimension of 1
must be included for sparse class labels.
return_dict: If `True`, loss and metric results are returned as a
dict, with each key being the name of the metric. If `False`,
they are returned as a list.
Returns:
A scalar loss value (when no metrics and `return_dict=False`),
a list of loss and metric values
(if there are metrics and `return_dict=False`), or a dict of
metric and loss values (if `return_dict=True`).
"""
self._assert_compile_called("train_on_batch")
if class_weight is not None:
if sample_weight is not None:
raise ValueError(
"Arguments `sample_weight` and `class_weight` "
"cannot be specified at the same time. "
f"Received: sample_weight={sample_weight}, "
f"class_weight={class_weight}"
)
sample_weight = data_adapter_utils.class_weight_to_sample_weights(
y, class_weight
)
data = (x, y, sample_weight)
# Maybe build model
self._eager_build(data)
self.make_train_function()
# Train step
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
optimizer_variables = self.optimizer.variables
metrics_variables = self.metrics_variables
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
)
logs, state = self.train_function(state, [data])
# State sync
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
self._jax_state = {
"trainable_variables": trainable_variables,
"non_trainable_variables": non_trainable_variables,
"optimizer_variables": optimizer_variables,
"metrics_variables": metrics_variables,
}
self.jax_state_sync()
# Format return values
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
def test_on_batch(
self,
x,
y=None,
sample_weight=None,
return_dict=False,
):
"""Test the model on a single batch of samples.
Args:
x: Input data. Must be array-like.
y: Target data. Must be array-like.
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample.
In the case of temporal data, you can pass a 2D array
with shape `(samples, sequence_length)`, to apply a different
weight to every timestep of every sample.
return_dict: If `True`, loss and metric results are returned as a
dict, with each key being the name of the metric. If `False`,
they are returned as a list.
Returns:
A scalar loss value (when no metrics and `return_dict=False`),
a list of loss and metric values
(if there are metrics and `return_dict=False`), or a dict of
metric and loss values (if `return_dict=True`).
"""
self._assert_compile_called("test_on_batch")
data = (x, y, sample_weight)
# Maybe build model
self._eager_build(data)
self.make_test_function()
# Test step
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
metrics_variables = self.metrics_variables
state = (
trainable_variables,
non_trainable_variables,
metrics_variables,
)
logs, state = self.test_function(state, [data])
# State sync
_, non_trainable_variables, metrics_variables = state
self._jax_state = {
"non_trainable_variables": non_trainable_variables,
"metrics_variables": metrics_variables,
}
self.jax_state_sync()
# Format return values.
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
def predict_on_batch(self, x):
"""Returns predictions for a single batch of samples.
Args:
x: Input data. It must be array-like.
Returns:
NumPy array(s) of predictions.
"""
if not self.built:
# Build model
with backend.StatelessScope():
self(x)
self.make_predict_function()
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
batch_outputs = self.predict_function(
trainable_variables, non_trainable_variables, [x]
)
batch_outputs = tf.nest.map_structure(
lambda x: np.array(x), batch_outputs
)
return batch_outputs
def jax_state_sync(self):
if not getattr(self, "_jax_state", None):
return
trainable_variables = self._jax_state.get("trainable_variables", None)
non_trainable_variables = self._jax_state.get(
"non_trainable_variables", None
)
optimizer_variables = self._jax_state.get("optimizer_variables", None)
metrics_variables = self._jax_state.get("metrics_variables", None)
if trainable_variables:
for ref_v, v in zip(self.trainable_variables, trainable_variables):
ref_v.assign(v)
if non_trainable_variables:
for ref_v, v in zip(
self.non_trainable_variables, non_trainable_variables
):
ref_v.assign(v)
if optimizer_variables:
for ref_v, v in zip(self.optimizer.variables, optimizer_variables):
ref_v.assign(v)
if metrics_variables:
for ref_v, v in zip(self.metrics_variables, metrics_variables):
ref_v.assign(v)