Generalize handling of compute_output_shape

This commit is contained in:
Francois Chollet 2023-05-14 18:04:38 -07:00
parent 8cc8ed3cd5
commit 9a4a063158
3 changed files with 114 additions and 17 deletions

@ -34,6 +34,7 @@ from keras_core.metrics.metric import Metric
from keras_core.operations.operation import Operation
from keras_core.utils import summary_utils
from keras_core.utils import traceback_utils
from keras_core.utils.shape_utils import map_shape_structure
from keras_core.utils.tracking import Tracker
@ -601,22 +602,10 @@ class Layer(Operation):
):
return KerasTensor(output_shape, dtype=self.compute_dtype)
# Case: nested. Could be a tuple/list of shapes, or a dict of
# shapes.
if isinstance(output_shape, list):
return [
KerasTensor(s, dtype=self.compute_dtype)
for s in output_shape
]
if isinstance(output_shape, tuple):
return tuple(
KerasTensor(s, dtype=self.compute_dtype)
for s in output_shape
# shapes. Could be deeply nested.
return map_shape_structure(
lambda s: KerasTensor(s, dtype=self.compute_dtype), output_shape
)
if isinstance(output_shape, dict):
return {
name: KerasTensor(s, dtype=self.compute_dtype)
for name, s in output_shape.items()
}
@utils.default
def compute_output_shape(self, *args, **kwargs):

@ -8,6 +8,115 @@ from keras_core import testing
class LayerTest(testing.TestCase):
def test_compute_output_spec(self):
# Test that implementing compute_output_shape
# is enough to make compute_output_spec work.
# Case: single output
class TestLayer(layers.Layer):
def call(self, x):
assert False # Should never be called.
def compute_output_shape(self, input_shape):
return input_shape
layer = TestLayer()
self.assertEqual(
layer.compute_output_spec(backend.KerasTensor((2, 3))).shape, (2, 3)
)
# Case: tuple output
class TestLayer(layers.Layer):
def call(self, x):
assert False # Should never be called.
def compute_output_shape(self, input_shape):
return (input_shape, input_shape)
layer = TestLayer()
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
self.assertTrue(isinstance(out, tuple))
self.assertEqual(len(out), 2)
self.assertEqual(out[0].shape, (2, 3))
self.assertEqual(out[1].shape, (2, 3))
# Case: list output
class TestLayer(layers.Layer):
def call(self, x):
assert False # Should never be called.
def compute_output_shape(self, input_shape):
return [input_shape, input_shape]
layer = TestLayer()
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
self.assertTrue(isinstance(out, list))
self.assertEqual(len(out), 2)
self.assertEqual(out[0].shape, (2, 3))
self.assertEqual(out[1].shape, (2, 3))
# Case: dict output
class TestLayer(layers.Layer):
def call(self, x):
assert False # Should never be called.
def compute_output_shape(self, input_shape):
return {"1": input_shape, "2": input_shape}
layer = TestLayer()
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
self.assertTrue(isinstance(out, dict))
self.assertEqual(len(out), 2)
self.assertEqual(out["1"].shape, (2, 3))
self.assertEqual(out["2"].shape, (2, 3))
# Case: nested tuple output
class TestLayer(layers.Layer):
def call(self, x):
assert False # Should never be called.
def compute_output_shape(self, input_shape):
return (
input_shape,
(input_shape, input_shape),
(input_shape, input_shape),
)
layer = TestLayer()
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
self.assertTrue(isinstance(out, tuple))
self.assertEqual(len(out), 3)
self.assertEqual(out[0].shape, (2, 3))
self.assertTrue(isinstance(out[1], tuple))
self.assertEqual(len(out[1]), 2)
self.assertEqual(out[1][0].shape, (2, 3))
self.assertEqual(out[1][1].shape, (2, 3))
self.assertTrue(isinstance(out[2], tuple))
self.assertEqual(len(out[2]), 2)
self.assertEqual(out[2][0].shape, (2, 3))
self.assertEqual(out[2][1].shape, (2, 3))
# Case: nested dict output
class TestLayer(layers.Layer):
def call(self, x):
assert False # Should never be called.
def compute_output_shape(self, input_shape):
return {
"1": input_shape,
"2": {"11": input_shape, "22": input_shape},
}
layer = TestLayer()
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
self.assertTrue(isinstance(out, dict))
self.assertEqual(len(out), 2)
self.assertEqual(out["1"].shape, (2, 3))
self.assertTrue(isinstance(out["2"], dict))
self.assertEqual(len(out["2"]), 2)
self.assertEqual(out["2"]["11"].shape, (2, 3))
self.assertEqual(out["2"]["22"].shape, (2, 3))
def test_positional_arg_error(self):
class SomeLayer(layers.Layer):
def call(self, x, bool_arg):

@ -1,4 +1,3 @@
from keras_core import backend
from keras_core import initializers
from keras_core import operations as ops
from keras_core.api_export import keras_core_export