Summary touch ups

This commit is contained in:
Francois Chollet 2023-04-13 14:07:42 -07:00
parent fe87e2bf05
commit 395df51ab7
2 changed files with 32 additions and 56 deletions

@ -145,12 +145,12 @@ def print_summary(
if sequential_like:
line_length = line_length or 84
positions = positions or [0.45, 0.85, 1.0]
positions = positions or [0.45, 0.84, 1.0]
# header names for the different log elements
header = ["Layer (type)", "Output Shape", "Param #"]
else:
line_length = line_length or 100
positions = positions or [0.3, 0.6, 0.70, 1.0]
line_length = line_length or 108
positions = positions or [0.3, 0.56, 0.70, 1.0]
# header names for the different log elements
header = ["Layer (type)", "Output Shape", "Param #", "Connected to"]
relevant_nodes = []
@ -162,20 +162,14 @@ def print_summary(
positions = [p * 0.86 for p in positions] + [1.0]
header.append("Trainable")
layer_range = get_layer_index_bound_by_layer_name(layers, layer_range)
print_fn(text_rendering.highlight_msg(f' Model: "{model.name}"'))
rows = []
def get_layer_fields(layer, prefix=""):
output_shape = format_layer_shape(layer)
name = prefix + layer.name
cls_name = layer.__class__.__name__
if not getattr(layer, "built", False):
# If a subclassed model has a layer that is not called in
# Model.call, the layer will not be built and we cannot call
# layer.count_params().
params = "0 (unused)"
if not hasattr(layer, "built"):
params = "0"
elif not layer.built:
params = "0 (unbuilt)"
else:
params = layer.count_params()
fields = [name + " (" + cls_name + ")", output_shape, str(params)]
@ -184,29 +178,8 @@ def print_summary(
fields.append("Y" if layer.trainable else "N")
return fields
def print_layer_summary(layer, prefix=""):
"""Prints a summary for a single layer.
Args:
layer: target layer.
nested_level: level of nesting of the layer inside its parent layer
(e.g. 0 for a top-level layer, 1 for a nested layer).
"""
fields = get_layer_fields(layer, prefix=prefix)
if show_trainable:
fields.append("Y" if layer.trainable else "N")
rows.append(fields)
def print_layer_summary_with_connections(layer, prefix=""):
"""Prints a summary for a single layer (including its connections).
Args:
layer: target layer.
nested_level: level of nesting of the layer inside its parent layer
(e.g. 0 for a top-level layer, 1 for a nested layer).
"""
fields = get_layer_fields(layer, prefix=prefix)
connections = []
def get_connections(layer):
connections = ""
for node in layer._inbound_nodes:
if relevant_nodes and node not in relevant_nodes:
# node is not part of the current network
@ -216,34 +189,42 @@ def print_summary(
inbound_layer = keras_history.operation
node_index = keras_history.node_index
tensor_index = keras_history.tensor_index
connections.append(
if connections:
connections += ", "
connections += (
f"{inbound_layer.name}[{node_index}][{tensor_index}]"
)
if not connections:
connections = "-"
fields.append(connections)
if show_trainable:
fields.append("Y" if layer.trainable else "N")
rows.append(fields)
return connections
def print_layer(layer, nested_level=0):
if nested_level:
prefix = " " * nested_level + "" + " "
else:
prefix = ""
if sequential_like:
print_layer_summary(layer, prefix=prefix)
else:
print_layer_summary_with_connections(layer, prefix=prefix)
fields = get_layer_fields(layer, prefix=prefix)
if not sequential_like:
fields.append(get_connections(layer))
if show_trainable:
fields.append("Y" if layer.trainable else "N")
rows = [fields]
if expand_nested and hasattr(layer, "layers") and layer.layers:
nested_layers = layer.layers
nested_level += 1
for i in range(len(nested_layers)):
print_layer(nested_layers[i], nested_level=nested_level)
rows.extend(
print_layer(nested_layers[i], nested_level=nested_level)
)
return rows
layer_range = get_layer_index_bound_by_layer_name(layers, layer_range)
print_fn(text_rendering.highlight_msg(f' Model: "{model.name}"'))
rows = []
for layer in layers[layer_range[0] : layer_range[1]]:
print_layer(layer)
rows.extend(print_layer(layer))
# Render summary as a table.
table = text_rendering.TextTable(
@ -254,13 +235,8 @@ def print_summary(
alignments=["left"] + ["center" for _ in range(len(header) - 1)],
max_line_length=line_length,
)
try:
table_str = table.make()
print_fn(table_str)
except UnicodeEncodeError:
printable = set(string.printable)
table_str = filter(lambda x: x in printable, table_str)
print_fn(table_str)
table_str = table.make()
print_fn(table_str)
# After the table, append information about parameter count and size.
if hasattr(model, "_collected_trainable_weights"):

@ -63,8 +63,8 @@ class TextTable:
):
alignments = alignments or ["center" for _ in fields]
lines = []
line_break_chars_post = (")", "}", "]")
line_break_chars_pre = ("(", "{", "[")
line_break_chars_post = ("),", "],")
line_break_chars_pre = ("(", "[")
for field, width, alignment in zip(
fields, self.column_widths, alignments
):