From 4fd0029ad7b9b6b7a6e0d69ca7550c8af38f17c2 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 12 Apr 2023 15:41:35 -0700 Subject: [PATCH] More extensive Functional model test coverage. --- keras_core/models/functional.py | 17 ++++++++ keras_core/models/functional_test.py | 61 ++++++++++++++++++++++++++-- 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/keras_core/models/functional.py b/keras_core/models/functional.py index 055969f89..5b5a4d62f 100644 --- a/keras_core/models/functional.py +++ b/keras_core/models/functional.py @@ -2,6 +2,7 @@ import warnings from tensorflow import nest +from keras_core import backend from keras_core import operations as ops from keras_core.layers.layer import Layer from keras_core.models.model import Model @@ -32,6 +33,19 @@ class Functional(Function, Model): skip_init = kwargs.pop("skip_init", False) if skip_init: return + if isinstance(inputs, dict): + for k, v in inputs.items(): + if not isinstance(v, backend.KerasTensor): + raise ValueError( + "When providing an input dict, all values in the dict " + f"must be KerasTensors. Received: inputs={inputs} including " + f"invalid value {v} of type {type(v)}") + if k != v.name: + raise ValueError( + "When providing an input dict, all keys in the dict " + "must match the names of the corresponding tensors. " + f"Received key '{k}' mapping to value {v} which has name '{v.name}'. " + f"Change the tensor name to '{k}' (via `Input(..., name='{k}')`)") super().__init__(inputs, outputs, name=name, **kwargs) self._layers = self.layers self.built = True @@ -59,6 +73,9 @@ class Functional(Function, Model): def compute_output_spec(self, inputs, training=False, mask=None): return super().compute_output_spec(inputs) + + def _assert_input_compatibility(self, *args): + return super(Model, self)._assert_input_compatibility(*args) def _flatten_to_reference_inputs(self, inputs): if isinstance(inputs, dict): diff --git a/keras_core/models/functional_test.py b/keras_core/models/functional_test.py index b7c063493..6964b04e7 100644 --- a/keras_core/models/functional_test.py +++ b/keras_core/models/functional_test.py @@ -2,7 +2,6 @@ import numpy as np from keras_core import backend from keras_core import layers -from keras_core import operations as ops from keras_core import testing from keras_core.layers.core.input_layer import Input from keras_core.models.functional import Functional @@ -51,6 +50,33 @@ class FunctionalTest(testing.TestCase): self.assertEqual(out_val[0].shape, (2, 4)) self.assertEqual(out_val[1].shape, (2, 5)) + def test_basic_flow_dict_io(self): + input_a = Input(shape=(3,), batch_size=2, name="a") + input_b = Input(shape=(3,), batch_size=2, name="b") + x = input_a + input_b + x = layers.Dense(5)(x) + outputs = layers.Dense(4)(x) + + with self.assertRaisesRegex(ValueError, "all values in the dict must be KerasTensors"): + model = Functional({"aa": [input_a], "bb": input_b}, outputs) + + with self.assertRaisesRegex(ValueError, "all keys in the dict must match the names"): + model = Functional({"aa": input_a, "bb": input_b}, outputs) + + model = Functional({"a": input_a, "b": input_b}, outputs) + + # Eager call + in_val = {"a": np.random.random((2, 3)), "b": np.random.random((2, 3))} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + # Symbolic call + input_a_2 = Input(shape=(3,), batch_size=2) + input_b_2 = Input(shape=(3,), batch_size=2) + in_val = {"a": input_a_2, "b": input_b_2} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + def test_layer_getters(self): # Test mixing ops and layers input_a = Input(shape=(3,), batch_size=2, name="input_a") @@ -87,10 +113,39 @@ class FunctionalTest(testing.TestCase): pass def test_passing_inputs_by_name(self): - pass + input_a = Input(shape=(3,), batch_size=2, name="input_a") + input_b = Input(shape=(3,), batch_size=2, name="input_b") + x = input_a + input_b + x = layers.Dense(5)(x) + outputs = layers.Dense(4)(x) + model = Functional([input_a, input_b], outputs) + + # Eager call + in_val = {"input_a": np.random.random((2, 3)), "input_b": np.random.random((2, 3))} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + # Symbolic call + input_a_2 = Input(shape=(3,), batch_size=2, name="input_a_2") + input_b_2 = Input(shape=(3,), batch_size=2, name="input_b_2") + in_val = {"input_a": input_a_2, "input_b": input_b_2} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) def test_rank_standardization(self): - pass + # Downranking + inputs = Input(shape=(3,), batch_size=2) + outputs = layers.Dense(3)(inputs) + model = Functional(inputs, outputs) + out_val = model(np.random.random((2, 3, 1))) + self.assertEqual(out_val.shape, (2, 3)) + + # Upranking + inputs = Input(shape=(3, 1), batch_size=2) + outputs = layers.Dense(3)(inputs) + model = Functional(inputs, outputs) + out_val = model(np.random.random((2, 3))) + self.assertEqual(out_val.shape, (2, 3, 3)) def test_serialization(self): # TODO