More extensive Functional model test coverage.
This commit is contained in:
parent
d663d5bcc1
commit
4fd0029ad7
@ -2,6 +2,7 @@ import warnings
|
||||
|
||||
from tensorflow import nest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import operations as ops
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.models.model import Model
|
||||
@ -32,6 +33,19 @@ class Functional(Function, Model):
|
||||
skip_init = kwargs.pop("skip_init", False)
|
||||
if skip_init:
|
||||
return
|
||||
if isinstance(inputs, dict):
|
||||
for k, v in inputs.items():
|
||||
if not isinstance(v, backend.KerasTensor):
|
||||
raise ValueError(
|
||||
"When providing an input dict, all values in the dict "
|
||||
f"must be KerasTensors. Received: inputs={inputs} including "
|
||||
f"invalid value {v} of type {type(v)}")
|
||||
if k != v.name:
|
||||
raise ValueError(
|
||||
"When providing an input dict, all keys in the dict "
|
||||
"must match the names of the corresponding tensors. "
|
||||
f"Received key '{k}' mapping to value {v} which has name '{v.name}'. "
|
||||
f"Change the tensor name to '{k}' (via `Input(..., name='{k}')`)")
|
||||
super().__init__(inputs, outputs, name=name, **kwargs)
|
||||
self._layers = self.layers
|
||||
self.built = True
|
||||
@ -59,6 +73,9 @@ class Functional(Function, Model):
|
||||
|
||||
def compute_output_spec(self, inputs, training=False, mask=None):
|
||||
return super().compute_output_spec(inputs)
|
||||
|
||||
def _assert_input_compatibility(self, *args):
|
||||
return super(Model, self)._assert_input_compatibility(*args)
|
||||
|
||||
def _flatten_to_reference_inputs(self, inputs):
|
||||
if isinstance(inputs, dict):
|
||||
|
@ -2,7 +2,6 @@ import numpy as np
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import layers
|
||||
from keras_core import operations as ops
|
||||
from keras_core import testing
|
||||
from keras_core.layers.core.input_layer import Input
|
||||
from keras_core.models.functional import Functional
|
||||
@ -51,6 +50,33 @@ class FunctionalTest(testing.TestCase):
|
||||
self.assertEqual(out_val[0].shape, (2, 4))
|
||||
self.assertEqual(out_val[1].shape, (2, 5))
|
||||
|
||||
def test_basic_flow_dict_io(self):
|
||||
input_a = Input(shape=(3,), batch_size=2, name="a")
|
||||
input_b = Input(shape=(3,), batch_size=2, name="b")
|
||||
x = input_a + input_b
|
||||
x = layers.Dense(5)(x)
|
||||
outputs = layers.Dense(4)(x)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "all values in the dict must be KerasTensors"):
|
||||
model = Functional({"aa": [input_a], "bb": input_b}, outputs)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "all keys in the dict must match the names"):
|
||||
model = Functional({"aa": input_a, "bb": input_b}, outputs)
|
||||
|
||||
model = Functional({"a": input_a, "b": input_b}, outputs)
|
||||
|
||||
# Eager call
|
||||
in_val = {"a": np.random.random((2, 3)), "b": np.random.random((2, 3))}
|
||||
out_val = model(in_val)
|
||||
self.assertEqual(out_val.shape, (2, 4))
|
||||
|
||||
# Symbolic call
|
||||
input_a_2 = Input(shape=(3,), batch_size=2)
|
||||
input_b_2 = Input(shape=(3,), batch_size=2)
|
||||
in_val = {"a": input_a_2, "b": input_b_2}
|
||||
out_val = model(in_val)
|
||||
self.assertEqual(out_val.shape, (2, 4))
|
||||
|
||||
def test_layer_getters(self):
|
||||
# Test mixing ops and layers
|
||||
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
||||
@ -87,10 +113,39 @@ class FunctionalTest(testing.TestCase):
|
||||
pass
|
||||
|
||||
def test_passing_inputs_by_name(self):
|
||||
pass
|
||||
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
||||
input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
||||
x = input_a + input_b
|
||||
x = layers.Dense(5)(x)
|
||||
outputs = layers.Dense(4)(x)
|
||||
model = Functional([input_a, input_b], outputs)
|
||||
|
||||
# Eager call
|
||||
in_val = {"input_a": np.random.random((2, 3)), "input_b": np.random.random((2, 3))}
|
||||
out_val = model(in_val)
|
||||
self.assertEqual(out_val.shape, (2, 4))
|
||||
|
||||
# Symbolic call
|
||||
input_a_2 = Input(shape=(3,), batch_size=2, name="input_a_2")
|
||||
input_b_2 = Input(shape=(3,), batch_size=2, name="input_b_2")
|
||||
in_val = {"input_a": input_a_2, "input_b": input_b_2}
|
||||
out_val = model(in_val)
|
||||
self.assertEqual(out_val.shape, (2, 4))
|
||||
|
||||
def test_rank_standardization(self):
|
||||
pass
|
||||
# Downranking
|
||||
inputs = Input(shape=(3,), batch_size=2)
|
||||
outputs = layers.Dense(3)(inputs)
|
||||
model = Functional(inputs, outputs)
|
||||
out_val = model(np.random.random((2, 3, 1)))
|
||||
self.assertEqual(out_val.shape, (2, 3))
|
||||
|
||||
# Upranking
|
||||
inputs = Input(shape=(3, 1), batch_size=2)
|
||||
outputs = layers.Dense(3)(inputs)
|
||||
model = Functional(inputs, outputs)
|
||||
out_val = model(np.random.random((2, 3)))
|
||||
self.assertEqual(out_val.shape, (2, 3, 3))
|
||||
|
||||
def test_serialization(self):
|
||||
# TODO
|
||||
|
Loading…
Reference in New Issue
Block a user