2023-04-17 21:56:37 +00:00
|
|
|
import numpy as np
|
|
|
|
|
2023-04-18 04:26:04 +00:00
|
|
|
from keras_core import backend
|
2023-04-19 20:50:22 +00:00
|
|
|
from keras_core import initializers
|
2023-04-17 21:56:37 +00:00
|
|
|
from keras_core import layers
|
|
|
|
from keras_core import losses
|
2023-04-17 22:41:48 +00:00
|
|
|
from keras_core import metrics
|
|
|
|
from keras_core import optimizers
|
|
|
|
from keras_core import testing
|
2023-04-19 20:50:22 +00:00
|
|
|
|
|
|
|
if backend.backend() == "jax":
|
2023-04-20 21:59:20 +00:00
|
|
|
from keras_core.backend.jax.trainer import JAXTrainer as Trainer
|
2023-04-19 20:50:22 +00:00
|
|
|
else:
|
2023-04-21 17:00:32 +00:00
|
|
|
from keras_core.backend.tensorflow.trainer import (
|
|
|
|
TensorFlowTrainer as Trainer,
|
|
|
|
)
|
2023-04-17 21:56:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
# A model is just a layer mixed in with a Trainer.
|
2023-04-19 20:50:22 +00:00
|
|
|
class ExampleModel(layers.Dense, Trainer):
|
2023-04-17 21:56:37 +00:00
|
|
|
def __init__(self, units):
|
2023-04-19 20:50:22 +00:00
|
|
|
layers.Dense.__init__(
|
|
|
|
self,
|
|
|
|
units=units,
|
|
|
|
use_bias=False,
|
|
|
|
kernel_initializer=initializers.Ones(),
|
|
|
|
)
|
2023-04-18 00:47:22 +00:00
|
|
|
Trainer.__init__(self)
|
2023-04-17 21:56:37 +00:00
|
|
|
|
|
|
|
|
2023-04-20 20:08:37 +00:00
|
|
|
class OutputStructModel(layers.Layer, Trainer):
|
|
|
|
def __init__(self, units):
|
|
|
|
layers.Layer.__init__(self)
|
|
|
|
Trainer.__init__(self)
|
|
|
|
self.dense_1 = layers.Dense(
|
|
|
|
units,
|
|
|
|
use_bias=False,
|
|
|
|
kernel_initializer=initializers.Ones(),
|
|
|
|
)
|
|
|
|
self.dense_2 = layers.Dense(
|
|
|
|
units,
|
|
|
|
use_bias=False,
|
|
|
|
kernel_initializer=initializers.Ones(),
|
|
|
|
)
|
|
|
|
|
|
|
|
def call(self, x):
|
|
|
|
return {
|
|
|
|
"y_one": self.dense_1(x),
|
|
|
|
"y_two": self.dense_2(x),
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2023-04-19 20:50:22 +00:00
|
|
|
class TestTrainer(testing.TestCase):
|
2023-04-20 01:37:25 +00:00
|
|
|
def test_metric_tracking(self):
|
|
|
|
class ModelWithMetric(layers.Dense, Trainer):
|
|
|
|
def __init__(self, units):
|
|
|
|
layers.Dense.__init__(
|
|
|
|
self,
|
|
|
|
units=units,
|
|
|
|
use_bias=False,
|
|
|
|
kernel_initializer=initializers.Ones(),
|
|
|
|
)
|
|
|
|
Trainer.__init__(self)
|
|
|
|
self.my_metric = metrics.MeanSquaredError(name="my_metric")
|
|
|
|
|
|
|
|
model = ModelWithMetric(units=3)
|
|
|
|
model.compile(
|
|
|
|
optimizer=optimizers.SGD(),
|
|
|
|
loss=losses.MeanSquaredError(),
|
|
|
|
metrics=[metrics.MeanSquaredError()],
|
|
|
|
)
|
|
|
|
x = np.ones((2, 4))
|
|
|
|
y = np.zeros((2, 3))
|
|
|
|
# Fit the model to make sure compile_metrics are built
|
|
|
|
model.fit(x, y, batch_size=2, epochs=1)
|
|
|
|
|
2023-04-22 06:16:39 +00:00
|
|
|
# The model should have 3 metrics: loss_tracker, compile_metrics, my_metric
|
2023-04-20 01:37:25 +00:00
|
|
|
self.assertEqual(len(model.metrics), 3)
|
|
|
|
self.assertEqual(model.metrics[0], model._loss_tracker)
|
|
|
|
self.assertEqual(model.metrics[1], model.my_metric)
|
|
|
|
self.assertEqual(model.metrics[2], model._compile_metrics)
|
|
|
|
|
|
|
|
# All metrics should have their weights created
|
|
|
|
self.assertEqual(len(model._loss_tracker.variables), 2)
|
|
|
|
self.assertEqual(len(model._compile_metrics.variables), 2)
|
|
|
|
self.assertEqual(len(model.my_metric.variables), 2)
|
|
|
|
|
|
|
|
# And those weights are tracked at the model level
|
|
|
|
self.assertEqual(len(model.metrics_variables), 6)
|
|
|
|
|
2023-04-19 23:25:56 +00:00
|
|
|
def _test_fit_flow(self, run_eagerly, jit_compile):
|
2023-04-19 20:50:22 +00:00
|
|
|
model = ExampleModel(units=3)
|
|
|
|
x = np.ones((100, 4))
|
|
|
|
y = np.zeros((100, 3))
|
2023-04-17 21:56:37 +00:00
|
|
|
batch_size = 16
|
2023-04-19 20:50:22 +00:00
|
|
|
epochs = 3
|
2023-04-17 21:56:37 +00:00
|
|
|
|
2023-04-17 22:41:48 +00:00
|
|
|
model.compile(
|
|
|
|
optimizer=optimizers.SGD(),
|
|
|
|
loss=losses.MeanSquaredError(),
|
|
|
|
metrics=[metrics.MeanSquaredError()],
|
|
|
|
run_eagerly=run_eagerly,
|
|
|
|
jit_compile=jit_compile,
|
|
|
|
)
|
2023-04-17 22:18:27 +00:00
|
|
|
history = model.fit(x, y, batch_size=batch_size, epochs=epochs)
|
2023-04-19 20:50:22 +00:00
|
|
|
history = history.history
|
|
|
|
self.assertIn("loss", history)
|
|
|
|
self.assertIn("mean_squared_error", history)
|
|
|
|
self.assertAllClose(
|
|
|
|
history["mean_squared_error"], [13.938, 9.547, 6.539], atol=1e-2
|
|
|
|
)
|
2023-04-17 21:56:37 +00:00
|
|
|
|
2023-04-19 23:25:56 +00:00
|
|
|
def test_fit_flow_eager(self):
|
|
|
|
self._test_fit_flow(run_eagerly=True, jit_compile=False)
|
2023-04-17 21:56:37 +00:00
|
|
|
|
2023-04-19 23:25:56 +00:00
|
|
|
def test_fit_flow_graph_fn(self):
|
|
|
|
self._test_fit_flow(run_eagerly=False, jit_compile=False)
|
2023-04-17 21:56:37 +00:00
|
|
|
|
2023-04-19 23:25:56 +00:00
|
|
|
def test_fit_flow_jit(self):
|
|
|
|
self._test_fit_flow(run_eagerly=False, jit_compile=True)
|
|
|
|
|
|
|
|
def _test_evaluate_flow(self, run_eagerly, jit_compile):
|
|
|
|
model = ExampleModel(units=3)
|
|
|
|
x = np.ones((100, 4))
|
|
|
|
y = np.zeros((100, 3))
|
|
|
|
batch_size = 16
|
|
|
|
|
|
|
|
model.compile(
|
|
|
|
optimizer=optimizers.SGD(),
|
|
|
|
loss=losses.MeanSquaredError(),
|
|
|
|
metrics=[metrics.MeanSquaredError()],
|
|
|
|
run_eagerly=run_eagerly,
|
|
|
|
jit_compile=jit_compile,
|
|
|
|
)
|
|
|
|
output = model.evaluate(x, y, batch_size=batch_size)
|
|
|
|
self.assertAllClose(output, [16.0, 16.0])
|
|
|
|
output = model.evaluate(x, y, batch_size=batch_size, return_dict=True)
|
|
|
|
self.assertTrue(isinstance(output, dict))
|
|
|
|
self.assertIn("loss", output)
|
|
|
|
self.assertIn("mean_squared_error", output)
|
|
|
|
self.assertAllClose(output["mean_squared_error"], 16.0)
|
|
|
|
|
|
|
|
def test_evaluate_flow_eager(self):
|
|
|
|
self._test_evaluate_flow(run_eagerly=True, jit_compile=False)
|
|
|
|
|
|
|
|
def test_evaluate_flow_graph_fn(self):
|
|
|
|
self._test_evaluate_flow(run_eagerly=False, jit_compile=False)
|
|
|
|
|
|
|
|
def test_evaluate_flow_jit(self):
|
|
|
|
self._test_evaluate_flow(run_eagerly=False, jit_compile=True)
|
2023-04-20 20:08:37 +00:00
|
|
|
|
|
|
|
def _test_predict_flow(self, run_eagerly, jit_compile):
|
|
|
|
# Test basic example
|
|
|
|
model = ExampleModel(units=3)
|
|
|
|
model.run_eagerly = run_eagerly
|
|
|
|
model.jit_compile = jit_compile
|
|
|
|
|
|
|
|
x = np.ones((100, 4))
|
|
|
|
batch_size = 16
|
|
|
|
outputs = model.predict(x, batch_size=batch_size)
|
|
|
|
self.assertAllClose(outputs, 4 * np.ones((100, 3)))
|
|
|
|
|
|
|
|
# Test with output struct
|
|
|
|
model = OutputStructModel(units=3)
|
|
|
|
model.run_eagerly = run_eagerly
|
|
|
|
model.jit_compile = jit_compile
|
|
|
|
|
|
|
|
x = np.ones((100, 4))
|
|
|
|
batch_size = 16
|
|
|
|
outputs = model.predict(x, batch_size=batch_size)
|
|
|
|
self.assertTrue(isinstance(outputs, dict))
|
|
|
|
self.assertEqual(len(outputs), 2)
|
|
|
|
self.assertAllClose(outputs["y_one"], 4 * np.ones((100, 3)))
|
|
|
|
self.assertAllClose(outputs["y_two"], 4 * np.ones((100, 3)))
|
|
|
|
|
|
|
|
def test_predicte_flow_eager(self):
|
|
|
|
self._test_predict_flow(run_eagerly=True, jit_compile=False)
|
|
|
|
|
|
|
|
def test_predict_flow_graph_fn(self):
|
|
|
|
self._test_predict_flow(run_eagerly=False, jit_compile=False)
|
|
|
|
|
|
|
|
def test_predict_flow_jit(self):
|
|
|
|
self._test_predict_flow(run_eagerly=False, jit_compile=True)
|