Add TF evaluate() flow.

This commit is contained in:
Francois Chollet 2023-04-19 16:25:56 -07:00
parent ee5be68ce9
commit aa4078899d
5 changed files with 176 additions and 37 deletions

@ -9,9 +9,10 @@ from keras_core.trainers.epoch_iterator import EpochIterator
class Trainer(base_trainer.Trainer):
def stateless_compute_loss_and_updates(
def compute_loss_and_updates(
self, trainable_variables, non_trainable_variables, x, y, sample_weight
):
"""This method is stateless and is intended for use with jax.grad."""
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x
)
@ -104,7 +105,7 @@ class Trainer(base_trainer.Trainer):
)
grad_fn = jax.value_and_grad(
self.stateless_compute_loss_and_updates, has_aux=True
self.compute_loss_and_updates, has_aux=True
)
def _train_step(state, data):

@ -13,7 +13,7 @@ from keras_core.utils.naming import auto_name
DYNAMIC_SHAPES_OK = True
class Variable(KerasVariable):
class Variable(KerasVariable, tf.__internal__.types.Tensor):
def __init__(self, value, dtype=None, trainable=True, name=None):
self.name = name or auto_name(self.__class__.__name__)
dtype = standardize_dtype(dtype)

@ -12,6 +12,12 @@ from keras_core.trainers.epoch_iterator import EpochIterator
class Trainer(base_trainer.Trainer):
def __init__(self):
super().__init__()
self.train_function = None
self.test_function = None
self.predict_function = None
def train_step(self, data):
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
@ -36,10 +42,24 @@ class Trainer(base_trainer.Trainer):
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
def test_step(self, data):
raise NotImplementedError
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
if self._call_has_training_arg():
y_pred = self(x, training=False)
else:
y_pred = self(x)
loss = self.compute_loss(
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
)
self._loss_tracker.update_state(loss)
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
def predict_step(self, data):
raise NotImplementedError
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)
if self._call_has_training_arg():
y_pred = self(x, training=False)
else:
y_pred = self(x)
return y_pred
def make_train_function(self, force=False):
# TODO: support tf.distribute and steps_per_execution.
@ -47,7 +67,7 @@ class Trainer(base_trainer.Trainer):
return self.train_function
def one_step_on_data(data):
"""Runs a single training step."""
"""Runs a single training step on a batch of data."""
return self.train_step(data)
if not self.run_eagerly and self.jit_compile:
@ -56,6 +76,7 @@ class Trainer(base_trainer.Trainer):
)
def one_step_on_iterator(iterator):
"""Runs a single training step given a Dataset iterator."""
data = next(iterator)
return one_step_on_data(data)
@ -68,10 +89,58 @@ class Trainer(base_trainer.Trainer):
self.train_function = train_function
def make_test_function(self, force=False):
raise NotImplementedError
# TODO: support tf.distribute and steps_per_execution.
if self.test_function is not None and not force:
return self.test_function
def one_step_on_data(data):
"""Runs a single test step on a batch of data."""
return self.test_step(data)
if not self.run_eagerly and self.jit_compile:
one_step_on_data = tf.function(
one_step_on_data, jit_compile=True, reduce_retracing=True
)
def one_step_on_iterator(iterator):
"""Runs a single test step given a Dataset iterator."""
data = next(iterator)
return one_step_on_data(data)
if not self.run_eagerly:
test_function = tf.function(
one_step_on_iterator, reduce_retracing=True
)
else:
test_function = one_step_on_iterator
self.test_function = test_function
def make_predict_function(self, force=False):
raise NotImplementedError
# TODO: support tf.distribute and steps_per_execution.
if self.predict_function is not None and not force:
return self.predict_function
def one_step_on_data(data):
"""Runs a predict test step on a batch of data."""
return self.predict_step(data)
if not self.run_eagerly and self.jit_compile:
one_step_on_data = tf.function(
one_step_on_data, jit_compile=True, reduce_retracing=True
)
def one_step_on_iterator(iterator):
"""Runs a single predict step given a Dataset iterator."""
data = next(iterator)
return one_step_on_data(data)
if not self.run_eagerly:
predict_function = tf.function(
one_step_on_iterator, reduce_retracing=True
)
else:
predict_function = one_step_on_iterator
self.predict_function = predict_function
def fit(
self,
@ -212,13 +281,71 @@ class Trainer(base_trainer.Trainer):
return_dict=False,
**kwargs,
):
raise NotImplementedError
# TODO: respect compiled trainable state
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
if kwargs:
raise ValueError(f"Arguments not recognized: {kwargs}")
if use_cached_eval_dataset:
epoch_iterator = self._eval_epoch_iterator
else:
# Create an iterator that yields batches for one epoch.
epoch_iterator = TFEpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
)
# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
verbose=verbose,
epochs=1,
steps=epoch_iterator.num_batches,
model=self,
)
self.make_test_function()
callbacks.on_test_begin()
logs = None
self.reset_metrics()
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator.enumerate_epoch():
callbacks.on_test_batch_begin(step)
logs = self.test_function(iterator)
callbacks.on_test_batch_end(step, logs)
logs = self._process_logs(self.get_metrics_result())
callbacks.on_test_end(logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
):
raise NotImplementedError
def _flatten_metrics_in_order(self, logs):
"""Turns the `logs` dict into a list as per key order of `metrics_names`."""
metric_names = [m.name for m in self.metrics]
results = []
for name in metric_names:
if name in logs:
results.append(logs[name])
for key in sorted(logs.keys()):
if key not in metric_names:
results.append(logs[key])
if len(results) == 1:
return results[0]
return results
class TFEpochIterator(EpochIterator):
def __init__(self, *args, **kwargs):

@ -5,6 +5,7 @@ from keras_core import metrics as metrics_module
from keras_core import operations as ops
from keras_core.trainers.compile_utils import CompileLoss
from keras_core.trainers.compile_utils import CompileMetrics
from keras_core.utils import tracking
class Trainer:
@ -12,10 +13,8 @@ class Trainer:
self._run_eagerly = False
self._jit_compile = True
self.compiled = False
self.train_function = None
self.test_function = None
self.predict_function = None
@tracking.no_automatic_dependency_tracking
def compile(
self,
optimizer,
@ -209,24 +208,6 @@ class Trainer:
return_metrics[metric.name] = result
return return_metrics
def train_step(self, data):
raise NotImplementedError
def test_step(self, data):
raise NotImplementedError
def predict_step(self, data):
raise NotImplementedError
def make_train_function(self):
raise NotImplementedError
def make_test_function(self):
raise NotImplementedError
def make_predict_function(self):
raise NotImplementedError
def fit(
self,
x=None,

@ -27,7 +27,7 @@ class ExampleModel(layers.Dense, Trainer):
class TestTrainer(testing.TestCase):
def _test_basic_flow(self, run_eagerly, jit_compile):
def _test_fit_flow(self, run_eagerly, jit_compile):
model = ExampleModel(units=3)
x = np.ones((100, 4))
y = np.zeros((100, 3))
@ -49,11 +49,41 @@ class TestTrainer(testing.TestCase):
history["mean_squared_error"], [13.938, 9.547, 6.539], atol=1e-2
)
def test_basic_flow_eager(self):
self._test_basic_flow(run_eagerly=True, jit_compile=False)
def test_fit_flow_eager(self):
self._test_fit_flow(run_eagerly=True, jit_compile=False)
def test_basic_flow_graph_fn(self):
self._test_basic_flow(run_eagerly=False, jit_compile=False)
def test_fit_flow_graph_fn(self):
self._test_fit_flow(run_eagerly=False, jit_compile=False)
def test_basic_flow_jit(self):
self._test_basic_flow(run_eagerly=False, jit_compile=True)
def test_fit_flow_jit(self):
self._test_fit_flow(run_eagerly=False, jit_compile=True)
def _test_evaluate_flow(self, run_eagerly, jit_compile):
model = ExampleModel(units=3)
x = np.ones((100, 4))
y = np.zeros((100, 3))
batch_size = 16
model.compile(
optimizer=optimizers.SGD(),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
run_eagerly=run_eagerly,
jit_compile=jit_compile,
)
output = model.evaluate(x, y, batch_size=batch_size)
self.assertAllClose(output, [16.0, 16.0])
output = model.evaluate(x, y, batch_size=batch_size, return_dict=True)
self.assertTrue(isinstance(output, dict))
self.assertIn("loss", output)
self.assertIn("mean_squared_error", output)
self.assertAllClose(output["mean_squared_error"], 16.0)
def test_evaluate_flow_eager(self):
self._test_evaluate_flow(run_eagerly=True, jit_compile=False)
def test_evaluate_flow_graph_fn(self):
self._test_evaluate_flow(run_eagerly=False, jit_compile=False)
def test_evaluate_flow_jit(self):
self._test_evaluate_flow(run_eagerly=False, jit_compile=True)