Generalize handling of compute_output_shape
This commit is contained in:
parent
8cc8ed3cd5
commit
9a4a063158
@ -34,6 +34,7 @@ from keras_core.metrics.metric import Metric
|
||||
from keras_core.operations.operation import Operation
|
||||
from keras_core.utils import summary_utils
|
||||
from keras_core.utils import traceback_utils
|
||||
from keras_core.utils.shape_utils import map_shape_structure
|
||||
from keras_core.utils.tracking import Tracker
|
||||
|
||||
|
||||
@ -601,22 +602,10 @@ class Layer(Operation):
|
||||
):
|
||||
return KerasTensor(output_shape, dtype=self.compute_dtype)
|
||||
# Case: nested. Could be a tuple/list of shapes, or a dict of
|
||||
# shapes.
|
||||
if isinstance(output_shape, list):
|
||||
return [
|
||||
KerasTensor(s, dtype=self.compute_dtype)
|
||||
for s in output_shape
|
||||
]
|
||||
if isinstance(output_shape, tuple):
|
||||
return tuple(
|
||||
KerasTensor(s, dtype=self.compute_dtype)
|
||||
for s in output_shape
|
||||
)
|
||||
if isinstance(output_shape, dict):
|
||||
return {
|
||||
name: KerasTensor(s, dtype=self.compute_dtype)
|
||||
for name, s in output_shape.items()
|
||||
}
|
||||
# shapes. Could be deeply nested.
|
||||
return map_shape_structure(
|
||||
lambda s: KerasTensor(s, dtype=self.compute_dtype), output_shape
|
||||
)
|
||||
|
||||
@utils.default
|
||||
def compute_output_shape(self, *args, **kwargs):
|
||||
|
@ -8,6 +8,115 @@ from keras_core import testing
|
||||
|
||||
|
||||
class LayerTest(testing.TestCase):
|
||||
def test_compute_output_spec(self):
|
||||
# Test that implementing compute_output_shape
|
||||
# is enough to make compute_output_spec work.
|
||||
|
||||
# Case: single output
|
||||
class TestLayer(layers.Layer):
|
||||
def call(self, x):
|
||||
assert False # Should never be called.
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
||||
layer = TestLayer()
|
||||
self.assertEqual(
|
||||
layer.compute_output_spec(backend.KerasTensor((2, 3))).shape, (2, 3)
|
||||
)
|
||||
|
||||
# Case: tuple output
|
||||
class TestLayer(layers.Layer):
|
||||
def call(self, x):
|
||||
assert False # Should never be called.
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return (input_shape, input_shape)
|
||||
|
||||
layer = TestLayer()
|
||||
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
|
||||
self.assertTrue(isinstance(out, tuple))
|
||||
self.assertEqual(len(out), 2)
|
||||
self.assertEqual(out[0].shape, (2, 3))
|
||||
self.assertEqual(out[1].shape, (2, 3))
|
||||
|
||||
# Case: list output
|
||||
class TestLayer(layers.Layer):
|
||||
def call(self, x):
|
||||
assert False # Should never be called.
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return [input_shape, input_shape]
|
||||
|
||||
layer = TestLayer()
|
||||
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
|
||||
self.assertTrue(isinstance(out, list))
|
||||
self.assertEqual(len(out), 2)
|
||||
self.assertEqual(out[0].shape, (2, 3))
|
||||
self.assertEqual(out[1].shape, (2, 3))
|
||||
|
||||
# Case: dict output
|
||||
class TestLayer(layers.Layer):
|
||||
def call(self, x):
|
||||
assert False # Should never be called.
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return {"1": input_shape, "2": input_shape}
|
||||
|
||||
layer = TestLayer()
|
||||
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
|
||||
self.assertTrue(isinstance(out, dict))
|
||||
self.assertEqual(len(out), 2)
|
||||
self.assertEqual(out["1"].shape, (2, 3))
|
||||
self.assertEqual(out["2"].shape, (2, 3))
|
||||
|
||||
# Case: nested tuple output
|
||||
class TestLayer(layers.Layer):
|
||||
def call(self, x):
|
||||
assert False # Should never be called.
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return (
|
||||
input_shape,
|
||||
(input_shape, input_shape),
|
||||
(input_shape, input_shape),
|
||||
)
|
||||
|
||||
layer = TestLayer()
|
||||
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
|
||||
self.assertTrue(isinstance(out, tuple))
|
||||
self.assertEqual(len(out), 3)
|
||||
self.assertEqual(out[0].shape, (2, 3))
|
||||
self.assertTrue(isinstance(out[1], tuple))
|
||||
self.assertEqual(len(out[1]), 2)
|
||||
self.assertEqual(out[1][0].shape, (2, 3))
|
||||
self.assertEqual(out[1][1].shape, (2, 3))
|
||||
self.assertTrue(isinstance(out[2], tuple))
|
||||
self.assertEqual(len(out[2]), 2)
|
||||
self.assertEqual(out[2][0].shape, (2, 3))
|
||||
self.assertEqual(out[2][1].shape, (2, 3))
|
||||
|
||||
# Case: nested dict output
|
||||
class TestLayer(layers.Layer):
|
||||
def call(self, x):
|
||||
assert False # Should never be called.
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return {
|
||||
"1": input_shape,
|
||||
"2": {"11": input_shape, "22": input_shape},
|
||||
}
|
||||
|
||||
layer = TestLayer()
|
||||
out = layer.compute_output_spec(backend.KerasTensor((2, 3)))
|
||||
self.assertTrue(isinstance(out, dict))
|
||||
self.assertEqual(len(out), 2)
|
||||
self.assertEqual(out["1"].shape, (2, 3))
|
||||
self.assertTrue(isinstance(out["2"], dict))
|
||||
self.assertEqual(len(out["2"]), 2)
|
||||
self.assertEqual(out["2"]["11"].shape, (2, 3))
|
||||
self.assertEqual(out["2"]["22"].shape, (2, 3))
|
||||
|
||||
def test_positional_arg_error(self):
|
||||
class SomeLayer(layers.Layer):
|
||||
def call(self, x, bool_arg):
|
||||
|
@ -1,4 +1,3 @@
|
||||
from keras_core import backend
|
||||
from keras_core import initializers
|
||||
from keras_core import operations as ops
|
||||
from keras_core.api_export import keras_core_export
|
||||
|
Loading…
Reference in New Issue
Block a user