Switch TF trainer to a setup that enables future step fusing
This commit is contained in:
parent
bbcd6eee8a
commit
0b8dd458b4
@ -1,4 +1,8 @@
|
||||
import contextlib
|
||||
import warnings
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.eager import context as tf_context
|
||||
|
||||
from keras_core import callbacks as callbacks_module
|
||||
from keras_core import optimizers as optimizers_module
|
||||
@ -41,19 +45,25 @@ class Trainer(base_trainer.Trainer):
|
||||
if self.train_function is not None and not force:
|
||||
return self.train_function
|
||||
|
||||
def step_function(data):
|
||||
def one_step_on_data(data):
|
||||
"""Runs a single training step."""
|
||||
return self.train_step(data)
|
||||
|
||||
if self.jit_compile:
|
||||
train_function = tf.function(
|
||||
step_function, jit_compile=True, reduce_retracing=True
|
||||
if not self.run_eagerly and self.jit_compile:
|
||||
one_step_on_data = tf.function(
|
||||
one_step_on_data, jit_compile=True, reduce_retracing=True
|
||||
)
|
||||
elif not self.run_eagerly:
|
||||
train_function = tf.function(step_function, reduce_retracing=True)
|
||||
else:
|
||||
train_function = step_function
|
||||
|
||||
def one_step_on_iterator(iterator):
|
||||
data = next(iterator)
|
||||
return one_step_on_data(data)
|
||||
|
||||
if not self.run_eagerly:
|
||||
train_function = tf.function(
|
||||
one_step_on_iterator, reduce_retracing=True
|
||||
)
|
||||
else:
|
||||
train_function = one_step_on_iterator
|
||||
self.train_function = train_function
|
||||
|
||||
def make_test_function(self, force=False):
|
||||
@ -105,7 +115,7 @@ class Trainer(base_trainer.Trainer):
|
||||
) = data_adapter_utils.unpack_x_y_sample_weight(validation_data)
|
||||
|
||||
# Create an iterator that yields batches for one epoch.
|
||||
epoch_iterator = EpochIterator(
|
||||
epoch_iterator = TFEpochIterator(
|
||||
x=x,
|
||||
y=y,
|
||||
sample_weight=sample_weight,
|
||||
@ -135,12 +145,13 @@ class Trainer(base_trainer.Trainer):
|
||||
for epoch in range(initial_epoch, epochs):
|
||||
self.reset_metrics()
|
||||
callbacks.on_epoch_begin(epoch)
|
||||
for step, batch in epoch_iterator.enumerate_epoch(return_type="tf"):
|
||||
callbacks.on_train_batch_begin(step)
|
||||
logs = self.train_function(batch)
|
||||
callbacks.on_train_batch_end(step, logs)
|
||||
if self.stop_training:
|
||||
break
|
||||
with epoch_iterator.catch_stop_iteration():
|
||||
for step, iterator in epoch_iterator.enumerate_epoch():
|
||||
callbacks.on_train_batch_begin(step)
|
||||
logs = self.train_function(iterator)
|
||||
callbacks.on_train_batch_end(step, logs)
|
||||
if self.stop_training:
|
||||
break
|
||||
|
||||
# Override with model metrics instead of last step logs
|
||||
epoch_logs = self._process_logs(self.get_metrics_result())
|
||||
@ -216,3 +227,53 @@ class Trainer(base_trainer.Trainer):
|
||||
pass
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
class TFEpochIterator(EpochIterator):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._steps_seen = 0
|
||||
|
||||
def enumerate_epoch(self):
|
||||
if self.steps_per_epoch:
|
||||
if not self._current_iterator:
|
||||
self._current_iterator = iter(
|
||||
self.data_adapter.get_tf_dataset()
|
||||
)
|
||||
for step in range(self.steps_per_epoch):
|
||||
yield step, self._current_iterator
|
||||
else:
|
||||
iterator = iter(self.data_adapter.get_tf_dataset())
|
||||
if self.num_batches:
|
||||
for step in range(self.num_batches):
|
||||
yield step, iterator
|
||||
else:
|
||||
step = -1
|
||||
while True:
|
||||
step += 1
|
||||
self._steps_seen = step + 1
|
||||
yield step, iterator
|
||||
self.data_adapter.on_epoch_end()
|
||||
|
||||
def tf_sync(self):
|
||||
tf_context.async_wait()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def catch_stop_iteration(self):
|
||||
"""Catches errors when an iterator runs out of data."""
|
||||
try:
|
||||
yield
|
||||
self.tf_sync()
|
||||
except (StopIteration, tf.errors.OutOfRangeError):
|
||||
if self._num_batches is None:
|
||||
self._num_batches = self._steps_seen
|
||||
warnings.warn(
|
||||
"Your input ran out of data; interrupting training. "
|
||||
"Make sure that your dataset or generator can generate "
|
||||
"at least `steps_per_epoch * epochs` batches. "
|
||||
"You may need to use the `.repeat()` "
|
||||
"function when building your dataset.",
|
||||
stacklevel=2,
|
||||
)
|
||||
self._current_iterator = None
|
||||
self.data_adapter.on_epoch_end()
|
||||
|
@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
|
||||
from keras_core import testing
|
||||
from keras_core import backend
|
||||
from keras_core import testing
|
||||
from keras_core.layers.layer import Layer
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ class LayerTest(testing.TestCase):
|
||||
|
||||
def call(self, x):
|
||||
return backend.random.dropout(x, rate=0.5, seed=self.seed_gen)
|
||||
|
||||
|
||||
layer = RNGLayer()
|
||||
self.assertEqual(layer.variables, [layer.seed_gen.state])
|
||||
self.assertAllClose(layer.variables[0], [1337, 0])
|
||||
@ -49,12 +49,14 @@ class LayerTest(testing.TestCase):
|
||||
x = backend.random.dropout(x, rate=0.5, seed=self.seed_gens[0])
|
||||
x = backend.random.dropout(x, rate=0.5, seed=self.seed_gens[1])
|
||||
return x
|
||||
|
||||
|
||||
layer = RNGListLayer()
|
||||
self.assertEqual(layer.variables, [layer.seed_gens[0].state, layer.seed_gens[1].state])
|
||||
self.assertEqual(
|
||||
layer.variables,
|
||||
[layer.seed_gens[0].state, layer.seed_gens[1].state],
|
||||
)
|
||||
self.assertAllClose(layer.variables[0], [1, 0])
|
||||
self.assertAllClose(layer.variables[1], [10, 0])
|
||||
layer(np.ones((3, 4)))
|
||||
self.assertAllClose(layer.variables[0], [1, 1])
|
||||
self.assertAllClose(layer.variables[1], [10, 1])
|
||||
|
||||
|
@ -97,6 +97,7 @@ class EpochIterator:
|
||||
raise ValueError(
|
||||
f"Unrecognized data type: x={x} (of type {type(x)})"
|
||||
)
|
||||
self._num_batches = self.data_adapter.num_batches
|
||||
|
||||
def _get_iterator(self, return_type):
|
||||
if return_type not in ("np", "tf"):
|
||||
@ -118,19 +119,27 @@ class EpochIterator:
|
||||
try:
|
||||
data = next(self._current_iterator)
|
||||
yield step, data
|
||||
except StopIteration:
|
||||
except (StopIteration, tf.errors.OutOfRangeError):
|
||||
warnings.warn(
|
||||
"The dataset ran out of data before the end of the epoch. "
|
||||
"When passing `steps_per_epoch` "
|
||||
"(or otherwise `validation_steps` in `fit()` or `steps` in `evaluate()`), "
|
||||
"make sure that your dataset size (number of batches) is divisible "
|
||||
"by `steps_per_epoch`."
|
||||
"Your input ran out of data; interrupting epoch. "
|
||||
"Make sure that your dataset or generator can generate "
|
||||
"at least `steps_per_epoch * epochs` batches. "
|
||||
"You may need to use the `.repeat()` "
|
||||
"function when building your dataset.",
|
||||
stacklevel=2,
|
||||
)
|
||||
self._current_iterator = self._get_iterator(return_type)
|
||||
self._current_iterator = None
|
||||
else:
|
||||
for step, data in enumerate(self._get_iterator(return_type)):
|
||||
yield step, data
|
||||
if not self._num_batches:
|
||||
# Infer the number of batches returned by the data_adater.
|
||||
# Assumed static.
|
||||
self._num_batches = step + 1
|
||||
self.data_adapter.on_epoch_end()
|
||||
|
||||
@property
|
||||
def num_batches(self):
|
||||
return self.data_adapter.num_batches
|
||||
# Either copied from the data_adapter, or
|
||||
# inferred at the end of an iteration.
|
||||
return self._num_batches
|
||||
|
Loading…
Reference in New Issue
Block a user