Add TF evaluate() flow.
This commit is contained in:
parent
ee5be68ce9
commit
aa4078899d
@ -9,9 +9,10 @@ from keras_core.trainers.epoch_iterator import EpochIterator
|
||||
|
||||
|
||||
class Trainer(base_trainer.Trainer):
|
||||
def stateless_compute_loss_and_updates(
|
||||
def compute_loss_and_updates(
|
||||
self, trainable_variables, non_trainable_variables, x, y, sample_weight
|
||||
):
|
||||
"""This method is stateless and is intended for use with jax.grad."""
|
||||
y_pred, non_trainable_variables = self.stateless_call(
|
||||
trainable_variables, non_trainable_variables, x
|
||||
)
|
||||
@ -104,7 +105,7 @@ class Trainer(base_trainer.Trainer):
|
||||
)
|
||||
|
||||
grad_fn = jax.value_and_grad(
|
||||
self.stateless_compute_loss_and_updates, has_aux=True
|
||||
self.compute_loss_and_updates, has_aux=True
|
||||
)
|
||||
|
||||
def _train_step(state, data):
|
||||
|
@ -13,7 +13,7 @@ from keras_core.utils.naming import auto_name
|
||||
DYNAMIC_SHAPES_OK = True
|
||||
|
||||
|
||||
class Variable(KerasVariable):
|
||||
class Variable(KerasVariable, tf.__internal__.types.Tensor):
|
||||
def __init__(self, value, dtype=None, trainable=True, name=None):
|
||||
self.name = name or auto_name(self.__class__.__name__)
|
||||
dtype = standardize_dtype(dtype)
|
||||
|
@ -12,6 +12,12 @@ from keras_core.trainers.epoch_iterator import EpochIterator
|
||||
|
||||
|
||||
class Trainer(base_trainer.Trainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.train_function = None
|
||||
self.test_function = None
|
||||
self.predict_function = None
|
||||
|
||||
def train_step(self, data):
|
||||
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
|
||||
|
||||
@ -36,10 +42,24 @@ class Trainer(base_trainer.Trainer):
|
||||
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
|
||||
|
||||
def test_step(self, data):
|
||||
raise NotImplementedError
|
||||
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
|
||||
if self._call_has_training_arg():
|
||||
y_pred = self(x, training=False)
|
||||
else:
|
||||
y_pred = self(x)
|
||||
loss = self.compute_loss(
|
||||
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
|
||||
)
|
||||
self._loss_tracker.update_state(loss)
|
||||
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
|
||||
|
||||
def predict_step(self, data):
|
||||
raise NotImplementedError
|
||||
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)
|
||||
if self._call_has_training_arg():
|
||||
y_pred = self(x, training=False)
|
||||
else:
|
||||
y_pred = self(x)
|
||||
return y_pred
|
||||
|
||||
def make_train_function(self, force=False):
|
||||
# TODO: support tf.distribute and steps_per_execution.
|
||||
@ -47,7 +67,7 @@ class Trainer(base_trainer.Trainer):
|
||||
return self.train_function
|
||||
|
||||
def one_step_on_data(data):
|
||||
"""Runs a single training step."""
|
||||
"""Runs a single training step on a batch of data."""
|
||||
return self.train_step(data)
|
||||
|
||||
if not self.run_eagerly and self.jit_compile:
|
||||
@ -56,6 +76,7 @@ class Trainer(base_trainer.Trainer):
|
||||
)
|
||||
|
||||
def one_step_on_iterator(iterator):
|
||||
"""Runs a single training step given a Dataset iterator."""
|
||||
data = next(iterator)
|
||||
return one_step_on_data(data)
|
||||
|
||||
@ -68,10 +89,58 @@ class Trainer(base_trainer.Trainer):
|
||||
self.train_function = train_function
|
||||
|
||||
def make_test_function(self, force=False):
|
||||
raise NotImplementedError
|
||||
# TODO: support tf.distribute and steps_per_execution.
|
||||
if self.test_function is not None and not force:
|
||||
return self.test_function
|
||||
|
||||
def one_step_on_data(data):
|
||||
"""Runs a single test step on a batch of data."""
|
||||
return self.test_step(data)
|
||||
|
||||
if not self.run_eagerly and self.jit_compile:
|
||||
one_step_on_data = tf.function(
|
||||
one_step_on_data, jit_compile=True, reduce_retracing=True
|
||||
)
|
||||
|
||||
def one_step_on_iterator(iterator):
|
||||
"""Runs a single test step given a Dataset iterator."""
|
||||
data = next(iterator)
|
||||
return one_step_on_data(data)
|
||||
|
||||
if not self.run_eagerly:
|
||||
test_function = tf.function(
|
||||
one_step_on_iterator, reduce_retracing=True
|
||||
)
|
||||
else:
|
||||
test_function = one_step_on_iterator
|
||||
self.test_function = test_function
|
||||
|
||||
def make_predict_function(self, force=False):
|
||||
raise NotImplementedError
|
||||
# TODO: support tf.distribute and steps_per_execution.
|
||||
if self.predict_function is not None and not force:
|
||||
return self.predict_function
|
||||
|
||||
def one_step_on_data(data):
|
||||
"""Runs a predict test step on a batch of data."""
|
||||
return self.predict_step(data)
|
||||
|
||||
if not self.run_eagerly and self.jit_compile:
|
||||
one_step_on_data = tf.function(
|
||||
one_step_on_data, jit_compile=True, reduce_retracing=True
|
||||
)
|
||||
|
||||
def one_step_on_iterator(iterator):
|
||||
"""Runs a single predict step given a Dataset iterator."""
|
||||
data = next(iterator)
|
||||
return one_step_on_data(data)
|
||||
|
||||
if not self.run_eagerly:
|
||||
predict_function = tf.function(
|
||||
one_step_on_iterator, reduce_retracing=True
|
||||
)
|
||||
else:
|
||||
predict_function = one_step_on_iterator
|
||||
self.predict_function = predict_function
|
||||
|
||||
def fit(
|
||||
self,
|
||||
@ -212,13 +281,71 @@ class Trainer(base_trainer.Trainer):
|
||||
return_dict=False,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
# TODO: respect compiled trainable state
|
||||
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
|
||||
if kwargs:
|
||||
raise ValueError(f"Arguments not recognized: {kwargs}")
|
||||
|
||||
if use_cached_eval_dataset:
|
||||
epoch_iterator = self._eval_epoch_iterator
|
||||
else:
|
||||
# Create an iterator that yields batches for one epoch.
|
||||
epoch_iterator = TFEpochIterator(
|
||||
x=x,
|
||||
y=y,
|
||||
sample_weight=sample_weight,
|
||||
batch_size=batch_size,
|
||||
steps_per_epoch=steps,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
# Container that configures and calls callbacks.
|
||||
if not isinstance(callbacks, callbacks_module.CallbackList):
|
||||
callbacks = callbacks_module.CallbackList(
|
||||
callbacks,
|
||||
add_history=True,
|
||||
add_progbar=verbose != 0,
|
||||
verbose=verbose,
|
||||
epochs=1,
|
||||
steps=epoch_iterator.num_batches,
|
||||
model=self,
|
||||
)
|
||||
|
||||
self.make_test_function()
|
||||
callbacks.on_test_begin()
|
||||
logs = None
|
||||
self.reset_metrics()
|
||||
with epoch_iterator.catch_stop_iteration():
|
||||
for step, iterator in epoch_iterator.enumerate_epoch():
|
||||
callbacks.on_test_batch_begin(step)
|
||||
logs = self.test_function(iterator)
|
||||
callbacks.on_test_batch_end(step, logs)
|
||||
logs = self._process_logs(self.get_metrics_result())
|
||||
callbacks.on_test_end(logs)
|
||||
|
||||
if return_dict:
|
||||
return logs
|
||||
return self._flatten_metrics_in_order(logs)
|
||||
|
||||
def predict(
|
||||
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def _flatten_metrics_in_order(self, logs):
|
||||
"""Turns the `logs` dict into a list as per key order of `metrics_names`."""
|
||||
metric_names = [m.name for m in self.metrics]
|
||||
results = []
|
||||
for name in metric_names:
|
||||
if name in logs:
|
||||
results.append(logs[name])
|
||||
for key in sorted(logs.keys()):
|
||||
if key not in metric_names:
|
||||
results.append(logs[key])
|
||||
if len(results) == 1:
|
||||
return results[0]
|
||||
return results
|
||||
|
||||
|
||||
class TFEpochIterator(EpochIterator):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -5,6 +5,7 @@ from keras_core import metrics as metrics_module
|
||||
from keras_core import operations as ops
|
||||
from keras_core.trainers.compile_utils import CompileLoss
|
||||
from keras_core.trainers.compile_utils import CompileMetrics
|
||||
from keras_core.utils import tracking
|
||||
|
||||
|
||||
class Trainer:
|
||||
@ -12,10 +13,8 @@ class Trainer:
|
||||
self._run_eagerly = False
|
||||
self._jit_compile = True
|
||||
self.compiled = False
|
||||
self.train_function = None
|
||||
self.test_function = None
|
||||
self.predict_function = None
|
||||
|
||||
@tracking.no_automatic_dependency_tracking
|
||||
def compile(
|
||||
self,
|
||||
optimizer,
|
||||
@ -209,24 +208,6 @@ class Trainer:
|
||||
return_metrics[metric.name] = result
|
||||
return return_metrics
|
||||
|
||||
def train_step(self, data):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_step(self, data):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict_step(self, data):
|
||||
raise NotImplementedError
|
||||
|
||||
def make_train_function(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def make_test_function(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def make_predict_function(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def fit(
|
||||
self,
|
||||
x=None,
|
||||
|
@ -27,7 +27,7 @@ class ExampleModel(layers.Dense, Trainer):
|
||||
|
||||
|
||||
class TestTrainer(testing.TestCase):
|
||||
def _test_basic_flow(self, run_eagerly, jit_compile):
|
||||
def _test_fit_flow(self, run_eagerly, jit_compile):
|
||||
model = ExampleModel(units=3)
|
||||
x = np.ones((100, 4))
|
||||
y = np.zeros((100, 3))
|
||||
@ -49,11 +49,41 @@ class TestTrainer(testing.TestCase):
|
||||
history["mean_squared_error"], [13.938, 9.547, 6.539], atol=1e-2
|
||||
)
|
||||
|
||||
def test_basic_flow_eager(self):
|
||||
self._test_basic_flow(run_eagerly=True, jit_compile=False)
|
||||
def test_fit_flow_eager(self):
|
||||
self._test_fit_flow(run_eagerly=True, jit_compile=False)
|
||||
|
||||
def test_basic_flow_graph_fn(self):
|
||||
self._test_basic_flow(run_eagerly=False, jit_compile=False)
|
||||
def test_fit_flow_graph_fn(self):
|
||||
self._test_fit_flow(run_eagerly=False, jit_compile=False)
|
||||
|
||||
def test_basic_flow_jit(self):
|
||||
self._test_basic_flow(run_eagerly=False, jit_compile=True)
|
||||
def test_fit_flow_jit(self):
|
||||
self._test_fit_flow(run_eagerly=False, jit_compile=True)
|
||||
|
||||
def _test_evaluate_flow(self, run_eagerly, jit_compile):
|
||||
model = ExampleModel(units=3)
|
||||
x = np.ones((100, 4))
|
||||
y = np.zeros((100, 3))
|
||||
batch_size = 16
|
||||
|
||||
model.compile(
|
||||
optimizer=optimizers.SGD(),
|
||||
loss=losses.MeanSquaredError(),
|
||||
metrics=[metrics.MeanSquaredError()],
|
||||
run_eagerly=run_eagerly,
|
||||
jit_compile=jit_compile,
|
||||
)
|
||||
output = model.evaluate(x, y, batch_size=batch_size)
|
||||
self.assertAllClose(output, [16.0, 16.0])
|
||||
output = model.evaluate(x, y, batch_size=batch_size, return_dict=True)
|
||||
self.assertTrue(isinstance(output, dict))
|
||||
self.assertIn("loss", output)
|
||||
self.assertIn("mean_squared_error", output)
|
||||
self.assertAllClose(output["mean_squared_error"], 16.0)
|
||||
|
||||
def test_evaluate_flow_eager(self):
|
||||
self._test_evaluate_flow(run_eagerly=True, jit_compile=False)
|
||||
|
||||
def test_evaluate_flow_graph_fn(self):
|
||||
self._test_evaluate_flow(run_eagerly=False, jit_compile=False)
|
||||
|
||||
def test_evaluate_flow_jit(self):
|
||||
self._test_evaluate_flow(run_eagerly=False, jit_compile=True)
|
||||
|
Loading…
Reference in New Issue
Block a user