Refactor print_summary.
This commit is contained in:
parent
f0e0527591
commit
e092fa8b57
@ -2324,10 +2324,9 @@ class Container(Layer):
|
||||
import yaml
|
||||
return yaml.dump(self._updated_config(), **kwargs)
|
||||
|
||||
def summary(self, line_length=100, positions=[.33, .55, .67, 1.]):
|
||||
def summary(self, line_length=None, positions=None):
|
||||
from keras.utils.layer_utils import print_summary
|
||||
print_summary(self.layers,
|
||||
getattr(self, 'container_nodes', None),
|
||||
print_summary(self,
|
||||
line_length=line_length,
|
||||
positions=positions)
|
||||
|
||||
|
@ -29,13 +29,11 @@ def layer_from_config(config, custom_objects=None):
|
||||
printable_module_name='layer')
|
||||
|
||||
|
||||
def print_summary(layers, relevant_nodes=None,
|
||||
line_length=100, positions=None):
|
||||
def print_summary(model, line_length=None, positions=None):
|
||||
"""Prints a summary of a layer.
|
||||
|
||||
# Arguments
|
||||
layers: list of layers to print summaries of
|
||||
relevant_nodes: list of relevant nodes
|
||||
model: Keras model instance.
|
||||
line_length: total length of printed lines
|
||||
positions: relative or absolute positions of log elements in each line.
|
||||
If not provided, defaults to `[.33, .55, .67, 1.]`.
|
||||
@ -43,11 +41,29 @@ def print_summary(layers, relevant_nodes=None,
|
||||
# TODO: don't print connectivity for sequential models
|
||||
maybe change API to accept a model instance
|
||||
"""
|
||||
positions = positions or [.33, .55, .67, 1.]
|
||||
if positions[-1] <= 1:
|
||||
positions = [int(line_length * p) for p in positions]
|
||||
# header names for the different log elements
|
||||
to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
|
||||
|
||||
if isinstance(model, Sequential):
|
||||
sequential_like = True
|
||||
else:
|
||||
sequential_like = True
|
||||
for k, v in model.nodes_by_depth.items():
|
||||
if len(v) > 1:
|
||||
sequential_like = False
|
||||
|
||||
if sequential_like:
|
||||
line_length = line_length or 65
|
||||
positions = positions or [.45, .85, 1.]
|
||||
if positions[-1] <= 1:
|
||||
positions = [int(line_length * p) for p in positions]
|
||||
# header names for the different log elements
|
||||
to_display = ['Layer (type)', 'Output Shape', 'Param #']
|
||||
else:
|
||||
line_length = line_length or 100
|
||||
positions = positions or [.33, .55, .67, 1.]
|
||||
if positions[-1] <= 1:
|
||||
positions = [int(line_length * p) for p in positions]
|
||||
# header names for the different log elements
|
||||
to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
|
||||
|
||||
def print_row(fields, positions):
|
||||
line = ''
|
||||
@ -64,6 +80,16 @@ def print_summary(layers, relevant_nodes=None,
|
||||
print('=' * line_length)
|
||||
|
||||
def print_layer_summary(layer):
|
||||
try:
|
||||
output_shape = layer.output_shape
|
||||
except AttributeError:
|
||||
output_shape = 'multiple'
|
||||
name = layer.name
|
||||
cls_name = layer.__class__.__name__
|
||||
fields = [name + ' (' + cls_name + ')', output_shape, layer.count_params()]
|
||||
print_row(fields, positions)
|
||||
|
||||
def print_layer_summary_with_connections(layer):
|
||||
"""Prints a summary for a single layer.
|
||||
|
||||
# Arguments
|
||||
@ -99,8 +125,12 @@ def print_summary(layers, relevant_nodes=None,
|
||||
fields = ['', '', '', connections[i]]
|
||||
print_row(fields, positions)
|
||||
|
||||
layers = model.layers
|
||||
for i in range(len(layers)):
|
||||
print_layer_summary(layers[i])
|
||||
if sequential_like:
|
||||
print_layer_summary(layers[i])
|
||||
else:
|
||||
print_layer_summary_with_connections(layers[i])
|
||||
if i == len(layers) - 1:
|
||||
print('=' * line_length)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user