From 8b8e0bc15f80994150a8cb5e71716c08e8a0e50c Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 12 Apr 2023 17:35:54 -0700 Subject: [PATCH] Improve Sequential test coverage --- keras_core/models/sequential.py | 2 +- keras_core/models/sequential_test.py | 30 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/keras_core/models/sequential.py b/keras_core/models/sequential.py index f6480b9cd..aeec4e523 100644 --- a/keras_core/models/sequential.py +++ b/keras_core/models/sequential.py @@ -40,7 +40,7 @@ class Sequential(Model): ): raise ValueError( f"Sequential model '{self.name}' has already been configured to " - f"use input shape {self._layers[0].batch_input_shape}. You cannot add " + f"use input shape {self._layers[0].batch_shape}. You cannot add " f"a different Input layer to it." ) diff --git a/keras_core/models/sequential_test.py b/keras_core/models/sequential_test.py index 7866ddc29..375f17695 100644 --- a/keras_core/models/sequential_test.py +++ b/keras_core/models/sequential_test.py @@ -94,5 +94,35 @@ class SequentialTest(testing.TestCase): def test_dict_inputs(self): pass + def test_errors(self): + # Trying to pass 2 Inputs + model = Sequential() + model.add(Input(shape=(2,), batch_size=3)) + with self.assertRaisesRegex(ValueError, "already been configured"): + model.add(Input(shape=(2,), batch_size=3)) + with self.assertRaisesRegex(ValueError, "already been configured"): + model.add(layers.InputLayer(shape=(2,), batch_size=3)) + + # Same name 2x + model = Sequential() + model.add(layers.Dense(2, name="dense")) + with self.assertRaisesRegex(ValueError, "should have unique names"): + model.add(layers.Dense(2, name="dense")) + + # No layers + model = Sequential() + x = np.random.random((3, 2)) + with self.assertRaisesRegex(ValueError, "no layers"): + model(x) + + # Build conflict + model = Sequential() + model.add(Input(shape=(2,), batch_size=3)) + model.add(layers.Dense(2)) + with self.assertRaisesRegex(ValueError, "already been configured"): + model.build((3, 4)) + # But this works + model.build((3, 2)) + def test_serialization(self): pass