Trainable summary touch ups

This commit is contained in:
Francois Chollet 2023-07-17 10:39:56 -07:00
parent 37f4728267
commit bfc9feb17d
2 changed files with 12 additions and 22 deletions

@ -146,21 +146,8 @@ from keras_core.ops.operation_utils import reduce_shape
def broadcast_shapes(shape1, shape2):
"""Broadcast input shapes to a unified shape.
Convert to list for mutability.
Args:
shape1: A tuple or list of integers.
shape2: A tuple or list of integers.
Returns:
output_shape (list of int or None): The broadcasted shape.
Example:
>>> broadcast_shapes((5, 3), (1, 3))
[5, 3]
"""
# Broadcast input shapes to a unified shape.
# Convert to list for mutability.
shape1 = list(shape1)
shape2 = list(shape2)
origin_shape1 = shape1

@ -192,8 +192,8 @@ def print_summary(
relevant_nodes += v
if show_trainable:
default_line_length += 8
positions = [p * 0.88 for p in positions] + [1.0]
default_line_length += 12
positions = [p * 0.90 for p in positions] + [1.0]
header.append("Trainable")
alignment.append("center")
@ -262,11 +262,14 @@ def print_summary(
if not sequential_like:
fields.append(get_connections(layer))
if show_trainable:
if layer.weights:
fields.append(
bold_text("Y", color=34)
if layer.trainable
else bold_text("N", color=9)
)
else:
fields.append(bold_text("-"))
return fields
def print_layer(layer, nested_level=0):