913405c237
* add pytorch backend trainer * addressing the comments --------- Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
264 lines
9.9 KiB
Python
264 lines
9.9 KiB
Python
import math
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
from keras_core import operations as ops
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.utils import io_utils
|
|
|
|
|
|
@keras_core_export("keras_core.utils.Progbar")
|
|
class Progbar:
|
|
"""Displays a progress bar.
|
|
|
|
Args:
|
|
target: Total number of steps expected, None if unknown.
|
|
width: Progress bar width on screen.
|
|
verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
|
|
stateful_metrics: Iterable of string names of metrics that should *not*
|
|
be averaged over time. Metrics in this list will be displayed as-is.
|
|
All others will be averaged by the progbar before display.
|
|
interval: Minimum visual progress update interval (in seconds).
|
|
unit_name: Display name for step counts (usually "step" or "sample").
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
target,
|
|
width=20,
|
|
verbose=1,
|
|
interval=0.05,
|
|
stateful_metrics=None,
|
|
unit_name="step",
|
|
):
|
|
self.target = target
|
|
self.width = width
|
|
self.verbose = verbose
|
|
self.interval = interval
|
|
self.unit_name = unit_name
|
|
if stateful_metrics:
|
|
self.stateful_metrics = set(stateful_metrics)
|
|
else:
|
|
self.stateful_metrics = set()
|
|
|
|
self._dynamic_display = (
|
|
(hasattr(sys.stdout, "isatty") and sys.stdout.isatty())
|
|
or "ipykernel" in sys.modules
|
|
or "posix" in sys.modules
|
|
or "PYCHARM_HOSTED" in os.environ
|
|
)
|
|
self._seen_so_far = 0
|
|
# We use a dict + list to avoid garbage collection
|
|
# issues found in OrderedDict
|
|
self._values = {}
|
|
self._values_order = []
|
|
self._start = time.time()
|
|
self._last_update = 0
|
|
self._time_at_epoch_start = self._start
|
|
self._time_after_first_step = None
|
|
self._prev_total_width = 0
|
|
|
|
def update(self, current, values=None, finalize=None):
|
|
"""Updates the progress bar.
|
|
|
|
Args:
|
|
current: Index of current step.
|
|
values: List of tuples: `(name, value_for_last_step)`. If `name` is
|
|
in `stateful_metrics`, `value_for_last_step` will be displayed
|
|
as-is. Else, an average of the metric over time will be
|
|
displayed.
|
|
finalize: Whether this is the last update for the progress bar. If
|
|
`None`, defaults to `current >= self.target`.
|
|
"""
|
|
if finalize is None:
|
|
if self.target is None:
|
|
finalize = False
|
|
else:
|
|
finalize = current >= self.target
|
|
|
|
values = values or []
|
|
for k, v in values:
|
|
if k not in self._values_order:
|
|
self._values_order.append(k)
|
|
if k not in self.stateful_metrics:
|
|
# In the case that progress bar doesn't have a target value in
|
|
# the first epoch, both on_batch_end and on_epoch_end will be
|
|
# called, which will cause 'current' and 'self._seen_so_far' to
|
|
# have the same value. Force the minimal value to 1 here,
|
|
# otherwise stateful_metric will be 0s.
|
|
value_base = max(current - self._seen_so_far, 1)
|
|
if k not in self._values:
|
|
self._values[k] = [v * value_base, value_base]
|
|
else:
|
|
self._values[k][0] += v * value_base
|
|
self._values[k][1] += value_base
|
|
else:
|
|
# Stateful metrics output a numeric value. This representation
|
|
# means "take an average from a single value" but keeps the
|
|
# numeric formatting.
|
|
self._values[k] = [v, 1]
|
|
self._seen_so_far = current
|
|
|
|
message = ""
|
|
special_char_len = 0
|
|
now = time.time()
|
|
time_per_unit = self._estimate_step_duration(current, now)
|
|
|
|
if self.verbose == 1:
|
|
if now - self._last_update < self.interval and not finalize:
|
|
return
|
|
|
|
if self._dynamic_display:
|
|
message += "\b" * self._prev_total_width
|
|
message += "\r"
|
|
else:
|
|
message += "\n"
|
|
|
|
if self.target is not None:
|
|
numdigits = int(math.log10(self.target)) + 1
|
|
bar = ("%" + str(numdigits) + "d/%d") % (current, self.target)
|
|
bar = f"\x1b[1m{bar}\x1b[0m "
|
|
prog = float(current) / self.target
|
|
prog_width = int(self.width * prog)
|
|
|
|
if prog_width > 0:
|
|
bar += "\33[32m" + "━" * prog_width + "\x1b[0m"
|
|
special_char_len += 17
|
|
bar += "\33[37m" + "━" * (self.width - prog_width) + "\x1b[0m"
|
|
special_char_len += 9
|
|
|
|
else:
|
|
bar = "%7d/Unknown" % current
|
|
message += bar
|
|
|
|
# Add ETA if applicable
|
|
if self.target is not None and not finalize:
|
|
eta = time_per_unit * (self.target - current)
|
|
if eta > 3600:
|
|
eta_format = "%d:%02d:%02d" % (
|
|
eta // 3600,
|
|
(eta % 3600) // 60,
|
|
eta % 60,
|
|
)
|
|
elif eta > 60:
|
|
eta_format = "%d:%02d" % (eta // 60, eta % 60)
|
|
else:
|
|
eta_format = "%ds" % eta
|
|
info = f" \x1b[1m{eta_format}\x1b[0m"
|
|
else:
|
|
# Time elapsed since start, in seconds
|
|
info = f" \x1b[1m{now - self._start:.0f}s\x1b[0m"
|
|
special_char_len += 8
|
|
|
|
# Add time/step
|
|
info += self._format_time(time_per_unit, self.unit_name)
|
|
|
|
# Add metrics
|
|
for k in self._values_order:
|
|
info += f" - {k}:"
|
|
if isinstance(self._values[k], list):
|
|
avg = ops.mean(
|
|
self._values[k][0] / max(1, self._values[k][1])
|
|
)
|
|
if abs(avg) > 1e-3:
|
|
info += f" {avg:.4f}"
|
|
else:
|
|
info += f" {avg:.4e}"
|
|
else:
|
|
info += f" {self._values[k]}"
|
|
message += info
|
|
|
|
total_width = len(bar) + len(info) - special_char_len
|
|
if self._prev_total_width > total_width:
|
|
message += " " * (self._prev_total_width - total_width)
|
|
if finalize:
|
|
message += "\n"
|
|
|
|
io_utils.print_msg(message, line_break=False)
|
|
self._prev_total_width = total_width
|
|
message = ""
|
|
|
|
elif self.verbose == 2:
|
|
if finalize:
|
|
numdigits = int(math.log10(self.target)) + 1
|
|
count = ("%" + str(numdigits) + "d/%d") % (current, self.target)
|
|
info = f"{count} - {now - self._start:.0f}s"
|
|
info += " -" + self._format_time(time_per_unit, self.unit_name)
|
|
for k in self._values_order:
|
|
info += f" - {k}:"
|
|
avg = ops.mean(
|
|
self._values[k][0] / max(1, self._values[k][1])
|
|
)
|
|
if avg > 1e-3:
|
|
info += f" {avg:.4f}"
|
|
else:
|
|
info += f" {avg:.4e}"
|
|
info += "\n"
|
|
message += info
|
|
io_utils.print_msg(message, line_break=False)
|
|
message = ""
|
|
|
|
self._last_update = now
|
|
|
|
def add(self, n, values=None):
|
|
self.update(self._seen_so_far + n, values)
|
|
|
|
def _format_time(self, time_per_unit, unit_name):
|
|
"""format a given duration to display to the user.
|
|
|
|
Given the duration, this function formats it in either milliseconds
|
|
or seconds and displays the unit (i.e. ms/step or s/epoch).
|
|
|
|
Args:
|
|
time_per_unit: the duration to display
|
|
unit_name: the name of the unit to display
|
|
|
|
Returns:
|
|
A string with the correctly formatted duration and units
|
|
"""
|
|
formatted = ""
|
|
if time_per_unit >= 1 or time_per_unit == 0:
|
|
formatted += f" {time_per_unit:.0f}s/{unit_name}"
|
|
elif time_per_unit >= 1e-3:
|
|
formatted += f" {time_per_unit * 1000.0:.0f}ms/{unit_name}"
|
|
else:
|
|
formatted += f" {time_per_unit * 1000000.0:.0f}us/{unit_name}"
|
|
return formatted
|
|
|
|
def _estimate_step_duration(self, current, now):
|
|
"""Estimate the duration of a single step.
|
|
|
|
Given the step number `current` and the corresponding time `now` this
|
|
function returns an estimate for how long a single step takes. If this
|
|
is called before one step has been completed (i.e. `current == 0`) then
|
|
zero is given as an estimate. The duration estimate ignores the duration
|
|
of the (assumed to be non-representative) first step for estimates when
|
|
more steps are available (i.e. `current>1`).
|
|
|
|
Args:
|
|
current: Index of current step.
|
|
now: The current time.
|
|
|
|
Returns: Estimate of the duration of a single step.
|
|
"""
|
|
if current:
|
|
# there are a few special scenarios here:
|
|
# 1) somebody is calling the progress bar without ever supplying
|
|
# step 1
|
|
# 2) somebody is calling the progress bar and supplies step one
|
|
# multiple times, e.g. as part of a finalizing call
|
|
# in these cases, we just fall back to the simple calculation
|
|
if self._time_after_first_step is not None and current > 1:
|
|
time_per_unit = (now - self._time_after_first_step) / (
|
|
current - 1
|
|
)
|
|
else:
|
|
time_per_unit = (now - self._start) / current
|
|
|
|
if current == 1:
|
|
self._time_after_first_step = now
|
|
return time_per_unit
|
|
else:
|
|
return 0
|