diff --git a/keras_core/backend/tensorflow/trainer.py b/keras_core/backend/tensorflow/trainer.py index 7f115897d..ad6a041bc 100644 --- a/keras_core/backend/tensorflow/trainer.py +++ b/keras_core/backend/tensorflow/trainer.py @@ -114,12 +114,15 @@ class TensorFlowTrainer(base_trainer.Trainer): ) return outputs + def train_function(iterator): + """Runs a training execution with multiple steps.""" + for _ in tf.range(self.steps_per_execution): + outputs = one_step_on_iterator(iterator) + return outputs + if not self.run_eagerly: - train_function = tf.function( - one_step_on_iterator, reduce_retracing=True - ) - else: - train_function = one_step_on_iterator + train_function = tf.function(train_function, reduce_retracing=True) + self.train_function = train_function def make_test_function(self, force=False): @@ -244,6 +247,7 @@ class TensorFlowTrainer(base_trainer.Trainer): shuffle=shuffle, class_weight=class_weight, distribute_strategy=self.distribute_strategy, + steps_per_execution=self.steps_per_execution, ) # Container that configures and calls callbacks. @@ -445,7 +449,9 @@ class TFEpochIterator(EpochIterator): self.data_adapter.get_tf_dataset() ) ) - for step in range(self.steps_per_epoch): + for step in range( + 0, self.steps_per_epoch, self.steps_per_execution + ): yield step, self._current_iterator else: iterator = iter( @@ -454,12 +460,14 @@ class TFEpochIterator(EpochIterator): ) ) if self.num_batches: - for step in range(self.num_batches): + for step in range( + 0, self.num_batches, self.steps_per_execution + ): yield step, iterator else: step = -1 while True: - step += 1 + step += self.steps_per_execution self._steps_seen = step + 1 yield step, iterator self.data_adapter.on_epoch_end() diff --git a/keras_core/trainers/epoch_iterator.py b/keras_core/trainers/epoch_iterator.py index 8a9bc8d81..535c359c3 100644 --- a/keras_core/trainers/epoch_iterator.py +++ b/keras_core/trainers/epoch_iterator.py @@ -60,8 +60,10 @@ class EpochIterator: steps_per_epoch=None, shuffle=False, class_weight=None, + steps_per_execution=1, ): self.steps_per_epoch = steps_per_epoch + self.steps_per_execution = steps_per_execution if steps_per_epoch: self._current_iterator = None self._insufficient_data = False diff --git a/keras_core/trainers/trainer.py b/keras_core/trainers/trainer.py index 20784fef1..13ed94896 100644 --- a/keras_core/trainers/trainer.py +++ b/keras_core/trainers/trainer.py @@ -26,6 +26,7 @@ class Trainer: metrics=None, weighted_metrics=None, run_eagerly=False, + steps_per_execution=1, jit_compile=True, ): self.optimizer = optimizers.get(optimizer) @@ -49,6 +50,7 @@ class Trainer: self.stop_training = False self.compiled = True self._loss_tracker = metrics_module.Mean(name="loss") + self.steps_per_execution = steps_per_execution self._compile_config = serialization_lib.SerializableDict( optimizer=optimizer, @@ -57,6 +59,7 @@ class Trainer: metrics=metrics, weighted_metrics=weighted_metrics, run_eagerly=run_eagerly, + steps_per_execution=steps_per_execution, jit_compile=jit_compile, ) diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index 0b09a32ea..59064d324 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import backend from keras_core import initializers @@ -7,6 +8,7 @@ from keras_core import losses from keras_core import metrics from keras_core import optimizers from keras_core import testing +from keras_core.callbacks.callback import Callback if backend.backend() == "jax": from keras_core.backend.jax.trainer import JAXTrainer as Trainer @@ -213,3 +215,32 @@ class TestTrainer(testing.TestCase): def test_predict_flow_jit(self): self._test_predict_flow(run_eagerly=False, jit_compile=True) + + # TODO: Remove the skipif when implemented steps_per_execution for JAX. + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="JAX does not support steps_per_execution yet", + ) + def test_steps_per_execution_steps_count(self): + class StepCount(Callback): + def __init__(self): + super().__init__() + self.count = 0 + self.batches = [0, 3, 6] + + def on_batch_begin(self, batch, logs=None): + assert batch == self.batches[self.count] + self.count += 1 + + x = np.ones((100, 4)) + y = np.ones((100, 1)) + model = ExampleModel(units=1) + model.compile(loss="mse", optimizer="adam", steps_per_execution=3) + step_count = StepCount() + model.fit(x=x, y=y, batch_size=16, callbacks=[step_count]) + self.assertEqual(step_count.count, 3) + + model_2 = ExampleModel(units=1) + model_2.compile(loss="mse", optimizer="adam", steps_per_execution=1) + model_2.fit(x=x, y=y, batch_size=16) + self.assertAllClose(model.get_weights(), model_2.get_weights())