Trainable summary touch ups
This commit is contained in:
parent
37f4728267
commit
bfc9feb17d
@ -146,21 +146,8 @@ from keras_core.ops.operation_utils import reduce_shape
|
||||
|
||||
|
||||
def broadcast_shapes(shape1, shape2):
|
||||
"""Broadcast input shapes to a unified shape.
|
||||
|
||||
Convert to list for mutability.
|
||||
|
||||
Args:
|
||||
shape1: A tuple or list of integers.
|
||||
shape2: A tuple or list of integers.
|
||||
|
||||
Returns:
|
||||
output_shape (list of int or None): The broadcasted shape.
|
||||
|
||||
Example:
|
||||
>>> broadcast_shapes((5, 3), (1, 3))
|
||||
[5, 3]
|
||||
"""
|
||||
# Broadcast input shapes to a unified shape.
|
||||
# Convert to list for mutability.
|
||||
shape1 = list(shape1)
|
||||
shape2 = list(shape2)
|
||||
origin_shape1 = shape1
|
||||
|
@ -192,8 +192,8 @@ def print_summary(
|
||||
relevant_nodes += v
|
||||
|
||||
if show_trainable:
|
||||
default_line_length += 8
|
||||
positions = [p * 0.88 for p in positions] + [1.0]
|
||||
default_line_length += 12
|
||||
positions = [p * 0.90 for p in positions] + [1.0]
|
||||
header.append("Trainable")
|
||||
alignment.append("center")
|
||||
|
||||
@ -262,11 +262,14 @@ def print_summary(
|
||||
if not sequential_like:
|
||||
fields.append(get_connections(layer))
|
||||
if show_trainable:
|
||||
if layer.weights:
|
||||
fields.append(
|
||||
bold_text("Y", color=34)
|
||||
if layer.trainable
|
||||
else bold_text("N", color=9)
|
||||
)
|
||||
else:
|
||||
fields.append(bold_text("-"))
|
||||
return fields
|
||||
|
||||
def print_layer(layer, nested_level=0):
|
||||
|
Loading…
Reference in New Issue
Block a user