keras/keras_core/backend/jax/trainer.py

284 lines
9.5 KiB
Python
Raw Normal View History

import jax
2023-04-18 04:26:04 +00:00
from keras_core import backend
from keras_core import callbacks as callbacks_module
from keras_core import optimizers as optimizers_module
from keras_core.trainers import trainer as base_trainer
2023-04-19 04:11:21 +00:00
from keras_core.trainers.data_adapters import data_adapter_utils
from keras_core.trainers.epoch_iterator import EpochIterator
class Trainer(base_trainer.Trainer):
def stateless_compute_loss_and_updates(
2023-04-19 22:24:35 +00:00
self, trainable_variables, non_trainable_variables, x, y, sample_weight
):
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x
)
2023-04-19 22:24:35 +00:00
loss = self.compute_loss(x, y, y_pred, sample_weight)
return loss, (y_pred, non_trainable_variables)
def fit(
self,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose="auto",
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
):
if not self.compiled:
raise ValueError(
"You must call `compile()` before calling `fit()`."
)
2023-04-18 04:26:04 +00:00
# TODO: respect compiled trainable state
if validation_split and validation_data is None:
# Create the validation data using the training data. Only supported
# for TF/numpy/jax arrays.
(
x,
y,
sample_weight,
2023-04-19 04:11:21 +00:00
), validation_data = data_adapter_utils.train_validation_split(
(x, y, sample_weight), validation_split=validation_split
)
if validation_data:
(
val_x,
val_y,
val_sample_weight,
2023-04-19 04:11:21 +00:00
) = data_adapter_utils.unpack_x_y_sample_weight(validation_data)
2023-04-18 04:26:04 +00:00
# Create an iterator that yields batches for one epoch.
epoch_iterator = EpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps_per_epoch,
shuffle=shuffle,
class_weight=class_weight,
)
2023-04-18 04:26:04 +00:00
if not self.built:
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
2023-04-18 21:49:38 +00:00
(
x,
y,
sample_weight,
2023-04-19 04:11:21 +00:00
) = data_adapter_utils.unpack_x_y_sample_weight(data)
2023-04-18 04:26:04 +00:00
# Build model
2023-04-18 21:49:38 +00:00
y_pred = self(x)
# Build metrics
self.compute_metrics(x, y, y_pred, sample_weight)
self.reset_metrics()
2023-04-18 04:26:04 +00:00
break
if not self.optimizer.built:
# Build optimizer
self.optimizer.build(self.trainable_variables)
2023-04-18 04:26:04 +00:00
# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
verbose=verbose,
epochs=epochs,
steps=epoch_iterator.num_batches,
model=self,
)
2023-04-19 22:24:35 +00:00
grad_fn = jax.value_and_grad(
self.stateless_compute_loss_and_updates, has_aux=True
)
def _train_step(state, data):
2023-04-18 21:49:38 +00:00
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
2023-04-19 04:11:21 +00:00
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(
2023-04-18 21:49:38 +00:00
data
)
2023-04-18 20:02:29 +00:00
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
2023-04-19 22:24:35 +00:00
trainable_variables,
non_trainable_variables,
x,
y,
sample_weight,
2023-04-18 20:02:29 +00:00
)
2023-04-18 21:49:38 +00:00
(
trainable_variables,
optimizer_variables,
) = self.optimizer.stateless_apply(
grads, trainable_variables, optimizer_variables
)
2023-04-18 20:02:29 +00:00
2023-04-18 21:49:38 +00:00
with backend.StatelessScope(
state_mapping=[
(ref_v, v)
for ref_v, v in zip(
self.metrics_variables, metrics_variables
)
]
) as scope:
logs = self.compute_metrics(x, y, y_pred, sample_weight)
self._loss_tracker.update_state(loss)
new_metrics_variables = []
for ref_v in self.metrics_variables:
new_v = scope.get_current_value(ref_v)
if new_v is None:
new_v = ref_v.value
new_metrics_variables.append(new_v)
metrics_variables = new_metrics_variables
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
)
return logs, state
2023-04-19 22:24:35 +00:00
2023-04-19 21:35:28 +00:00
if not self.run_eagerly and self.jit_compile:
2023-04-19 22:24:35 +00:00
@jax.jit
def train_step(state, data):
return _train_step(state, data)
2023-04-19 22:24:35 +00:00
else:
train_step = _train_step
2023-04-19 22:24:35 +00:00
self.stop_training = False
callbacks.on_train_begin()
2023-04-18 20:02:29 +00:00
for epoch in range(initial_epoch, epochs):
self.reset_metrics()
callbacks.on_epoch_begin(epoch)
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
optimizer_variables = self.optimizer.variables
metrics_variables = self.metrics_variables
2023-04-18 04:26:04 +00:00
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
# Callbacks
callbacks.on_train_batch_begin(step)
2023-04-18 04:26:04 +00:00
2023-04-18 20:02:29 +00:00
# Train step
2023-04-18 21:49:38 +00:00
state = (
2023-04-18 20:02:29 +00:00
trainable_variables,
non_trainable_variables,
optimizer_variables,
2023-04-18 21:49:38 +00:00
metrics_variables,
2023-04-18 20:02:29 +00:00
)
2023-04-18 21:49:38 +00:00
logs, state = train_step(state, data)
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
2023-04-18 20:02:29 +00:00
2023-04-18 04:26:04 +00:00
# Callbacks
callbacks.on_train_batch_end(step, logs)
if self.stop_training:
break
2023-04-18 21:49:38 +00:00
# Update variable values
# NOTE: doing this after each step would be a big performance bottleneck.
for ref_v, v in zip(self.trainable_variables, trainable_variables):
ref_v.assign(v)
for ref_v, v in zip(
self.non_trainable_variables, non_trainable_variables
):
ref_v.assign(v)
for ref_v, v in zip(self.optimizer.variables, optimizer_variables):
ref_v.assign(v)
for ref_v, v in zip(self.metrics_variables, metrics_variables):
ref_v.assign(v)
# Override with model metrics instead of last step logs
2023-04-18 21:49:38 +00:00
epoch_logs = self._process_logs(self.get_metrics_result())
# Run validation.
if validation_data and self._should_eval(epoch, validation_freq):
# Create EpochIterator for evaluation and cache it.
if getattr(self, "_eval_epoch_iterator", None) is None:
self._eval_epoch_iterator = EpochIterator(
x=val_x,
y=val_y,
sample_weight=val_sample_weight,
batch_size=validation_batch_size or batch_size,
epochs=1,
)
val_logs = self.evaluate(
x=val_x,
y=val_y,
sample_weight=val_sample_weight,
batch_size=validation_batch_size or batch_size,
steps=validation_steps,
callbacks=callbacks,
return_dict=True,
_use_cached_eval_dataset=True,
)
val_logs = {
"val_" + name: val for name, val in val_logs.items()
}
epoch_logs.update(self._process_logs(val_logs))
callbacks.on_epoch_end(epoch, epoch_logs)
training_logs = epoch_logs
if self.stop_training:
break
if (
isinstance(self.optimizer, optimizers_module.Optimizer)
and epochs > 0
):
self.optimizer.finalize_variable_values(self.trainable_weights)
# If _eval_epoch_iterator exists, delete it after all epochs are done.
if getattr(self, "_eval_epoch_iterator", None) is not None:
del self._eval_epoch_iterator
callbacks.on_train_end(logs=training_logs)
return self.history
def evaluate(
self,
x=None,
y=None,
batch_size=None,
verbose="auto",
sample_weight=None,
steps=None,
callbacks=None,
return_dict=False,
**kwargs,
):
raise NotImplementedError
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
):
2023-04-19 22:24:35 +00:00
raise NotImplementedError