Summary touch ups
This commit is contained in:
parent
fe87e2bf05
commit
395df51ab7
@ -145,12 +145,12 @@ def print_summary(
|
||||
|
||||
if sequential_like:
|
||||
line_length = line_length or 84
|
||||
positions = positions or [0.45, 0.85, 1.0]
|
||||
positions = positions or [0.45, 0.84, 1.0]
|
||||
# header names for the different log elements
|
||||
header = ["Layer (type)", "Output Shape", "Param #"]
|
||||
else:
|
||||
line_length = line_length or 100
|
||||
positions = positions or [0.3, 0.6, 0.70, 1.0]
|
||||
line_length = line_length or 108
|
||||
positions = positions or [0.3, 0.56, 0.70, 1.0]
|
||||
# header names for the different log elements
|
||||
header = ["Layer (type)", "Output Shape", "Param #", "Connected to"]
|
||||
relevant_nodes = []
|
||||
@ -162,20 +162,14 @@ def print_summary(
|
||||
positions = [p * 0.86 for p in positions] + [1.0]
|
||||
header.append("Trainable")
|
||||
|
||||
layer_range = get_layer_index_bound_by_layer_name(layers, layer_range)
|
||||
|
||||
print_fn(text_rendering.highlight_msg(f' Model: "{model.name}"'))
|
||||
rows = []
|
||||
|
||||
def get_layer_fields(layer, prefix=""):
|
||||
output_shape = format_layer_shape(layer)
|
||||
name = prefix + layer.name
|
||||
cls_name = layer.__class__.__name__
|
||||
if not getattr(layer, "built", False):
|
||||
# If a subclassed model has a layer that is not called in
|
||||
# Model.call, the layer will not be built and we cannot call
|
||||
# layer.count_params().
|
||||
params = "0 (unused)"
|
||||
if not hasattr(layer, "built"):
|
||||
params = "0"
|
||||
elif not layer.built:
|
||||
params = "0 (unbuilt)"
|
||||
else:
|
||||
params = layer.count_params()
|
||||
fields = [name + " (" + cls_name + ")", output_shape, str(params)]
|
||||
@ -184,29 +178,8 @@ def print_summary(
|
||||
fields.append("Y" if layer.trainable else "N")
|
||||
return fields
|
||||
|
||||
def print_layer_summary(layer, prefix=""):
|
||||
"""Prints a summary for a single layer.
|
||||
|
||||
Args:
|
||||
layer: target layer.
|
||||
nested_level: level of nesting of the layer inside its parent layer
|
||||
(e.g. 0 for a top-level layer, 1 for a nested layer).
|
||||
"""
|
||||
fields = get_layer_fields(layer, prefix=prefix)
|
||||
if show_trainable:
|
||||
fields.append("Y" if layer.trainable else "N")
|
||||
rows.append(fields)
|
||||
|
||||
def print_layer_summary_with_connections(layer, prefix=""):
|
||||
"""Prints a summary for a single layer (including its connections).
|
||||
|
||||
Args:
|
||||
layer: target layer.
|
||||
nested_level: level of nesting of the layer inside its parent layer
|
||||
(e.g. 0 for a top-level layer, 1 for a nested layer).
|
||||
"""
|
||||
fields = get_layer_fields(layer, prefix=prefix)
|
||||
connections = []
|
||||
def get_connections(layer):
|
||||
connections = ""
|
||||
for node in layer._inbound_nodes:
|
||||
if relevant_nodes and node not in relevant_nodes:
|
||||
# node is not part of the current network
|
||||
@ -216,34 +189,42 @@ def print_summary(
|
||||
inbound_layer = keras_history.operation
|
||||
node_index = keras_history.node_index
|
||||
tensor_index = keras_history.tensor_index
|
||||
connections.append(
|
||||
if connections:
|
||||
connections += ", "
|
||||
connections += (
|
||||
f"{inbound_layer.name}[{node_index}][{tensor_index}]"
|
||||
)
|
||||
if not connections:
|
||||
connections = "-"
|
||||
fields.append(connections)
|
||||
if show_trainable:
|
||||
fields.append("Y" if layer.trainable else "N")
|
||||
rows.append(fields)
|
||||
return connections
|
||||
|
||||
def print_layer(layer, nested_level=0):
|
||||
if nested_level:
|
||||
prefix = " " * nested_level + "└" + " "
|
||||
else:
|
||||
prefix = ""
|
||||
if sequential_like:
|
||||
print_layer_summary(layer, prefix=prefix)
|
||||
else:
|
||||
print_layer_summary_with_connections(layer, prefix=prefix)
|
||||
|
||||
fields = get_layer_fields(layer, prefix=prefix)
|
||||
if not sequential_like:
|
||||
fields.append(get_connections(layer))
|
||||
if show_trainable:
|
||||
fields.append("Y" if layer.trainable else "N")
|
||||
|
||||
rows = [fields]
|
||||
if expand_nested and hasattr(layer, "layers") and layer.layers:
|
||||
nested_layers = layer.layers
|
||||
nested_level += 1
|
||||
for i in range(len(nested_layers)):
|
||||
print_layer(nested_layers[i], nested_level=nested_level)
|
||||
rows.extend(
|
||||
print_layer(nested_layers[i], nested_level=nested_level)
|
||||
)
|
||||
return rows
|
||||
|
||||
layer_range = get_layer_index_bound_by_layer_name(layers, layer_range)
|
||||
print_fn(text_rendering.highlight_msg(f' Model: "{model.name}"'))
|
||||
rows = []
|
||||
for layer in layers[layer_range[0] : layer_range[1]]:
|
||||
print_layer(layer)
|
||||
rows.extend(print_layer(layer))
|
||||
|
||||
# Render summary as a table.
|
||||
table = text_rendering.TextTable(
|
||||
@ -254,13 +235,8 @@ def print_summary(
|
||||
alignments=["left"] + ["center" for _ in range(len(header) - 1)],
|
||||
max_line_length=line_length,
|
||||
)
|
||||
try:
|
||||
table_str = table.make()
|
||||
print_fn(table_str)
|
||||
except UnicodeEncodeError:
|
||||
printable = set(string.printable)
|
||||
table_str = filter(lambda x: x in printable, table_str)
|
||||
print_fn(table_str)
|
||||
table_str = table.make()
|
||||
print_fn(table_str)
|
||||
|
||||
# After the table, append information about parameter count and size.
|
||||
if hasattr(model, "_collected_trainable_weights"):
|
||||
|
@ -63,8 +63,8 @@ class TextTable:
|
||||
):
|
||||
alignments = alignments or ["center" for _ in fields]
|
||||
lines = []
|
||||
line_break_chars_post = (")", "}", "]")
|
||||
line_break_chars_pre = ("(", "{", "[")
|
||||
line_break_chars_post = ("),", "],")
|
||||
line_break_chars_pre = ("(", "[")
|
||||
for field, width, alignment in zip(
|
||||
fields, self.column_widths, alignments
|
||||
):
|
||||
|
Loading…
Reference in New Issue
Block a user