Extend support toSequential._maybe_build
for functional/sequential model as first layer and fix .summary()
(#20002)
This commit is contained in:
parent
9f7e8050f1
commit
902f9da309
@ -137,6 +137,12 @@ class Sequential(Model):
|
||||
if isinstance(self._layers[0], InputLayer) and len(self._layers) > 1:
|
||||
input_shape = self._layers[0].batch_shape
|
||||
self.build(input_shape)
|
||||
elif hasattr(self._layers[0], "input_shape") and len(self._layers) > 1:
|
||||
# We can build the Sequential model if the first layer has the
|
||||
# `input_shape` property. This is most commonly found in Functional
|
||||
# model.
|
||||
input_shape = self._layers[0].input_shape
|
||||
self.build(input_shape)
|
||||
|
||||
def _lock_state(self):
|
||||
# Unlike other layers, Sequential is mutable after build.
|
||||
|
@ -150,6 +150,58 @@ class SequentialTest(testing.TestCase):
|
||||
y = model(x)
|
||||
self.assertEqual(y.shape, (2, 3, 4))
|
||||
|
||||
def test_basic_flow_with_functional_model_as_first_layer(self):
|
||||
# Build functional model
|
||||
inputs = Input((16, 16, 3))
|
||||
outputs = layers.Conv2D(4, 3, padding="same")(inputs)
|
||||
functional_model = Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
model = Sequential(
|
||||
[functional_model, layers.Flatten(), layers.Dense(1)]
|
||||
)
|
||||
model.summary()
|
||||
self.assertEqual(len(model.layers), 3)
|
||||
self.assertTrue(model.built)
|
||||
for layer in model.layers:
|
||||
self.assertTrue(layer.built)
|
||||
|
||||
# Test eager call
|
||||
x = np.random.random((1, 16, 16, 3))
|
||||
y = model(x)
|
||||
self.assertEqual(type(model._functional), Functional)
|
||||
self.assertEqual(tuple(y.shape), (1, 1))
|
||||
|
||||
# Test symbolic call
|
||||
x = backend.KerasTensor((1, 16, 16, 3))
|
||||
y = model(x)
|
||||
self.assertEqual(y.shape, (1, 1))
|
||||
|
||||
def test_basic_flow_with_sequential_model_as_first_layer(self):
|
||||
# Build sequential model
|
||||
sequential_model = Sequential(
|
||||
[Input((16, 16, 3)), layers.Conv2D(4, 3, padding="same")]
|
||||
)
|
||||
|
||||
model = Sequential(
|
||||
[sequential_model, layers.Flatten(), layers.Dense(1)]
|
||||
)
|
||||
model.summary()
|
||||
self.assertEqual(len(model.layers), 3)
|
||||
self.assertTrue(model.built)
|
||||
for layer in model.layers:
|
||||
self.assertTrue(layer.built)
|
||||
|
||||
# Test eager call
|
||||
x = np.random.random((1, 16, 16, 3))
|
||||
y = model(x)
|
||||
self.assertEqual(type(model._functional), Functional)
|
||||
self.assertEqual(tuple(y.shape), (1, 1))
|
||||
|
||||
# Test symbolic call
|
||||
x = backend.KerasTensor((1, 16, 16, 3))
|
||||
y = model(x)
|
||||
self.assertEqual(y.shape, (1, 1))
|
||||
|
||||
def test_dict_inputs(self):
|
||||
class DictLayer(layers.Layer):
|
||||
def call(self, inputs):
|
||||
|
@ -96,12 +96,15 @@ def format_layer_shape(layer):
|
||||
)
|
||||
else:
|
||||
try:
|
||||
outputs = layer.compute_output_shape(**layer._build_shapes_dict)
|
||||
if hasattr(layer, "output_shape"):
|
||||
output_shapes = layer.output_shape
|
||||
else:
|
||||
outputs = layer.compute_output_shape(**layer._build_shapes_dict)
|
||||
output_shapes = tree.map_shape_structure(
|
||||
lambda x: format_shape(x), outputs
|
||||
)
|
||||
except NotImplementedError:
|
||||
return "?"
|
||||
output_shapes = tree.map_shape_structure(
|
||||
lambda x: format_shape(x), outputs
|
||||
)
|
||||
if len(output_shapes) == 1:
|
||||
return output_shapes[0]
|
||||
out = str(output_shapes)
|
||||
|
Loading…
Reference in New Issue
Block a user