From 9a4a063158836dd0df8faec959f5c311c80ebe7f Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Sun, 14 May 2023 18:04:38 -0700 Subject: [PATCH] Generalize handling of compute_output_shape --- keras_core/layers/layer.py | 21 +--- keras_core/layers/layer_test.py | 109 ++++++++++++++++++ .../normalization/spectral_normalization.py | 1 - 3 files changed, 114 insertions(+), 17 deletions(-) diff --git a/keras_core/layers/layer.py b/keras_core/layers/layer.py index ac041be21..72d231316 100644 --- a/keras_core/layers/layer.py +++ b/keras_core/layers/layer.py @@ -34,6 +34,7 @@ from keras_core.metrics.metric import Metric from keras_core.operations.operation import Operation from keras_core.utils import summary_utils from keras_core.utils import traceback_utils +from keras_core.utils.shape_utils import map_shape_structure from keras_core.utils.tracking import Tracker @@ -601,22 +602,10 @@ class Layer(Operation): ): return KerasTensor(output_shape, dtype=self.compute_dtype) # Case: nested. Could be a tuple/list of shapes, or a dict of - # shapes. - if isinstance(output_shape, list): - return [ - KerasTensor(s, dtype=self.compute_dtype) - for s in output_shape - ] - if isinstance(output_shape, tuple): - return tuple( - KerasTensor(s, dtype=self.compute_dtype) - for s in output_shape - ) - if isinstance(output_shape, dict): - return { - name: KerasTensor(s, dtype=self.compute_dtype) - for name, s in output_shape.items() - } + # shapes. Could be deeply nested. + return map_shape_structure( + lambda s: KerasTensor(s, dtype=self.compute_dtype), output_shape + ) @utils.default def compute_output_shape(self, *args, **kwargs): diff --git a/keras_core/layers/layer_test.py b/keras_core/layers/layer_test.py index 75e3fced1..806f78deb 100644 --- a/keras_core/layers/layer_test.py +++ b/keras_core/layers/layer_test.py @@ -8,6 +8,115 @@ from keras_core import testing class LayerTest(testing.TestCase): + def test_compute_output_spec(self): + # Test that implementing compute_output_shape + # is enough to make compute_output_spec work. + + # Case: single output + class TestLayer(layers.Layer): + def call(self, x): + assert False # Should never be called. + + def compute_output_shape(self, input_shape): + return input_shape + + layer = TestLayer() + self.assertEqual( + layer.compute_output_spec(backend.KerasTensor((2, 3))).shape, (2, 3) + ) + + # Case: tuple output + class TestLayer(layers.Layer): + def call(self, x): + assert False # Should never be called. + + def compute_output_shape(self, input_shape): + return (input_shape, input_shape) + + layer = TestLayer() + out = layer.compute_output_spec(backend.KerasTensor((2, 3))) + self.assertTrue(isinstance(out, tuple)) + self.assertEqual(len(out), 2) + self.assertEqual(out[0].shape, (2, 3)) + self.assertEqual(out[1].shape, (2, 3)) + + # Case: list output + class TestLayer(layers.Layer): + def call(self, x): + assert False # Should never be called. + + def compute_output_shape(self, input_shape): + return [input_shape, input_shape] + + layer = TestLayer() + out = layer.compute_output_spec(backend.KerasTensor((2, 3))) + self.assertTrue(isinstance(out, list)) + self.assertEqual(len(out), 2) + self.assertEqual(out[0].shape, (2, 3)) + self.assertEqual(out[1].shape, (2, 3)) + + # Case: dict output + class TestLayer(layers.Layer): + def call(self, x): + assert False # Should never be called. + + def compute_output_shape(self, input_shape): + return {"1": input_shape, "2": input_shape} + + layer = TestLayer() + out = layer.compute_output_spec(backend.KerasTensor((2, 3))) + self.assertTrue(isinstance(out, dict)) + self.assertEqual(len(out), 2) + self.assertEqual(out["1"].shape, (2, 3)) + self.assertEqual(out["2"].shape, (2, 3)) + + # Case: nested tuple output + class TestLayer(layers.Layer): + def call(self, x): + assert False # Should never be called. + + def compute_output_shape(self, input_shape): + return ( + input_shape, + (input_shape, input_shape), + (input_shape, input_shape), + ) + + layer = TestLayer() + out = layer.compute_output_spec(backend.KerasTensor((2, 3))) + self.assertTrue(isinstance(out, tuple)) + self.assertEqual(len(out), 3) + self.assertEqual(out[0].shape, (2, 3)) + self.assertTrue(isinstance(out[1], tuple)) + self.assertEqual(len(out[1]), 2) + self.assertEqual(out[1][0].shape, (2, 3)) + self.assertEqual(out[1][1].shape, (2, 3)) + self.assertTrue(isinstance(out[2], tuple)) + self.assertEqual(len(out[2]), 2) + self.assertEqual(out[2][0].shape, (2, 3)) + self.assertEqual(out[2][1].shape, (2, 3)) + + # Case: nested dict output + class TestLayer(layers.Layer): + def call(self, x): + assert False # Should never be called. + + def compute_output_shape(self, input_shape): + return { + "1": input_shape, + "2": {"11": input_shape, "22": input_shape}, + } + + layer = TestLayer() + out = layer.compute_output_spec(backend.KerasTensor((2, 3))) + self.assertTrue(isinstance(out, dict)) + self.assertEqual(len(out), 2) + self.assertEqual(out["1"].shape, (2, 3)) + self.assertTrue(isinstance(out["2"], dict)) + self.assertEqual(len(out["2"]), 2) + self.assertEqual(out["2"]["11"].shape, (2, 3)) + self.assertEqual(out["2"]["22"].shape, (2, 3)) + def test_positional_arg_error(self): class SomeLayer(layers.Layer): def call(self, x, bool_arg): diff --git a/keras_core/layers/normalization/spectral_normalization.py b/keras_core/layers/normalization/spectral_normalization.py index 7508353f9..ea7a6c02f 100644 --- a/keras_core/layers/normalization/spectral_normalization.py +++ b/keras_core/layers/normalization/spectral_normalization.py @@ -1,4 +1,3 @@ -from keras_core import backend from keras_core import initializers from keras_core import operations as ops from keras_core.api_export import keras_core_export