keras/keras_core/callbacks/progbar_logger.py
Francois Chollet 69ed8781fd Format code
2023-06-05 18:42:03 -07:00

103 lines
3.0 KiB
Python

from keras_core.api_export import keras_core_export
from keras_core.callbacks.callback import Callback
from keras_core.utils import io_utils
from keras_core.utils.progbar import Progbar
@keras_core_export("keras_core.callbacks.ProgbarLogger")
class ProgbarLogger(Callback):
"""Callback that prints metrics to stdout.
Args:
count_mode: One of `"steps"` or `"samples"`.
Whether the progress bar should
count samples seen or steps (batches) seen.
Raises:
ValueError: In case of invalid `count_mode`.
"""
def __init__(self):
super().__init__()
self.seen = 0
self.progbar = None
self.target = None
self.verbose = 1
self.epochs = 1
self._called_in_fit = False
def set_params(self, params):
verbose = params["verbose"]
if verbose == "auto":
verbose = 1
self.verbose = verbose
self.epochs = params["epochs"]
self.target = params["steps"]
def on_train_begin(self, logs=None):
# When this logger is called inside `fit`, validation is silent.
self._called_in_fit = True
def on_test_begin(self, logs=None):
if not self._called_in_fit:
self._reset_progbar()
self._maybe_init_progbar()
def on_predict_begin(self, logs=None):
self._reset_progbar()
self._maybe_init_progbar()
def on_epoch_begin(self, epoch, logs=None):
self._reset_progbar()
self._maybe_init_progbar()
if self.verbose and self.epochs > 1:
io_utils.print_msg(f"Epoch {epoch + 1}/{self.epochs}")
def on_train_batch_end(self, batch, logs=None):
self._update_progbar(batch, logs)
def on_test_batch_end(self, batch, logs=None):
if not self._called_in_fit:
self._update_progbar(batch, logs)
def on_predict_batch_end(self, batch, logs=None):
# Don't pass prediction results.
self._update_progbar(batch, None)
def on_epoch_end(self, epoch, logs=None):
self._finalize_progbar(logs)
def on_test_end(self, logs=None):
if not self._called_in_fit:
self._finalize_progbar(logs)
def on_predict_end(self, logs=None):
self._finalize_progbar(logs)
def _reset_progbar(self):
self.seen = 0
self.progbar = None
def _maybe_init_progbar(self):
if self.progbar is None:
self.progbar = Progbar(
target=self.target, verbose=self.verbose, unit_name="step"
)
def _update_progbar(self, batch, logs=None):
"""Updates the progbar."""
logs = logs or {}
self._maybe_init_progbar()
self.seen = batch + 1 # One-indexed.
if self.verbose == 1:
self.progbar.update(self.seen, list(logs.items()), finalize=False)
def _finalize_progbar(self, logs):
logs = logs or {}
if self.target is None:
self.target = self.seen
self.progbar.target = self.target
self.progbar.update(self.target, list(logs.items()), finalize=True)