Restore print_fn option for summaries (#236)

This commit is contained in:
Matt Watson 2023-05-31 19:32:45 -07:00 committed by Francois Chollet
parent 7267c4e32c
commit 1eb1bf05d9

@ -125,7 +125,7 @@ def print_summary(
from keras_core.models import Functional
from keras_core.models import Sequential
if print_fn is None:
if not print_fn and not io_utils.is_interactive_logging_enabled():
print_fn = io_utils.print_msg
if isinstance(model, Sequential):
@ -292,13 +292,13 @@ def print_summary(
total_memory_size = trainable_memory_size + non_trainable_memory_size
# Create a rich console for printing. Capture for non-interactive logging.
if io_utils.is_interactive_logging_enabled():
console = rich.console.Console(highlight=False)
else:
if print_fn:
console = rich.console.Console(
highlight=False, force_terminal=False, color_system=None
)
console.begin_capture()
else:
console = rich.console.Console(highlight=False)
# Print the to the console.
console.print(bold_text(f'Model: "{rich.markup.escape(model.name)}"'))
@ -320,8 +320,8 @@ def print_summary(
)
# Output captured summary for non-interactive logging.
if not io_utils.is_interactive_logging_enabled():
io_utils.print_msg(console.end_capture(), line_break=False)
if print_fn:
print_fn(console.end_capture(), line_break=False)
def get_layer_index_bound_by_layer_name(layers, layer_range=None):