Add steps_per_execution
to TensorFlow backend (#153)
* Add `steps_per_execution` to TensorFlow backend * add test * fix the test --------- Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
This commit is contained in:
parent
6f2d44e3ef
commit
77b4fcc3dc
@ -114,12 +114,15 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
)
|
||||
return outputs
|
||||
|
||||
def train_function(iterator):
|
||||
"""Runs a training execution with multiple steps."""
|
||||
for _ in tf.range(self.steps_per_execution):
|
||||
outputs = one_step_on_iterator(iterator)
|
||||
return outputs
|
||||
|
||||
if not self.run_eagerly:
|
||||
train_function = tf.function(
|
||||
one_step_on_iterator, reduce_retracing=True
|
||||
)
|
||||
else:
|
||||
train_function = one_step_on_iterator
|
||||
train_function = tf.function(train_function, reduce_retracing=True)
|
||||
|
||||
self.train_function = train_function
|
||||
|
||||
def make_test_function(self, force=False):
|
||||
@ -244,6 +247,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
shuffle=shuffle,
|
||||
class_weight=class_weight,
|
||||
distribute_strategy=self.distribute_strategy,
|
||||
steps_per_execution=self.steps_per_execution,
|
||||
)
|
||||
|
||||
# Container that configures and calls callbacks.
|
||||
@ -445,7 +449,9 @@ class TFEpochIterator(EpochIterator):
|
||||
self.data_adapter.get_tf_dataset()
|
||||
)
|
||||
)
|
||||
for step in range(self.steps_per_epoch):
|
||||
for step in range(
|
||||
0, self.steps_per_epoch, self.steps_per_execution
|
||||
):
|
||||
yield step, self._current_iterator
|
||||
else:
|
||||
iterator = iter(
|
||||
@ -454,12 +460,14 @@ class TFEpochIterator(EpochIterator):
|
||||
)
|
||||
)
|
||||
if self.num_batches:
|
||||
for step in range(self.num_batches):
|
||||
for step in range(
|
||||
0, self.num_batches, self.steps_per_execution
|
||||
):
|
||||
yield step, iterator
|
||||
else:
|
||||
step = -1
|
||||
while True:
|
||||
step += 1
|
||||
step += self.steps_per_execution
|
||||
self._steps_seen = step + 1
|
||||
yield step, iterator
|
||||
self.data_adapter.on_epoch_end()
|
||||
|
@ -60,8 +60,10 @@ class EpochIterator:
|
||||
steps_per_epoch=None,
|
||||
shuffle=False,
|
||||
class_weight=None,
|
||||
steps_per_execution=1,
|
||||
):
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.steps_per_execution = steps_per_execution
|
||||
if steps_per_epoch:
|
||||
self._current_iterator = None
|
||||
self._insufficient_data = False
|
||||
|
@ -26,6 +26,7 @@ class Trainer:
|
||||
metrics=None,
|
||||
weighted_metrics=None,
|
||||
run_eagerly=False,
|
||||
steps_per_execution=1,
|
||||
jit_compile=True,
|
||||
):
|
||||
self.optimizer = optimizers.get(optimizer)
|
||||
@ -49,6 +50,7 @@ class Trainer:
|
||||
self.stop_training = False
|
||||
self.compiled = True
|
||||
self._loss_tracker = metrics_module.Mean(name="loss")
|
||||
self.steps_per_execution = steps_per_execution
|
||||
|
||||
self._compile_config = serialization_lib.SerializableDict(
|
||||
optimizer=optimizer,
|
||||
@ -57,6 +59,7 @@ class Trainer:
|
||||
metrics=metrics,
|
||||
weighted_metrics=weighted_metrics,
|
||||
run_eagerly=run_eagerly,
|
||||
steps_per_execution=steps_per_execution,
|
||||
jit_compile=jit_compile,
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import initializers
|
||||
@ -7,6 +8,7 @@ from keras_core import losses
|
||||
from keras_core import metrics
|
||||
from keras_core import optimizers
|
||||
from keras_core import testing
|
||||
from keras_core.callbacks.callback import Callback
|
||||
|
||||
if backend.backend() == "jax":
|
||||
from keras_core.backend.jax.trainer import JAXTrainer as Trainer
|
||||
@ -213,3 +215,32 @@ class TestTrainer(testing.TestCase):
|
||||
|
||||
def test_predict_flow_jit(self):
|
||||
self._test_predict_flow(run_eagerly=False, jit_compile=True)
|
||||
|
||||
# TODO: Remove the skipif when implemented steps_per_execution for JAX.
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() != "tensorflow",
|
||||
reason="JAX does not support steps_per_execution yet",
|
||||
)
|
||||
def test_steps_per_execution_steps_count(self):
|
||||
class StepCount(Callback):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.count = 0
|
||||
self.batches = [0, 3, 6]
|
||||
|
||||
def on_batch_begin(self, batch, logs=None):
|
||||
assert batch == self.batches[self.count]
|
||||
self.count += 1
|
||||
|
||||
x = np.ones((100, 4))
|
||||
y = np.ones((100, 1))
|
||||
model = ExampleModel(units=1)
|
||||
model.compile(loss="mse", optimizer="adam", steps_per_execution=3)
|
||||
step_count = StepCount()
|
||||
model.fit(x=x, y=y, batch_size=16, callbacks=[step_count])
|
||||
self.assertEqual(step_count.count, 3)
|
||||
|
||||
model_2 = ExampleModel(units=1)
|
||||
model_2.compile(loss="mse", optimizer="adam", steps_per_execution=1)
|
||||
model_2.fit(x=x, y=y, batch_size=16)
|
||||
self.assertAllClose(model.get_weights(), model_2.get_weights())
|
||||
|
Loading…
Reference in New Issue
Block a user