diff --git a/keras_core/trainers/epoch_iterator.py b/keras_core/trainers/epoch_iterator.py index 397e3cff7..4f40cf9c3 100644 --- a/keras_core/trainers/epoch_iterator.py +++ b/keras_core/trainers/epoch_iterator.py @@ -42,6 +42,7 @@ import types import warnings import tensorflow as tf +from tensorflow import nest from keras_core.trainers.data_adapters import array_data_adapter from keras_core.trainers.data_adapters import data_adapter_utils @@ -67,7 +68,8 @@ class EpochIterator: if steps_per_epoch: self._current_iterator = None self._insufficient_data = False - if isinstance(x, data_adapter_utils.ARRAY_TYPES): + first_element = next(iter(nest.flatten(x)), None) + if isinstance(first_element, data_adapter_utils.ARRAY_TYPES): self.data_adapter = array_data_adapter.ArrayDataAdapter( x, y, diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index 073f52fe0..f682ee04e 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -29,7 +29,7 @@ class ExampleModel(layers.Dense, Trainer): Trainer.__init__(self) -class OutputStructModel(layers.Layer, Trainer): +class StructModel(layers.Layer, Trainer): def __init__(self, units): layers.Layer.__init__(self) Trainer.__init__(self) @@ -46,8 +46,8 @@ class OutputStructModel(layers.Layer, Trainer): def call(self, x): return { - "y_one": self.dense_1(x), - "y_two": self.dense_2(x), + "y_one": self.dense_1(x["x_one"]), + "y_two": self.dense_2(x["x_two"]), } @@ -193,12 +193,15 @@ class TestTrainer(testing.TestCase): outputs = model.predict(x, batch_size=batch_size) self.assertAllClose(outputs, 4 * np.ones((100, 3))) - # Test with output struct - model = OutputStructModel(units=3) + # Test with input/output structs + model = StructModel(units=3) model.run_eagerly = run_eagerly model.jit_compile = jit_compile - x = np.ones((100, 4)) + x = { + "x_one": np.ones((100, 4)), + "x_two": np.ones((100, 4)), + } batch_size = 16 outputs = model.predict(x, batch_size=batch_size) self.assertTrue(isinstance(outputs, dict))