811 lines
28 KiB
Python
811 lines
28 KiB
Python
import jax
|
|
import numpy as np
|
|
import tensorflow as tf # for nest
|
|
|
|
from keras_core import backend
|
|
from keras_core import callbacks as callbacks_module
|
|
from keras_core import optimizers as optimizers_module
|
|
from keras_core.trainers import trainer as base_trainer
|
|
from keras_core.trainers.data_adapters import data_adapter_utils
|
|
from keras_core.trainers.epoch_iterator import EpochIterator
|
|
from keras_core.utils import traceback_utils
|
|
|
|
|
|
class JAXTrainer(base_trainer.Trainer):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.train_function = None
|
|
self.test_function = None
|
|
self.predict_function = None
|
|
|
|
def compute_loss_and_updates(
|
|
self,
|
|
trainable_variables,
|
|
non_trainable_variables,
|
|
x,
|
|
y,
|
|
sample_weight,
|
|
training=False,
|
|
):
|
|
"""This method is stateless and is intended for use with jax.grad."""
|
|
kwargs = {}
|
|
if self._call_has_training_arg():
|
|
kwargs["training"] = training
|
|
y_pred, non_trainable_variables = self.stateless_call(
|
|
trainable_variables, non_trainable_variables, x, **kwargs
|
|
)
|
|
|
|
loss = self.compute_loss(x, y, y_pred, sample_weight)
|
|
return loss, (y_pred, non_trainable_variables)
|
|
|
|
def _eager_build(self, data_batch):
|
|
compile_metrics_unbuilt = (
|
|
self._compile_metrics is not None
|
|
and not self._compile_metrics.built
|
|
)
|
|
if not self.built or compile_metrics_unbuilt:
|
|
# Build the model on one batch of data.
|
|
(
|
|
x,
|
|
y,
|
|
sample_weight,
|
|
) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)
|
|
# Build model
|
|
with backend.StatelessScope():
|
|
y_pred = self(x)
|
|
if compile_metrics_unbuilt:
|
|
# Build metrics
|
|
self.compute_metrics(
|
|
x, y, y_pred, sample_weight=sample_weight
|
|
)
|
|
if self.optimizer is not None and not self.optimizer.built:
|
|
# Build optimizer
|
|
self.optimizer.build(self.trainable_variables)
|
|
|
|
def make_train_function(self, force=False):
|
|
if self.train_function is not None and not force:
|
|
return self.train_function
|
|
|
|
grad_fn = jax.value_and_grad(
|
|
self.compute_loss_and_updates, has_aux=True
|
|
)
|
|
|
|
def one_train_step(state, data):
|
|
data = data[0]
|
|
(
|
|
trainable_variables,
|
|
non_trainable_variables,
|
|
optimizer_variables,
|
|
metrics_variables,
|
|
) = state
|
|
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(
|
|
data
|
|
)
|
|
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
|
|
trainable_variables,
|
|
non_trainable_variables,
|
|
x,
|
|
y,
|
|
sample_weight,
|
|
training=True,
|
|
)
|
|
|
|
(
|
|
trainable_variables,
|
|
optimizer_variables,
|
|
) = self.optimizer.stateless_apply(
|
|
grads, trainable_variables, optimizer_variables
|
|
)
|
|
|
|
with backend.StatelessScope(
|
|
state_mapping=[
|
|
(ref_v, v)
|
|
for ref_v, v in zip(
|
|
self.metrics_variables, metrics_variables
|
|
)
|
|
]
|
|
) as scope:
|
|
self._loss_tracker.update_state(loss)
|
|
logs = self.compute_metrics(x, y, y_pred, sample_weight)
|
|
|
|
new_metrics_variables = []
|
|
for ref_v in self.metrics_variables:
|
|
new_v = scope.get_current_value(ref_v)
|
|
if new_v is None:
|
|
new_v = ref_v.value
|
|
new_metrics_variables.append(new_v)
|
|
metrics_variables = new_metrics_variables
|
|
|
|
state = (
|
|
trainable_variables,
|
|
non_trainable_variables,
|
|
optimizer_variables,
|
|
metrics_variables,
|
|
)
|
|
return logs, state
|
|
|
|
def multi_train_steps(state, data):
|
|
for single_step_data in data:
|
|
logs, state = one_train_step(state, [single_step_data])
|
|
return logs, state
|
|
|
|
if self.steps_per_execution > 1:
|
|
train_step = multi_train_steps
|
|
else:
|
|
train_step = one_train_step
|
|
|
|
if not self.run_eagerly and self.jit_compile:
|
|
|
|
@jax.jit
|
|
def compiled_train_step(state, data):
|
|
return train_step(state, data)
|
|
|
|
self.train_function = compiled_train_step
|
|
|
|
else:
|
|
self.train_function = train_step
|
|
|
|
def make_test_function(self, force=False):
|
|
if self.test_function is not None and not force:
|
|
return self.test_function
|
|
|
|
def one_test_step(state, data):
|
|
data = data[0]
|
|
(
|
|
trainable_variables,
|
|
non_trainable_variables,
|
|
metrics_variables,
|
|
) = state
|
|
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(
|
|
data
|
|
)
|
|
loss, (
|
|
y_pred,
|
|
non_trainable_variables,
|
|
) = self.compute_loss_and_updates(
|
|
trainable_variables,
|
|
non_trainable_variables,
|
|
x,
|
|
y,
|
|
sample_weight,
|
|
training=False,
|
|
)
|
|
|
|
with backend.StatelessScope(
|
|
state_mapping=[
|
|
(ref_v, v)
|
|
for ref_v, v in zip(
|
|
self.metrics_variables, metrics_variables
|
|
)
|
|
]
|
|
) as scope:
|
|
self._loss_tracker.update_state(loss)
|
|
logs = self.compute_metrics(x, y, y_pred, sample_weight)
|
|
|
|
new_metrics_variables = []
|
|
for ref_v in self.metrics_variables:
|
|
new_v = scope.get_current_value(ref_v)
|
|
if new_v is None:
|
|
new_v = ref_v.value
|
|
new_metrics_variables.append(new_v)
|
|
metrics_variables = new_metrics_variables
|
|
|
|
state = (
|
|
trainable_variables,
|
|
non_trainable_variables,
|
|
metrics_variables,
|
|
)
|
|
return logs, state
|
|
|
|
def multi_test_steps(state, data):
|
|
for single_step_data in data:
|
|
logs, state = one_test_step(state, [single_step_data])
|
|
return logs, state
|
|
|
|
if self.steps_per_execution > 1:
|
|
test_step = multi_test_steps
|
|
else:
|
|
test_step = one_test_step
|
|
|
|
if not self.run_eagerly and self.jit_compile:
|
|
|
|
@jax.jit
|
|
def compiled_test_step(state, data):
|
|
return test_step(state, data)
|
|
|
|
self.test_function = compiled_test_step
|
|
|
|
else:
|
|
self.test_function = test_step
|
|
|
|
def make_predict_function(self, force=False):
|
|
if self.predict_function is not None and not force:
|
|
return self.predict_function
|
|
|
|
def one_predict_step(
|
|
trainable_variables, non_trainable_variables, data
|
|
):
|
|
kwargs = {}
|
|
if self._call_has_training_arg():
|
|
kwargs["training"] = False
|
|
|
|
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data[0])
|
|
outputs, _ = self.stateless_call(
|
|
trainable_variables, non_trainable_variables, x, **kwargs
|
|
)
|
|
return outputs
|
|
|
|
def multi_predict_steps(
|
|
trainable_variables, non_trainable_variables, data
|
|
):
|
|
outputs = one_predict_step(
|
|
trainable_variables, non_trainable_variables, data[:1]
|
|
)
|
|
for single_step_data in data[1:]:
|
|
step_outputs = one_predict_step(
|
|
trainable_variables,
|
|
non_trainable_variables,
|
|
[single_step_data],
|
|
)
|
|
outputs = tf.nest.map_structure(
|
|
lambda t1, t2: jax.numpy.concatenate([t1, t2]),
|
|
outputs,
|
|
step_outputs,
|
|
)
|
|
return outputs
|
|
|
|
if self.steps_per_execution > 1:
|
|
predict_step = multi_predict_steps
|
|
else:
|
|
predict_step = one_predict_step
|
|
|
|
if not self.run_eagerly and self.jit_compile:
|
|
|
|
@jax.jit
|
|
def compiled_predict_step(
|
|
trainable_variables, non_trainable_variables, data
|
|
):
|
|
return predict_step(
|
|
trainable_variables, non_trainable_variables, data
|
|
)
|
|
|
|
self.predict_function = compiled_predict_step
|
|
|
|
else:
|
|
self.predict_function = predict_step
|
|
|
|
@traceback_utils.filter_traceback
|
|
def fit(
|
|
self,
|
|
x=None,
|
|
y=None,
|
|
batch_size=None,
|
|
epochs=1,
|
|
verbose="auto",
|
|
callbacks=None,
|
|
validation_split=0.0,
|
|
validation_data=None,
|
|
shuffle=True,
|
|
class_weight=None,
|
|
sample_weight=None,
|
|
initial_epoch=0,
|
|
steps_per_epoch=None,
|
|
validation_steps=None,
|
|
validation_batch_size=None,
|
|
validation_freq=1,
|
|
):
|
|
self._assert_compile_called("fit")
|
|
# TODO: respect compiled trainable state
|
|
if validation_split and validation_data is None:
|
|
# Create the validation data using the training data. Only supported
|
|
# for TF/numpy/jax arrays.
|
|
(
|
|
x,
|
|
y,
|
|
sample_weight,
|
|
), validation_data = data_adapter_utils.train_validation_split(
|
|
(x, y, sample_weight), validation_split=validation_split
|
|
)
|
|
|
|
if validation_data:
|
|
(
|
|
val_x,
|
|
val_y,
|
|
val_sample_weight,
|
|
) = data_adapter_utils.unpack_x_y_sample_weight(validation_data)
|
|
|
|
# Create an iterator that yields batches for one epoch.
|
|
epoch_iterator = EpochIterator(
|
|
x=x,
|
|
y=y,
|
|
sample_weight=sample_weight,
|
|
batch_size=batch_size,
|
|
steps_per_epoch=steps_per_epoch,
|
|
shuffle=shuffle,
|
|
class_weight=class_weight,
|
|
steps_per_execution=self.steps_per_execution,
|
|
)
|
|
|
|
needs_building = (
|
|
not self.built
|
|
or not self.optimizer.built
|
|
or (
|
|
self._compile_metrics is not None
|
|
and not self._compile_metrics.built
|
|
)
|
|
)
|
|
if needs_building:
|
|
# Build the model on one batch of data.
|
|
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
|
|
data_batch = data[0]
|
|
self._eager_build(data_batch)
|
|
break
|
|
|
|
# Container that configures and calls callbacks.
|
|
if not isinstance(callbacks, callbacks_module.CallbackList):
|
|
callbacks = callbacks_module.CallbackList(
|
|
callbacks,
|
|
add_history=True,
|
|
add_progbar=verbose != 0,
|
|
verbose=verbose,
|
|
epochs=epochs,
|
|
steps=epoch_iterator.num_batches,
|
|
model=self,
|
|
)
|
|
|
|
self.make_train_function()
|
|
self.stop_training = False
|
|
callbacks.on_train_begin()
|
|
|
|
for epoch in range(initial_epoch, epochs):
|
|
self.reset_metrics()
|
|
callbacks.on_epoch_begin(epoch)
|
|
|
|
trainable_variables = self.trainable_variables
|
|
non_trainable_variables = self.non_trainable_variables
|
|
optimizer_variables = self.optimizer.variables
|
|
metrics_variables = self.metrics_variables
|
|
|
|
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
|
|
# Callbacks
|
|
callbacks.on_train_batch_begin(step)
|
|
|
|
# Train step
|
|
state = (
|
|
trainable_variables,
|
|
non_trainable_variables,
|
|
optimizer_variables,
|
|
metrics_variables,
|
|
)
|
|
logs, state = self.train_function(state, data)
|
|
(
|
|
trainable_variables,
|
|
non_trainable_variables,
|
|
optimizer_variables,
|
|
metrics_variables,
|
|
) = state
|
|
|
|
# Setting _jax_state enables callbacks to force a state sync
|
|
# if they need to.
|
|
self._jax_state = {
|
|
"trainable_variables": trainable_variables,
|
|
"non_trainable_variables": non_trainable_variables,
|
|
"optimizer_variables": optimizer_variables,
|
|
"metrics_variables": metrics_variables,
|
|
}
|
|
|
|
# Callbacks
|
|
callbacks.on_train_batch_end(step, self._pythonify_logs(logs))
|
|
if self.stop_training:
|
|
break
|
|
|
|
# Reattach state to model variables.
|
|
# NOTE: doing this after each step would be a big performance
|
|
# bottleneck.
|
|
self.jax_state_sync()
|
|
|
|
# Override with model metrics instead of last step logs
|
|
epoch_logs = self.get_metrics_result()
|
|
|
|
# Run validation.
|
|
if validation_data and self._should_eval(epoch, validation_freq):
|
|
# Create EpochIterator for evaluation and cache it.
|
|
if getattr(self, "_eval_epoch_iterator", None) is None:
|
|
self._eval_epoch_iterator = EpochIterator(
|
|
x=val_x,
|
|
y=val_y,
|
|
sample_weight=val_sample_weight,
|
|
batch_size=validation_batch_size or batch_size,
|
|
steps_per_execution=self.steps_per_execution,
|
|
)
|
|
val_logs = self.evaluate(
|
|
x=val_x,
|
|
y=val_y,
|
|
sample_weight=val_sample_weight,
|
|
batch_size=validation_batch_size or batch_size,
|
|
steps=validation_steps,
|
|
callbacks=callbacks,
|
|
return_dict=True,
|
|
_use_cached_eval_dataset=True,
|
|
)
|
|
val_logs = {
|
|
"val_" + name: val for name, val in val_logs.items()
|
|
}
|
|
epoch_logs.update(val_logs)
|
|
|
|
callbacks.on_epoch_end(epoch, epoch_logs)
|
|
training_logs = epoch_logs
|
|
if self.stop_training:
|
|
break
|
|
|
|
if (
|
|
isinstance(self.optimizer, optimizers_module.Optimizer)
|
|
and epochs > 0
|
|
):
|
|
self.optimizer.finalize_variable_values(self.trainable_weights)
|
|
|
|
# If _eval_epoch_iterator exists, delete it after all epochs are done.
|
|
if getattr(self, "_eval_epoch_iterator", None) is not None:
|
|
del self._eval_epoch_iterator
|
|
callbacks.on_train_end(logs=training_logs)
|
|
self._jax_state = None
|
|
return self.history
|
|
|
|
@traceback_utils.filter_traceback
|
|
def evaluate(
|
|
self,
|
|
x=None,
|
|
y=None,
|
|
batch_size=None,
|
|
verbose="auto",
|
|
sample_weight=None,
|
|
steps=None,
|
|
callbacks=None,
|
|
return_dict=False,
|
|
**kwargs,
|
|
):
|
|
self._assert_compile_called("evaluate")
|
|
# TODO: respect compiled trainable state
|
|
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
|
|
if kwargs:
|
|
raise ValueError(f"Arguments not recognized: {kwargs}")
|
|
|
|
if use_cached_eval_dataset:
|
|
epoch_iterator = self._eval_epoch_iterator
|
|
else:
|
|
# Create an iterator that yields batches of input/target data.
|
|
epoch_iterator = EpochIterator(
|
|
x=x,
|
|
y=y,
|
|
sample_weight=sample_weight,
|
|
batch_size=batch_size,
|
|
steps_per_epoch=steps,
|
|
shuffle=False,
|
|
steps_per_execution=self.steps_per_execution,
|
|
)
|
|
|
|
if not self.built:
|
|
# Build the model on one batch of data.
|
|
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
|
|
data_batch = data[0]
|
|
self._eager_build(data_batch)
|
|
break
|
|
|
|
# Container that configures and calls callbacks.
|
|
if not isinstance(callbacks, callbacks_module.CallbackList):
|
|
callbacks = callbacks_module.CallbackList(
|
|
callbacks,
|
|
add_history=True,
|
|
add_progbar=verbose != 0,
|
|
verbose=verbose,
|
|
epochs=1,
|
|
steps=epoch_iterator.num_batches,
|
|
model=self,
|
|
)
|
|
|
|
self.make_test_function()
|
|
callbacks.on_test_begin()
|
|
logs = None
|
|
self.reset_metrics()
|
|
|
|
trainable_variables = self.trainable_variables
|
|
non_trainable_variables = self.non_trainable_variables
|
|
metrics_variables = self.metrics_variables
|
|
|
|
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
|
|
callbacks.on_test_batch_begin(step)
|
|
|
|
state = (
|
|
trainable_variables,
|
|
non_trainable_variables,
|
|
metrics_variables,
|
|
)
|
|
logs, state = self.test_function(state, data)
|
|
# Note that trainable variables are not returned since they're
|
|
# immutable here.
|
|
_, non_trainable_variables, metrics_variables = state
|
|
|
|
# Setting _jax_state enables callbacks to force a state sync
|
|
# if they need to.
|
|
self._jax_state = {
|
|
# I wouldn't recommend modifying non-trainable model state
|
|
# during evaluate(), but it's allowed.
|
|
"non_trainable_variables": non_trainable_variables,
|
|
"metrics_variables": metrics_variables,
|
|
}
|
|
callbacks.on_test_batch_end(step, self._pythonify_logs(logs))
|
|
|
|
# Reattach state back to model.
|
|
self.jax_state_sync()
|
|
|
|
logs = self.get_metrics_result()
|
|
callbacks.on_test_end(logs)
|
|
self._jax_state = None
|
|
if return_dict:
|
|
return logs
|
|
return self._flatten_metrics_in_order(logs)
|
|
|
|
@traceback_utils.filter_traceback
|
|
def predict(
|
|
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
|
|
):
|
|
# Create an iterator that yields batches of input data.
|
|
epoch_iterator = EpochIterator(
|
|
x=x,
|
|
batch_size=batch_size,
|
|
steps_per_epoch=steps,
|
|
shuffle=False,
|
|
steps_per_execution=self.steps_per_execution,
|
|
)
|
|
|
|
if not self.built:
|
|
# Build the model on one batch of data.
|
|
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
|
|
# Build model
|
|
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data[0])
|
|
with backend.StatelessScope():
|
|
self(x)
|
|
break
|
|
|
|
# Container that configures and calls callbacks.
|
|
if not isinstance(callbacks, callbacks_module.CallbackList):
|
|
callbacks = callbacks_module.CallbackList(
|
|
callbacks,
|
|
add_history=True,
|
|
add_progbar=verbose != 0,
|
|
verbose=verbose,
|
|
epochs=1,
|
|
steps=epoch_iterator.num_batches,
|
|
model=self,
|
|
)
|
|
|
|
self.make_predict_function()
|
|
callbacks.on_predict_begin()
|
|
|
|
def append_to_outputs(batch_outputs, outputs):
|
|
if outputs is None:
|
|
outputs = tf.nest.map_structure(
|
|
lambda batch_output: [batch_output],
|
|
batch_outputs,
|
|
)
|
|
else:
|
|
tf.__internal__.nest.map_structure_up_to(
|
|
batch_outputs,
|
|
lambda output, batch_output: output.append(batch_output),
|
|
outputs,
|
|
batch_outputs,
|
|
)
|
|
return outputs
|
|
|
|
trainable_variables = self.trainable_variables
|
|
non_trainable_variables = self.non_trainable_variables
|
|
outputs = None
|
|
for step, x in epoch_iterator.enumerate_epoch(return_type="np"):
|
|
callbacks.on_predict_batch_begin(step)
|
|
batch_outputs = self.predict_function(
|
|
trainable_variables, non_trainable_variables, x
|
|
)
|
|
outputs = append_to_outputs(batch_outputs, outputs)
|
|
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
|
|
callbacks.on_predict_end()
|
|
return tf.__internal__.nest.map_structure_up_to(
|
|
batch_outputs, np.concatenate, outputs
|
|
)
|
|
|
|
def train_on_batch(
|
|
self,
|
|
x,
|
|
y=None,
|
|
sample_weight=None,
|
|
class_weight=None,
|
|
return_dict=False,
|
|
):
|
|
"""Runs a single gradient update on a single batch of data.
|
|
|
|
Args:
|
|
x: Input data. Must be array-like.
|
|
y: Target data. Must be array-like.
|
|
sample_weight: Optional array of the same length as x, containing
|
|
weights to apply to the model's loss for each sample.
|
|
In the case of temporal data, you can pass a 2D array
|
|
with shape `(samples, sequence_length)`, to apply a different
|
|
weight to every timestep of every sample.
|
|
class_weight: Optional dictionary mapping class indices (integers)
|
|
to a weight (float) to apply to the model's loss for the samples
|
|
from this class during training. This can be useful to tell the
|
|
model to "pay more attention" to samples from an
|
|
under-represented class. When `class_weight` is specified
|
|
and targets have a rank of 2 or greater, either `y` must
|
|
be one-hot encoded, or an explicit final dimension of 1
|
|
must be included for sparse class labels.
|
|
return_dict: If `True`, loss and metric results are returned as a
|
|
dict, with each key being the name of the metric. If `False`,
|
|
they are returned as a list.
|
|
|
|
Returns:
|
|
A scalar loss value (when no metrics and `return_dict=False`),
|
|
a list of loss and metric values
|
|
(if there are metrics and `return_dict=False`), or a dict of
|
|
metric and loss values (if `return_dict=True`).
|
|
"""
|
|
self._assert_compile_called("train_on_batch")
|
|
if class_weight is not None:
|
|
if sample_weight is not None:
|
|
raise ValueError(
|
|
"Arguments `sample_weight` and `class_weight` "
|
|
"cannot be specified at the same time. "
|
|
f"Received: sample_weight={sample_weight}, "
|
|
f"class_weight={class_weight}"
|
|
)
|
|
sample_weight = data_adapter_utils.class_weight_to_sample_weights(
|
|
y, class_weight
|
|
)
|
|
data = (x, y, sample_weight)
|
|
|
|
# Maybe build model
|
|
self._eager_build(data)
|
|
self.make_train_function()
|
|
|
|
# Train step
|
|
trainable_variables = self.trainable_variables
|
|
non_trainable_variables = self.non_trainable_variables
|
|
optimizer_variables = self.optimizer.variables
|
|
metrics_variables = self.metrics_variables
|
|
state = (
|
|
trainable_variables,
|
|
non_trainable_variables,
|
|
optimizer_variables,
|
|
metrics_variables,
|
|
)
|
|
logs, state = self.train_function(state, [data])
|
|
|
|
# State sync
|
|
(
|
|
trainable_variables,
|
|
non_trainable_variables,
|
|
optimizer_variables,
|
|
metrics_variables,
|
|
) = state
|
|
self._jax_state = {
|
|
"trainable_variables": trainable_variables,
|
|
"non_trainable_variables": non_trainable_variables,
|
|
"optimizer_variables": optimizer_variables,
|
|
"metrics_variables": metrics_variables,
|
|
}
|
|
self.jax_state_sync()
|
|
|
|
# Format return values
|
|
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
|
|
if return_dict:
|
|
return logs
|
|
return self._flatten_metrics_in_order(logs)
|
|
|
|
def test_on_batch(
|
|
self,
|
|
x,
|
|
y=None,
|
|
sample_weight=None,
|
|
return_dict=False,
|
|
):
|
|
"""Test the model on a single batch of samples.
|
|
|
|
Args:
|
|
x: Input data. Must be array-like.
|
|
y: Target data. Must be array-like.
|
|
sample_weight: Optional array of the same length as x, containing
|
|
weights to apply to the model's loss for each sample.
|
|
In the case of temporal data, you can pass a 2D array
|
|
with shape `(samples, sequence_length)`, to apply a different
|
|
weight to every timestep of every sample.
|
|
return_dict: If `True`, loss and metric results are returned as a
|
|
dict, with each key being the name of the metric. If `False`,
|
|
they are returned as a list.
|
|
|
|
Returns:
|
|
A scalar loss value (when no metrics and `return_dict=False`),
|
|
a list of loss and metric values
|
|
(if there are metrics and `return_dict=False`), or a dict of
|
|
metric and loss values (if `return_dict=True`).
|
|
"""
|
|
self._assert_compile_called("test_on_batch")
|
|
|
|
data = (x, y, sample_weight)
|
|
# Maybe build model
|
|
self._eager_build(data)
|
|
self.make_test_function()
|
|
|
|
# Test step
|
|
trainable_variables = self.trainable_variables
|
|
non_trainable_variables = self.non_trainable_variables
|
|
metrics_variables = self.metrics_variables
|
|
state = (
|
|
trainable_variables,
|
|
non_trainable_variables,
|
|
metrics_variables,
|
|
)
|
|
logs, state = self.test_function(state, [data])
|
|
|
|
# State sync
|
|
_, non_trainable_variables, metrics_variables = state
|
|
self._jax_state = {
|
|
"non_trainable_variables": non_trainable_variables,
|
|
"metrics_variables": metrics_variables,
|
|
}
|
|
self.jax_state_sync()
|
|
|
|
# Format return values.
|
|
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
|
|
if return_dict:
|
|
return logs
|
|
return self._flatten_metrics_in_order(logs)
|
|
|
|
def predict_on_batch(self, x):
|
|
"""Returns predictions for a single batch of samples.
|
|
|
|
Args:
|
|
x: Input data. It must be array-like.
|
|
|
|
Returns:
|
|
NumPy array(s) of predictions.
|
|
"""
|
|
if not self.built:
|
|
# Build model
|
|
with backend.StatelessScope():
|
|
self(x)
|
|
|
|
self.make_predict_function()
|
|
trainable_variables = self.trainable_variables
|
|
non_trainable_variables = self.non_trainable_variables
|
|
batch_outputs = self.predict_function(
|
|
trainable_variables, non_trainable_variables, [x]
|
|
)
|
|
batch_outputs = tf.nest.map_structure(
|
|
lambda x: np.array(x), batch_outputs
|
|
)
|
|
return batch_outputs
|
|
|
|
def jax_state_sync(self):
|
|
if not getattr(self, "_jax_state", None):
|
|
return
|
|
|
|
trainable_variables = self._jax_state.get("trainable_variables", None)
|
|
non_trainable_variables = self._jax_state.get(
|
|
"non_trainable_variables", None
|
|
)
|
|
optimizer_variables = self._jax_state.get("optimizer_variables", None)
|
|
metrics_variables = self._jax_state.get("metrics_variables", None)
|
|
if trainable_variables:
|
|
for ref_v, v in zip(self.trainable_variables, trainable_variables):
|
|
ref_v.assign(v)
|
|
if non_trainable_variables:
|
|
for ref_v, v in zip(
|
|
self.non_trainable_variables, non_trainable_variables
|
|
):
|
|
ref_v.assign(v)
|
|
if optimizer_variables:
|
|
for ref_v, v in zip(self.optimizer.variables, optimizer_variables):
|
|
ref_v.assign(v)
|
|
if metrics_variables:
|
|
for ref_v, v in zip(self.metrics_variables, metrics_variables):
|
|
ref_v.assign(v)
|