From bfc9feb17d0758619810ae26b50ace447e462a88 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 17 Jul 2023 10:39:56 -0700 Subject: [PATCH] Trainable summary touch ups --- keras_core/ops/numpy.py | 17 ++--------------- keras_core/utils/summary_utils.py | 17 ++++++++++------- 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/keras_core/ops/numpy.py b/keras_core/ops/numpy.py index f1c161806..666578acc 100644 --- a/keras_core/ops/numpy.py +++ b/keras_core/ops/numpy.py @@ -146,21 +146,8 @@ from keras_core.ops.operation_utils import reduce_shape def broadcast_shapes(shape1, shape2): - """Broadcast input shapes to a unified shape. - - Convert to list for mutability. - - Args: - shape1: A tuple or list of integers. - shape2: A tuple or list of integers. - - Returns: - output_shape (list of int or None): The broadcasted shape. - - Example: - >>> broadcast_shapes((5, 3), (1, 3)) - [5, 3] - """ + # Broadcast input shapes to a unified shape. + # Convert to list for mutability. shape1 = list(shape1) shape2 = list(shape2) origin_shape1 = shape1 diff --git a/keras_core/utils/summary_utils.py b/keras_core/utils/summary_utils.py index dfeba89a8..8717d4d40 100644 --- a/keras_core/utils/summary_utils.py +++ b/keras_core/utils/summary_utils.py @@ -192,8 +192,8 @@ def print_summary( relevant_nodes += v if show_trainable: - default_line_length += 8 - positions = [p * 0.88 for p in positions] + [1.0] + default_line_length += 12 + positions = [p * 0.90 for p in positions] + [1.0] header.append("Trainable") alignment.append("center") @@ -262,11 +262,14 @@ def print_summary( if not sequential_like: fields.append(get_connections(layer)) if show_trainable: - fields.append( - bold_text("Y", color=34) - if layer.trainable - else bold_text("N", color=9) - ) + if layer.weights: + fields.append( + bold_text("Y", color=34) + if layer.trainable + else bold_text("N", color=9) + ) + else: + fields.append(bold_text("-")) return fields def print_layer(layer, nested_level=0):