Allow nested inputs for the trainer (#192)
This commit is contained in:
parent
fe21b1aa71
commit
693cd31022
@ -42,6 +42,7 @@ import types
|
||||
import warnings
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import nest
|
||||
|
||||
from keras_core.trainers.data_adapters import array_data_adapter
|
||||
from keras_core.trainers.data_adapters import data_adapter_utils
|
||||
@ -67,7 +68,8 @@ class EpochIterator:
|
||||
if steps_per_epoch:
|
||||
self._current_iterator = None
|
||||
self._insufficient_data = False
|
||||
if isinstance(x, data_adapter_utils.ARRAY_TYPES):
|
||||
first_element = next(iter(nest.flatten(x)), None)
|
||||
if isinstance(first_element, data_adapter_utils.ARRAY_TYPES):
|
||||
self.data_adapter = array_data_adapter.ArrayDataAdapter(
|
||||
x,
|
||||
y,
|
||||
|
@ -29,7 +29,7 @@ class ExampleModel(layers.Dense, Trainer):
|
||||
Trainer.__init__(self)
|
||||
|
||||
|
||||
class OutputStructModel(layers.Layer, Trainer):
|
||||
class StructModel(layers.Layer, Trainer):
|
||||
def __init__(self, units):
|
||||
layers.Layer.__init__(self)
|
||||
Trainer.__init__(self)
|
||||
@ -46,8 +46,8 @@ class OutputStructModel(layers.Layer, Trainer):
|
||||
|
||||
def call(self, x):
|
||||
return {
|
||||
"y_one": self.dense_1(x),
|
||||
"y_two": self.dense_2(x),
|
||||
"y_one": self.dense_1(x["x_one"]),
|
||||
"y_two": self.dense_2(x["x_two"]),
|
||||
}
|
||||
|
||||
|
||||
@ -193,12 +193,15 @@ class TestTrainer(testing.TestCase):
|
||||
outputs = model.predict(x, batch_size=batch_size)
|
||||
self.assertAllClose(outputs, 4 * np.ones((100, 3)))
|
||||
|
||||
# Test with output struct
|
||||
model = OutputStructModel(units=3)
|
||||
# Test with input/output structs
|
||||
model = StructModel(units=3)
|
||||
model.run_eagerly = run_eagerly
|
||||
model.jit_compile = jit_compile
|
||||
|
||||
x = np.ones((100, 4))
|
||||
x = {
|
||||
"x_one": np.ones((100, 4)),
|
||||
"x_two": np.ones((100, 4)),
|
||||
}
|
||||
batch_size = 16
|
||||
outputs = model.predict(x, batch_size=batch_size)
|
||||
self.assertTrue(isinstance(outputs, dict))
|
||||
|
Loading…
Reference in New Issue
Block a user