Refactor print_summary.

This commit is contained in:
Francois Chollet 2017-02-19 17:13:31 -08:00
parent f0e0527591
commit e092fa8b57
2 changed files with 42 additions and 13 deletions

@ -2324,10 +2324,9 @@ class Container(Layer):
import yaml
return yaml.dump(self._updated_config(), **kwargs)
def summary(self, line_length=100, positions=[.33, .55, .67, 1.]):
def summary(self, line_length=None, positions=None):
from keras.utils.layer_utils import print_summary
print_summary(self.layers,
getattr(self, 'container_nodes', None),
print_summary(self,
line_length=line_length,
positions=positions)

@ -29,13 +29,11 @@ def layer_from_config(config, custom_objects=None):
printable_module_name='layer')
def print_summary(layers, relevant_nodes=None,
line_length=100, positions=None):
def print_summary(model, line_length=None, positions=None):
"""Prints a summary of a layer.
# Arguments
layers: list of layers to print summaries of
relevant_nodes: list of relevant nodes
model: Keras model instance.
line_length: total length of printed lines
positions: relative or absolute positions of log elements in each line.
If not provided, defaults to `[.33, .55, .67, 1.]`.
@ -43,11 +41,29 @@ def print_summary(layers, relevant_nodes=None,
# TODO: don't print connectivity for sequential models
maybe change API to accept a model instance
"""
positions = positions or [.33, .55, .67, 1.]
if positions[-1] <= 1:
positions = [int(line_length * p) for p in positions]
# header names for the different log elements
to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
if isinstance(model, Sequential):
sequential_like = True
else:
sequential_like = True
for k, v in model.nodes_by_depth.items():
if len(v) > 1:
sequential_like = False
if sequential_like:
line_length = line_length or 65
positions = positions or [.45, .85, 1.]
if positions[-1] <= 1:
positions = [int(line_length * p) for p in positions]
# header names for the different log elements
to_display = ['Layer (type)', 'Output Shape', 'Param #']
else:
line_length = line_length or 100
positions = positions or [.33, .55, .67, 1.]
if positions[-1] <= 1:
positions = [int(line_length * p) for p in positions]
# header names for the different log elements
to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
def print_row(fields, positions):
line = ''
@ -64,6 +80,16 @@ def print_summary(layers, relevant_nodes=None,
print('=' * line_length)
def print_layer_summary(layer):
try:
output_shape = layer.output_shape
except AttributeError:
output_shape = 'multiple'
name = layer.name
cls_name = layer.__class__.__name__
fields = [name + ' (' + cls_name + ')', output_shape, layer.count_params()]
print_row(fields, positions)
def print_layer_summary_with_connections(layer):
"""Prints a summary for a single layer.
# Arguments
@ -99,8 +125,12 @@ def print_summary(layers, relevant_nodes=None,
fields = ['', '', '', connections[i]]
print_row(fields, positions)
layers = model.layers
for i in range(len(layers)):
print_layer_summary(layers[i])
if sequential_like:
print_layer_summary(layers[i])
else:
print_layer_summary_with_connections(layers[i])
if i == len(layers) - 1:
print('=' * line_length)
else: