Allow nested inputs for the trainer (#192)
This commit is contained in:
parent
fe21b1aa71
commit
693cd31022
@ -42,6 +42,7 @@ import types
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow import nest
|
||||||
|
|
||||||
from keras_core.trainers.data_adapters import array_data_adapter
|
from keras_core.trainers.data_adapters import array_data_adapter
|
||||||
from keras_core.trainers.data_adapters import data_adapter_utils
|
from keras_core.trainers.data_adapters import data_adapter_utils
|
||||||
@ -67,7 +68,8 @@ class EpochIterator:
|
|||||||
if steps_per_epoch:
|
if steps_per_epoch:
|
||||||
self._current_iterator = None
|
self._current_iterator = None
|
||||||
self._insufficient_data = False
|
self._insufficient_data = False
|
||||||
if isinstance(x, data_adapter_utils.ARRAY_TYPES):
|
first_element = next(iter(nest.flatten(x)), None)
|
||||||
|
if isinstance(first_element, data_adapter_utils.ARRAY_TYPES):
|
||||||
self.data_adapter = array_data_adapter.ArrayDataAdapter(
|
self.data_adapter = array_data_adapter.ArrayDataAdapter(
|
||||||
x,
|
x,
|
||||||
y,
|
y,
|
||||||
|
@ -29,7 +29,7 @@ class ExampleModel(layers.Dense, Trainer):
|
|||||||
Trainer.__init__(self)
|
Trainer.__init__(self)
|
||||||
|
|
||||||
|
|
||||||
class OutputStructModel(layers.Layer, Trainer):
|
class StructModel(layers.Layer, Trainer):
|
||||||
def __init__(self, units):
|
def __init__(self, units):
|
||||||
layers.Layer.__init__(self)
|
layers.Layer.__init__(self)
|
||||||
Trainer.__init__(self)
|
Trainer.__init__(self)
|
||||||
@ -46,8 +46,8 @@ class OutputStructModel(layers.Layer, Trainer):
|
|||||||
|
|
||||||
def call(self, x):
|
def call(self, x):
|
||||||
return {
|
return {
|
||||||
"y_one": self.dense_1(x),
|
"y_one": self.dense_1(x["x_one"]),
|
||||||
"y_two": self.dense_2(x),
|
"y_two": self.dense_2(x["x_two"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -193,12 +193,15 @@ class TestTrainer(testing.TestCase):
|
|||||||
outputs = model.predict(x, batch_size=batch_size)
|
outputs = model.predict(x, batch_size=batch_size)
|
||||||
self.assertAllClose(outputs, 4 * np.ones((100, 3)))
|
self.assertAllClose(outputs, 4 * np.ones((100, 3)))
|
||||||
|
|
||||||
# Test with output struct
|
# Test with input/output structs
|
||||||
model = OutputStructModel(units=3)
|
model = StructModel(units=3)
|
||||||
model.run_eagerly = run_eagerly
|
model.run_eagerly = run_eagerly
|
||||||
model.jit_compile = jit_compile
|
model.jit_compile = jit_compile
|
||||||
|
|
||||||
x = np.ones((100, 4))
|
x = {
|
||||||
|
"x_one": np.ones((100, 4)),
|
||||||
|
"x_two": np.ones((100, 4)),
|
||||||
|
}
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
outputs = model.predict(x, batch_size=batch_size)
|
outputs = model.predict(x, batch_size=batch_size)
|
||||||
self.assertTrue(isinstance(outputs, dict))
|
self.assertTrue(isinstance(outputs, dict))
|
||||||
|
Loading…
Reference in New Issue
Block a user