Allow nested inputs for the trainer (#192)

This commit is contained in:
Matt Watson 2023-05-19 09:33:52 -07:00 committed by Francois Chollet
parent fe21b1aa71
commit 693cd31022
2 changed files with 12 additions and 7 deletions

@ -42,6 +42,7 @@ import types
import warnings import warnings
import tensorflow as tf import tensorflow as tf
from tensorflow import nest
from keras_core.trainers.data_adapters import array_data_adapter from keras_core.trainers.data_adapters import array_data_adapter
from keras_core.trainers.data_adapters import data_adapter_utils from keras_core.trainers.data_adapters import data_adapter_utils
@ -67,7 +68,8 @@ class EpochIterator:
if steps_per_epoch: if steps_per_epoch:
self._current_iterator = None self._current_iterator = None
self._insufficient_data = False self._insufficient_data = False
if isinstance(x, data_adapter_utils.ARRAY_TYPES): first_element = next(iter(nest.flatten(x)), None)
if isinstance(first_element, data_adapter_utils.ARRAY_TYPES):
self.data_adapter = array_data_adapter.ArrayDataAdapter( self.data_adapter = array_data_adapter.ArrayDataAdapter(
x, x,
y, y,

@ -29,7 +29,7 @@ class ExampleModel(layers.Dense, Trainer):
Trainer.__init__(self) Trainer.__init__(self)
class OutputStructModel(layers.Layer, Trainer): class StructModel(layers.Layer, Trainer):
def __init__(self, units): def __init__(self, units):
layers.Layer.__init__(self) layers.Layer.__init__(self)
Trainer.__init__(self) Trainer.__init__(self)
@ -46,8 +46,8 @@ class OutputStructModel(layers.Layer, Trainer):
def call(self, x): def call(self, x):
return { return {
"y_one": self.dense_1(x), "y_one": self.dense_1(x["x_one"]),
"y_two": self.dense_2(x), "y_two": self.dense_2(x["x_two"]),
} }
@ -193,12 +193,15 @@ class TestTrainer(testing.TestCase):
outputs = model.predict(x, batch_size=batch_size) outputs = model.predict(x, batch_size=batch_size)
self.assertAllClose(outputs, 4 * np.ones((100, 3))) self.assertAllClose(outputs, 4 * np.ones((100, 3)))
# Test with output struct # Test with input/output structs
model = OutputStructModel(units=3) model = StructModel(units=3)
model.run_eagerly = run_eagerly model.run_eagerly = run_eagerly
model.jit_compile = jit_compile model.jit_compile = jit_compile
x = np.ones((100, 4)) x = {
"x_one": np.ones((100, 4)),
"x_two": np.ones((100, 4)),
}
batch_size = 16 batch_size = 16
outputs = model.predict(x, batch_size=batch_size) outputs = model.predict(x, batch_size=batch_size)
self.assertTrue(isinstance(outputs, dict)) self.assertTrue(isinstance(outputs, dict))