Switch TF trainer to a setup that enables future step fusing

This commit is contained in:
Francois Chollet 2023-04-19 11:09:43 -07:00
parent bbcd6eee8a
commit 0b8dd458b4
3 changed files with 100 additions and 28 deletions

@ -1,4 +1,8 @@
import contextlib
import warnings
import tensorflow as tf
from tensorflow.python.eager import context as tf_context
from keras_core import callbacks as callbacks_module
from keras_core import optimizers as optimizers_module
@ -41,19 +45,25 @@ class Trainer(base_trainer.Trainer):
if self.train_function is not None and not force:
return self.train_function
def step_function(data):
def one_step_on_data(data):
"""Runs a single training step."""
return self.train_step(data)
if self.jit_compile:
train_function = tf.function(
step_function, jit_compile=True, reduce_retracing=True
if not self.run_eagerly and self.jit_compile:
one_step_on_data = tf.function(
one_step_on_data, jit_compile=True, reduce_retracing=True
)
elif not self.run_eagerly:
train_function = tf.function(step_function, reduce_retracing=True)
else:
train_function = step_function
def one_step_on_iterator(iterator):
data = next(iterator)
return one_step_on_data(data)
if not self.run_eagerly:
train_function = tf.function(
one_step_on_iterator, reduce_retracing=True
)
else:
train_function = one_step_on_iterator
self.train_function = train_function
def make_test_function(self, force=False):
@ -105,7 +115,7 @@ class Trainer(base_trainer.Trainer):
) = data_adapter_utils.unpack_x_y_sample_weight(validation_data)
# Create an iterator that yields batches for one epoch.
epoch_iterator = EpochIterator(
epoch_iterator = TFEpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
@ -135,12 +145,13 @@ class Trainer(base_trainer.Trainer):
for epoch in range(initial_epoch, epochs):
self.reset_metrics()
callbacks.on_epoch_begin(epoch)
for step, batch in epoch_iterator.enumerate_epoch(return_type="tf"):
callbacks.on_train_batch_begin(step)
logs = self.train_function(batch)
callbacks.on_train_batch_end(step, logs)
if self.stop_training:
break
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator.enumerate_epoch():
callbacks.on_train_batch_begin(step)
logs = self.train_function(iterator)
callbacks.on_train_batch_end(step, logs)
if self.stop_training:
break
# Override with model metrics instead of last step logs
epoch_logs = self._process_logs(self.get_metrics_result())
@ -216,3 +227,53 @@ class Trainer(base_trainer.Trainer):
pass
result[key] = value
return result
class TFEpochIterator(EpochIterator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._steps_seen = 0
def enumerate_epoch(self):
if self.steps_per_epoch:
if not self._current_iterator:
self._current_iterator = iter(
self.data_adapter.get_tf_dataset()
)
for step in range(self.steps_per_epoch):
yield step, self._current_iterator
else:
iterator = iter(self.data_adapter.get_tf_dataset())
if self.num_batches:
for step in range(self.num_batches):
yield step, iterator
else:
step = -1
while True:
step += 1
self._steps_seen = step + 1
yield step, iterator
self.data_adapter.on_epoch_end()
def tf_sync(self):
tf_context.async_wait()
@contextlib.contextmanager
def catch_stop_iteration(self):
"""Catches errors when an iterator runs out of data."""
try:
yield
self.tf_sync()
except (StopIteration, tf.errors.OutOfRangeError):
if self._num_batches is None:
self._num_batches = self._steps_seen
warnings.warn(
"Your input ran out of data; interrupting training. "
"Make sure that your dataset or generator can generate "
"at least `steps_per_epoch * epochs` batches. "
"You may need to use the `.repeat()` "
"function when building your dataset.",
stacklevel=2,
)
self._current_iterator = None
self.data_adapter.on_epoch_end()

@ -1,7 +1,7 @@
import numpy as np
from keras_core import testing
from keras_core import backend
from keras_core import testing
from keras_core.layers.layer import Layer
@ -30,7 +30,7 @@ class LayerTest(testing.TestCase):
def call(self, x):
return backend.random.dropout(x, rate=0.5, seed=self.seed_gen)
layer = RNGLayer()
self.assertEqual(layer.variables, [layer.seed_gen.state])
self.assertAllClose(layer.variables[0], [1337, 0])
@ -49,12 +49,14 @@ class LayerTest(testing.TestCase):
x = backend.random.dropout(x, rate=0.5, seed=self.seed_gens[0])
x = backend.random.dropout(x, rate=0.5, seed=self.seed_gens[1])
return x
layer = RNGListLayer()
self.assertEqual(layer.variables, [layer.seed_gens[0].state, layer.seed_gens[1].state])
self.assertEqual(
layer.variables,
[layer.seed_gens[0].state, layer.seed_gens[1].state],
)
self.assertAllClose(layer.variables[0], [1, 0])
self.assertAllClose(layer.variables[1], [10, 0])
layer(np.ones((3, 4)))
self.assertAllClose(layer.variables[0], [1, 1])
self.assertAllClose(layer.variables[1], [10, 1])

@ -97,6 +97,7 @@ class EpochIterator:
raise ValueError(
f"Unrecognized data type: x={x} (of type {type(x)})"
)
self._num_batches = self.data_adapter.num_batches
def _get_iterator(self, return_type):
if return_type not in ("np", "tf"):
@ -118,19 +119,27 @@ class EpochIterator:
try:
data = next(self._current_iterator)
yield step, data
except StopIteration:
except (StopIteration, tf.errors.OutOfRangeError):
warnings.warn(
"The dataset ran out of data before the end of the epoch. "
"When passing `steps_per_epoch` "
"(or otherwise `validation_steps` in `fit()` or `steps` in `evaluate()`), "
"make sure that your dataset size (number of batches) is divisible "
"by `steps_per_epoch`."
"Your input ran out of data; interrupting epoch. "
"Make sure that your dataset or generator can generate "
"at least `steps_per_epoch * epochs` batches. "
"You may need to use the `.repeat()` "
"function when building your dataset.",
stacklevel=2,
)
self._current_iterator = self._get_iterator(return_type)
self._current_iterator = None
else:
for step, data in enumerate(self._get_iterator(return_type)):
yield step, data
if not self._num_batches:
# Infer the number of batches returned by the data_adater.
# Assumed static.
self._num_batches = step + 1
self.data_adapter.on_epoch_end()
@property
def num_batches(self):
return self.data_adapter.num_batches
# Either copied from the data_adapter, or
# inferred at the end of an iteration.
return self._num_batches