More extensive Functional model test coverage.

This commit is contained in:
Francois Chollet 2023-04-12 15:41:35 -07:00
parent d663d5bcc1
commit 4fd0029ad7
2 changed files with 75 additions and 3 deletions

@ -2,6 +2,7 @@ import warnings
from tensorflow import nest
from keras_core import backend
from keras_core import operations as ops
from keras_core.layers.layer import Layer
from keras_core.models.model import Model
@ -32,6 +33,19 @@ class Functional(Function, Model):
skip_init = kwargs.pop("skip_init", False)
if skip_init:
return
if isinstance(inputs, dict):
for k, v in inputs.items():
if not isinstance(v, backend.KerasTensor):
raise ValueError(
"When providing an input dict, all values in the dict "
f"must be KerasTensors. Received: inputs={inputs} including "
f"invalid value {v} of type {type(v)}")
if k != v.name:
raise ValueError(
"When providing an input dict, all keys in the dict "
"must match the names of the corresponding tensors. "
f"Received key '{k}' mapping to value {v} which has name '{v.name}'. "
f"Change the tensor name to '{k}' (via `Input(..., name='{k}')`)")
super().__init__(inputs, outputs, name=name, **kwargs)
self._layers = self.layers
self.built = True
@ -59,6 +73,9 @@ class Functional(Function, Model):
def compute_output_spec(self, inputs, training=False, mask=None):
return super().compute_output_spec(inputs)
def _assert_input_compatibility(self, *args):
return super(Model, self)._assert_input_compatibility(*args)
def _flatten_to_reference_inputs(self, inputs):
if isinstance(inputs, dict):

@ -2,7 +2,6 @@ import numpy as np
from keras_core import backend
from keras_core import layers
from keras_core import operations as ops
from keras_core import testing
from keras_core.layers.core.input_layer import Input
from keras_core.models.functional import Functional
@ -51,6 +50,33 @@ class FunctionalTest(testing.TestCase):
self.assertEqual(out_val[0].shape, (2, 4))
self.assertEqual(out_val[1].shape, (2, 5))
def test_basic_flow_dict_io(self):
input_a = Input(shape=(3,), batch_size=2, name="a")
input_b = Input(shape=(3,), batch_size=2, name="b")
x = input_a + input_b
x = layers.Dense(5)(x)
outputs = layers.Dense(4)(x)
with self.assertRaisesRegex(ValueError, "all values in the dict must be KerasTensors"):
model = Functional({"aa": [input_a], "bb": input_b}, outputs)
with self.assertRaisesRegex(ValueError, "all keys in the dict must match the names"):
model = Functional({"aa": input_a, "bb": input_b}, outputs)
model = Functional({"a": input_a, "b": input_b}, outputs)
# Eager call
in_val = {"a": np.random.random((2, 3)), "b": np.random.random((2, 3))}
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))
# Symbolic call
input_a_2 = Input(shape=(3,), batch_size=2)
input_b_2 = Input(shape=(3,), batch_size=2)
in_val = {"a": input_a_2, "b": input_b_2}
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))
def test_layer_getters(self):
# Test mixing ops and layers
input_a = Input(shape=(3,), batch_size=2, name="input_a")
@ -87,10 +113,39 @@ class FunctionalTest(testing.TestCase):
pass
def test_passing_inputs_by_name(self):
pass
input_a = Input(shape=(3,), batch_size=2, name="input_a")
input_b = Input(shape=(3,), batch_size=2, name="input_b")
x = input_a + input_b
x = layers.Dense(5)(x)
outputs = layers.Dense(4)(x)
model = Functional([input_a, input_b], outputs)
# Eager call
in_val = {"input_a": np.random.random((2, 3)), "input_b": np.random.random((2, 3))}
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))
# Symbolic call
input_a_2 = Input(shape=(3,), batch_size=2, name="input_a_2")
input_b_2 = Input(shape=(3,), batch_size=2, name="input_b_2")
in_val = {"input_a": input_a_2, "input_b": input_b_2}
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))
def test_rank_standardization(self):
pass
# Downranking
inputs = Input(shape=(3,), batch_size=2)
outputs = layers.Dense(3)(inputs)
model = Functional(inputs, outputs)
out_val = model(np.random.random((2, 3, 1)))
self.assertEqual(out_val.shape, (2, 3))
# Upranking
inputs = Input(shape=(3, 1), batch_size=2)
outputs = layers.Dense(3)(inputs)
model = Functional(inputs, outputs)
out_val = model(np.random.random((2, 3)))
self.assertEqual(out_val.shape, (2, 3, 3))
def test_serialization(self):
# TODO