Restore print_fn option for summaries (#236)
This commit is contained in:
parent
7267c4e32c
commit
1eb1bf05d9
@ -125,7 +125,7 @@ def print_summary(
|
|||||||
from keras_core.models import Functional
|
from keras_core.models import Functional
|
||||||
from keras_core.models import Sequential
|
from keras_core.models import Sequential
|
||||||
|
|
||||||
if print_fn is None:
|
if not print_fn and not io_utils.is_interactive_logging_enabled():
|
||||||
print_fn = io_utils.print_msg
|
print_fn = io_utils.print_msg
|
||||||
|
|
||||||
if isinstance(model, Sequential):
|
if isinstance(model, Sequential):
|
||||||
@ -292,13 +292,13 @@ def print_summary(
|
|||||||
total_memory_size = trainable_memory_size + non_trainable_memory_size
|
total_memory_size = trainable_memory_size + non_trainable_memory_size
|
||||||
|
|
||||||
# Create a rich console for printing. Capture for non-interactive logging.
|
# Create a rich console for printing. Capture for non-interactive logging.
|
||||||
if io_utils.is_interactive_logging_enabled():
|
if print_fn:
|
||||||
console = rich.console.Console(highlight=False)
|
|
||||||
else:
|
|
||||||
console = rich.console.Console(
|
console = rich.console.Console(
|
||||||
highlight=False, force_terminal=False, color_system=None
|
highlight=False, force_terminal=False, color_system=None
|
||||||
)
|
)
|
||||||
console.begin_capture()
|
console.begin_capture()
|
||||||
|
else:
|
||||||
|
console = rich.console.Console(highlight=False)
|
||||||
|
|
||||||
# Print the to the console.
|
# Print the to the console.
|
||||||
console.print(bold_text(f'Model: "{rich.markup.escape(model.name)}"'))
|
console.print(bold_text(f'Model: "{rich.markup.escape(model.name)}"'))
|
||||||
@ -320,8 +320,8 @@ def print_summary(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Output captured summary for non-interactive logging.
|
# Output captured summary for non-interactive logging.
|
||||||
if not io_utils.is_interactive_logging_enabled():
|
if print_fn:
|
||||||
io_utils.print_msg(console.end_capture(), line_break=False)
|
print_fn(console.end_capture(), line_break=False)
|
||||||
|
|
||||||
|
|
||||||
def get_layer_index_bound_by_layer_name(layers, layer_range=None):
|
def get_layer_index_bound_by_layer_name(layers, layer_range=None):
|
||||||
|
Loading…
Reference in New Issue
Block a user