Extend support toSequential._maybe_build for functional/sequential model as first layer and fix .summary() (#20002)

This commit is contained in:
Hongyu, Chiu 2024-07-18 03:42:39 +08:00 committed by GitHub
parent 9f7e8050f1
commit 902f9da309
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 4 deletions

@ -137,6 +137,12 @@ class Sequential(Model):
if isinstance(self._layers[0], InputLayer) and len(self._layers) > 1:
input_shape = self._layers[0].batch_shape
self.build(input_shape)
elif hasattr(self._layers[0], "input_shape") and len(self._layers) > 1:
# We can build the Sequential model if the first layer has the
# `input_shape` property. This is most commonly found in Functional
# model.
input_shape = self._layers[0].input_shape
self.build(input_shape)
def _lock_state(self):
# Unlike other layers, Sequential is mutable after build.

@ -150,6 +150,58 @@ class SequentialTest(testing.TestCase):
y = model(x)
self.assertEqual(y.shape, (2, 3, 4))
def test_basic_flow_with_functional_model_as_first_layer(self):
# Build functional model
inputs = Input((16, 16, 3))
outputs = layers.Conv2D(4, 3, padding="same")(inputs)
functional_model = Model(inputs=inputs, outputs=outputs)
model = Sequential(
[functional_model, layers.Flatten(), layers.Dense(1)]
)
model.summary()
self.assertEqual(len(model.layers), 3)
self.assertTrue(model.built)
for layer in model.layers:
self.assertTrue(layer.built)
# Test eager call
x = np.random.random((1, 16, 16, 3))
y = model(x)
self.assertEqual(type(model._functional), Functional)
self.assertEqual(tuple(y.shape), (1, 1))
# Test symbolic call
x = backend.KerasTensor((1, 16, 16, 3))
y = model(x)
self.assertEqual(y.shape, (1, 1))
def test_basic_flow_with_sequential_model_as_first_layer(self):
# Build sequential model
sequential_model = Sequential(
[Input((16, 16, 3)), layers.Conv2D(4, 3, padding="same")]
)
model = Sequential(
[sequential_model, layers.Flatten(), layers.Dense(1)]
)
model.summary()
self.assertEqual(len(model.layers), 3)
self.assertTrue(model.built)
for layer in model.layers:
self.assertTrue(layer.built)
# Test eager call
x = np.random.random((1, 16, 16, 3))
y = model(x)
self.assertEqual(type(model._functional), Functional)
self.assertEqual(tuple(y.shape), (1, 1))
# Test symbolic call
x = backend.KerasTensor((1, 16, 16, 3))
y = model(x)
self.assertEqual(y.shape, (1, 1))
def test_dict_inputs(self):
class DictLayer(layers.Layer):
def call(self, inputs):

@ -96,12 +96,15 @@ def format_layer_shape(layer):
)
else:
try:
outputs = layer.compute_output_shape(**layer._build_shapes_dict)
if hasattr(layer, "output_shape"):
output_shapes = layer.output_shape
else:
outputs = layer.compute_output_shape(**layer._build_shapes_dict)
output_shapes = tree.map_shape_structure(
lambda x: format_shape(x), outputs
)
except NotImplementedError:
return "?"
output_shapes = tree.map_shape_structure(
lambda x: format_shape(x), outputs
)
if len(output_shapes) == 1:
return output_shapes[0]
out = str(output_shapes)