From aa4078899d4e4b75bc759aa4932716afc3f714eb Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 19 Apr 2023 16:25:56 -0700 Subject: [PATCH] Add TF evaluate() flow. --- keras_core/backend/jax/trainer.py | 5 +- keras_core/backend/tensorflow/__init__.py | 2 +- keras_core/backend/tensorflow/trainer.py | 139 +++++++++++++++++++++- keras_core/trainers/trainer.py | 23 +--- keras_core/trainers/trainer_test.py | 44 +++++-- 5 files changed, 176 insertions(+), 37 deletions(-) diff --git a/keras_core/backend/jax/trainer.py b/keras_core/backend/jax/trainer.py index 4f8e443ba..8a9c1ac5d 100644 --- a/keras_core/backend/jax/trainer.py +++ b/keras_core/backend/jax/trainer.py @@ -9,9 +9,10 @@ from keras_core.trainers.epoch_iterator import EpochIterator class Trainer(base_trainer.Trainer): - def stateless_compute_loss_and_updates( + def compute_loss_and_updates( self, trainable_variables, non_trainable_variables, x, y, sample_weight ): + """This method is stateless and is intended for use with jax.grad.""" y_pred, non_trainable_variables = self.stateless_call( trainable_variables, non_trainable_variables, x ) @@ -104,7 +105,7 @@ class Trainer(base_trainer.Trainer): ) grad_fn = jax.value_and_grad( - self.stateless_compute_loss_and_updates, has_aux=True + self.compute_loss_and_updates, has_aux=True ) def _train_step(state, data): diff --git a/keras_core/backend/tensorflow/__init__.py b/keras_core/backend/tensorflow/__init__.py index 4ae063dd5..04c20e38e 100644 --- a/keras_core/backend/tensorflow/__init__.py +++ b/keras_core/backend/tensorflow/__init__.py @@ -13,7 +13,7 @@ from keras_core.utils.naming import auto_name DYNAMIC_SHAPES_OK = True -class Variable(KerasVariable): +class Variable(KerasVariable, tf.__internal__.types.Tensor): def __init__(self, value, dtype=None, trainable=True, name=None): self.name = name or auto_name(self.__class__.__name__) dtype = standardize_dtype(dtype) diff --git a/keras_core/backend/tensorflow/trainer.py b/keras_core/backend/tensorflow/trainer.py index 96ff37b85..51d442c84 100644 --- a/keras_core/backend/tensorflow/trainer.py +++ b/keras_core/backend/tensorflow/trainer.py @@ -12,6 +12,12 @@ from keras_core.trainers.epoch_iterator import EpochIterator class Trainer(base_trainer.Trainer): + def __init__(self): + super().__init__() + self.train_function = None + self.test_function = None + self.predict_function = None + def train_step(self, data): x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) @@ -36,10 +42,24 @@ class Trainer(base_trainer.Trainer): return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) def test_step(self, data): - raise NotImplementedError + x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) + if self._call_has_training_arg(): + y_pred = self(x, training=False) + else: + y_pred = self(x) + loss = self.compute_loss( + x=x, y=y, y_pred=y_pred, sample_weight=sample_weight + ) + self._loss_tracker.update_state(loss) + return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) def predict_step(self, data): - raise NotImplementedError + x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) + if self._call_has_training_arg(): + y_pred = self(x, training=False) + else: + y_pred = self(x) + return y_pred def make_train_function(self, force=False): # TODO: support tf.distribute and steps_per_execution. @@ -47,7 +67,7 @@ class Trainer(base_trainer.Trainer): return self.train_function def one_step_on_data(data): - """Runs a single training step.""" + """Runs a single training step on a batch of data.""" return self.train_step(data) if not self.run_eagerly and self.jit_compile: @@ -56,6 +76,7 @@ class Trainer(base_trainer.Trainer): ) def one_step_on_iterator(iterator): + """Runs a single training step given a Dataset iterator.""" data = next(iterator) return one_step_on_data(data) @@ -68,10 +89,58 @@ class Trainer(base_trainer.Trainer): self.train_function = train_function def make_test_function(self, force=False): - raise NotImplementedError + # TODO: support tf.distribute and steps_per_execution. + if self.test_function is not None and not force: + return self.test_function + + def one_step_on_data(data): + """Runs a single test step on a batch of data.""" + return self.test_step(data) + + if not self.run_eagerly and self.jit_compile: + one_step_on_data = tf.function( + one_step_on_data, jit_compile=True, reduce_retracing=True + ) + + def one_step_on_iterator(iterator): + """Runs a single test step given a Dataset iterator.""" + data = next(iterator) + return one_step_on_data(data) + + if not self.run_eagerly: + test_function = tf.function( + one_step_on_iterator, reduce_retracing=True + ) + else: + test_function = one_step_on_iterator + self.test_function = test_function def make_predict_function(self, force=False): - raise NotImplementedError + # TODO: support tf.distribute and steps_per_execution. + if self.predict_function is not None and not force: + return self.predict_function + + def one_step_on_data(data): + """Runs a predict test step on a batch of data.""" + return self.predict_step(data) + + if not self.run_eagerly and self.jit_compile: + one_step_on_data = tf.function( + one_step_on_data, jit_compile=True, reduce_retracing=True + ) + + def one_step_on_iterator(iterator): + """Runs a single predict step given a Dataset iterator.""" + data = next(iterator) + return one_step_on_data(data) + + if not self.run_eagerly: + predict_function = tf.function( + one_step_on_iterator, reduce_retracing=True + ) + else: + predict_function = one_step_on_iterator + self.predict_function = predict_function def fit( self, @@ -212,13 +281,71 @@ class Trainer(base_trainer.Trainer): return_dict=False, **kwargs, ): - raise NotImplementedError + # TODO: respect compiled trainable state + use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) + if kwargs: + raise ValueError(f"Arguments not recognized: {kwargs}") + + if use_cached_eval_dataset: + epoch_iterator = self._eval_epoch_iterator + else: + # Create an iterator that yields batches for one epoch. + epoch_iterator = TFEpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + ) + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_history=True, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + self.make_test_function() + callbacks.on_test_begin() + logs = None + self.reset_metrics() + with epoch_iterator.catch_stop_iteration(): + for step, iterator in epoch_iterator.enumerate_epoch(): + callbacks.on_test_batch_begin(step) + logs = self.test_function(iterator) + callbacks.on_test_batch_end(step, logs) + logs = self._process_logs(self.get_metrics_result()) + callbacks.on_test_end(logs) + + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) def predict( self, x, batch_size=None, verbose="auto", steps=None, callbacks=None ): raise NotImplementedError + def _flatten_metrics_in_order(self, logs): + """Turns the `logs` dict into a list as per key order of `metrics_names`.""" + metric_names = [m.name for m in self.metrics] + results = [] + for name in metric_names: + if name in logs: + results.append(logs[name]) + for key in sorted(logs.keys()): + if key not in metric_names: + results.append(logs[key]) + if len(results) == 1: + return results[0] + return results + class TFEpochIterator(EpochIterator): def __init__(self, *args, **kwargs): diff --git a/keras_core/trainers/trainer.py b/keras_core/trainers/trainer.py index 178cbc2d0..ec06fd025 100644 --- a/keras_core/trainers/trainer.py +++ b/keras_core/trainers/trainer.py @@ -5,6 +5,7 @@ from keras_core import metrics as metrics_module from keras_core import operations as ops from keras_core.trainers.compile_utils import CompileLoss from keras_core.trainers.compile_utils import CompileMetrics +from keras_core.utils import tracking class Trainer: @@ -12,10 +13,8 @@ class Trainer: self._run_eagerly = False self._jit_compile = True self.compiled = False - self.train_function = None - self.test_function = None - self.predict_function = None + @tracking.no_automatic_dependency_tracking def compile( self, optimizer, @@ -209,24 +208,6 @@ class Trainer: return_metrics[metric.name] = result return return_metrics - def train_step(self, data): - raise NotImplementedError - - def test_step(self, data): - raise NotImplementedError - - def predict_step(self, data): - raise NotImplementedError - - def make_train_function(self): - raise NotImplementedError - - def make_test_function(self): - raise NotImplementedError - - def make_predict_function(self): - raise NotImplementedError - def fit( self, x=None, diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index 498f8289a..114cf029d 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -27,7 +27,7 @@ class ExampleModel(layers.Dense, Trainer): class TestTrainer(testing.TestCase): - def _test_basic_flow(self, run_eagerly, jit_compile): + def _test_fit_flow(self, run_eagerly, jit_compile): model = ExampleModel(units=3) x = np.ones((100, 4)) y = np.zeros((100, 3)) @@ -49,11 +49,41 @@ class TestTrainer(testing.TestCase): history["mean_squared_error"], [13.938, 9.547, 6.539], atol=1e-2 ) - def test_basic_flow_eager(self): - self._test_basic_flow(run_eagerly=True, jit_compile=False) + def test_fit_flow_eager(self): + self._test_fit_flow(run_eagerly=True, jit_compile=False) - def test_basic_flow_graph_fn(self): - self._test_basic_flow(run_eagerly=False, jit_compile=False) + def test_fit_flow_graph_fn(self): + self._test_fit_flow(run_eagerly=False, jit_compile=False) - def test_basic_flow_jit(self): - self._test_basic_flow(run_eagerly=False, jit_compile=True) + def test_fit_flow_jit(self): + self._test_fit_flow(run_eagerly=False, jit_compile=True) + + def _test_evaluate_flow(self, run_eagerly, jit_compile): + model = ExampleModel(units=3) + x = np.ones((100, 4)) + y = np.zeros((100, 3)) + batch_size = 16 + + model.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + run_eagerly=run_eagerly, + jit_compile=jit_compile, + ) + output = model.evaluate(x, y, batch_size=batch_size) + self.assertAllClose(output, [16.0, 16.0]) + output = model.evaluate(x, y, batch_size=batch_size, return_dict=True) + self.assertTrue(isinstance(output, dict)) + self.assertIn("loss", output) + self.assertIn("mean_squared_error", output) + self.assertAllClose(output["mean_squared_error"], 16.0) + + def test_evaluate_flow_eager(self): + self._test_evaluate_flow(run_eagerly=True, jit_compile=False) + + def test_evaluate_flow_graph_fn(self): + self._test_evaluate_flow(run_eagerly=False, jit_compile=False) + + def test_evaluate_flow_jit(self): + self._test_evaluate_flow(run_eagerly=False, jit_compile=True)