Add steps_per_execution to TensorFlow backend (#153)

* Add `steps_per_execution` to TensorFlow backend

* add test

* fix the test

---------

Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
This commit is contained in:
Haifeng Jin 2023-05-12 14:57:32 -07:00 committed by Francois Chollet
parent 6f2d44e3ef
commit 77b4fcc3dc
4 changed files with 52 additions and 8 deletions

@ -114,12 +114,15 @@ class TensorFlowTrainer(base_trainer.Trainer):
)
return outputs
def train_function(iterator):
"""Runs a training execution with multiple steps."""
for _ in tf.range(self.steps_per_execution):
outputs = one_step_on_iterator(iterator)
return outputs
if not self.run_eagerly:
train_function = tf.function(
one_step_on_iterator, reduce_retracing=True
)
else:
train_function = one_step_on_iterator
train_function = tf.function(train_function, reduce_retracing=True)
self.train_function = train_function
def make_test_function(self, force=False):
@ -244,6 +247,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
shuffle=shuffle,
class_weight=class_weight,
distribute_strategy=self.distribute_strategy,
steps_per_execution=self.steps_per_execution,
)
# Container that configures and calls callbacks.
@ -445,7 +449,9 @@ class TFEpochIterator(EpochIterator):
self.data_adapter.get_tf_dataset()
)
)
for step in range(self.steps_per_epoch):
for step in range(
0, self.steps_per_epoch, self.steps_per_execution
):
yield step, self._current_iterator
else:
iterator = iter(
@ -454,12 +460,14 @@ class TFEpochIterator(EpochIterator):
)
)
if self.num_batches:
for step in range(self.num_batches):
for step in range(
0, self.num_batches, self.steps_per_execution
):
yield step, iterator
else:
step = -1
while True:
step += 1
step += self.steps_per_execution
self._steps_seen = step + 1
yield step, iterator
self.data_adapter.on_epoch_end()

@ -60,8 +60,10 @@ class EpochIterator:
steps_per_epoch=None,
shuffle=False,
class_weight=None,
steps_per_execution=1,
):
self.steps_per_epoch = steps_per_epoch
self.steps_per_execution = steps_per_execution
if steps_per_epoch:
self._current_iterator = None
self._insufficient_data = False

@ -26,6 +26,7 @@ class Trainer:
metrics=None,
weighted_metrics=None,
run_eagerly=False,
steps_per_execution=1,
jit_compile=True,
):
self.optimizer = optimizers.get(optimizer)
@ -49,6 +50,7 @@ class Trainer:
self.stop_training = False
self.compiled = True
self._loss_tracker = metrics_module.Mean(name="loss")
self.steps_per_execution = steps_per_execution
self._compile_config = serialization_lib.SerializableDict(
optimizer=optimizer,
@ -57,6 +59,7 @@ class Trainer:
metrics=metrics,
weighted_metrics=weighted_metrics,
run_eagerly=run_eagerly,
steps_per_execution=steps_per_execution,
jit_compile=jit_compile,
)

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import initializers
@ -7,6 +8,7 @@ from keras_core import losses
from keras_core import metrics
from keras_core import optimizers
from keras_core import testing
from keras_core.callbacks.callback import Callback
if backend.backend() == "jax":
from keras_core.backend.jax.trainer import JAXTrainer as Trainer
@ -213,3 +215,32 @@ class TestTrainer(testing.TestCase):
def test_predict_flow_jit(self):
self._test_predict_flow(run_eagerly=False, jit_compile=True)
# TODO: Remove the skipif when implemented steps_per_execution for JAX.
@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="JAX does not support steps_per_execution yet",
)
def test_steps_per_execution_steps_count(self):
class StepCount(Callback):
def __init__(self):
super().__init__()
self.count = 0
self.batches = [0, 3, 6]
def on_batch_begin(self, batch, logs=None):
assert batch == self.batches[self.count]
self.count += 1
x = np.ones((100, 4))
y = np.ones((100, 1))
model = ExampleModel(units=1)
model.compile(loss="mse", optimizer="adam", steps_per_execution=3)
step_count = StepCount()
model.fit(x=x, y=y, batch_size=16, callbacks=[step_count])
self.assertEqual(step_count.count, 3)
model_2 = ExampleModel(units=1)
model_2.compile(loss="mse", optimizer="adam", steps_per_execution=1)
model_2.fit(x=x, y=y, batch_size=16)
self.assertAllClose(model.get_weights(), model_2.get_weights())