Fix progbar glitch

This commit is contained in:
Francois Chollet 2023-04-19 15:24:35 -07:00
parent 4bf04920f2
commit ee5be68ce9
3 changed files with 36 additions and 29 deletions

@ -9,16 +9,15 @@ from keras_core.trainers.epoch_iterator import EpochIterator
class Trainer(base_trainer.Trainer):
def stateless_compute_loss_and_updates(
self, trainable_variables, non_trainable_variables, x, y, sample_weight
):
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x
)
self, trainable_variables, non_trainable_variables, x, y, sample_weight
):
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = self.compute_loss(x, y, y_pred, sample_weight)
return loss, (y_pred, non_trainable_variables)
loss = self.compute_loss(x, y, y_pred, sample_weight)
return loss, (y_pred, non_trainable_variables)
def fit(
self,
@ -104,8 +103,10 @@ class Trainer(base_trainer.Trainer):
model=self,
)
grad_fn = jax.value_and_grad(self.stateless_compute_loss_and_updates, has_aux=True)
grad_fn = jax.value_and_grad(
self.stateless_compute_loss_and_updates, has_aux=True
)
def _train_step(state, data):
(
trainable_variables,
@ -117,7 +118,11 @@ class Trainer(base_trainer.Trainer):
data
)
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y, sample_weight
trainable_variables,
non_trainable_variables,
x,
y,
sample_weight,
)
(
@ -153,14 +158,16 @@ class Trainer(base_trainer.Trainer):
metrics_variables,
)
return logs, state
if not self.run_eagerly and self.jit_compile:
@jax.jit
def train_step(state, data):
return _train_step(state, data)
else:
train_step = _train_step
self.stop_training = False
callbacks.on_train_begin()
@ -273,4 +280,4 @@ class Trainer(base_trainer.Trainer):
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
):
raise NotImplementedError
raise NotImplementedError

@ -296,4 +296,4 @@ class Trainer:
except:
pass
result[key] = value
return result
return result

@ -49,7 +49,6 @@ class Progbar:
or "posix" in sys.modules
or "PYCHARM_HOSTED" in os.environ
)
self._total_width = 0
self._seen_so_far = 0
# We use a dict + list to avoid garbage collection
# issues found in OrderedDict
@ -59,6 +58,7 @@ class Progbar:
self._last_update = 0
self._time_at_epoch_start = self._start
self._time_after_first_step = None
self._prev_total_width = 0
def update(self, current, values=None, finalize=None):
"""Updates the progress bar.
@ -102,6 +102,7 @@ class Progbar:
self._seen_so_far = current
message = ""
special_char_len = 0
now = time.time()
time_per_unit = self._estimate_step_duration(current, now)
@ -109,9 +110,8 @@ class Progbar:
if now - self._last_update < self.interval and not finalize:
return
prev_total_width = self._total_width
if self._dynamic_display:
message += "\b" * prev_total_width
message += "\b" * self._prev_total_width
message += "\r"
else:
message += "\n"
@ -125,13 +125,12 @@ class Progbar:
if prog_width > 0:
bar += "\33[32m" + "" * prog_width + "\x1b[0m"
special_char_len += 17
bar += "\33[37m" + "" * (self.width - prog_width) + "\x1b[0m"
special_char_len += 9
else:
bar = "%7d/Unknown" % current
self._total_width = len(bar)
message += bar
# Add ETA if applicable
@ -151,6 +150,7 @@ class Progbar:
else:
# Time elapsed since start, in seconds
info = f" \x1b[1m{now - self._start:.0f}s\x1b[0m"
special_char_len += 8
# Add time/step
info += self._format_time(time_per_unit, self.unit_name)
@ -168,16 +168,16 @@ class Progbar:
info += f" {avg:.4e}"
else:
info += f" {self._values[k]}"
self._total_width += len(info)
if prev_total_width > self._total_width:
info += " " * (prev_total_width - self._total_width)
if finalize:
info += "\n"
message += info
total_width = len(bar) + len(info) - special_char_len
if self._prev_total_width > total_width:
message += " " * (self._prev_total_width - total_width)
if finalize:
message += "\n"
io_utils.print_msg(message, line_break=False)
self._prev_total_width = total_width
message = ""
elif self.verbose == 2: