Fix progbar glitch
This commit is contained in:
parent
4bf04920f2
commit
ee5be68ce9
@ -9,16 +9,15 @@ from keras_core.trainers.epoch_iterator import EpochIterator
|
||||
|
||||
|
||||
class Trainer(base_trainer.Trainer):
|
||||
|
||||
def stateless_compute_loss_and_updates(
|
||||
self, trainable_variables, non_trainable_variables, x, y, sample_weight
|
||||
):
|
||||
y_pred, non_trainable_variables = self.stateless_call(
|
||||
trainable_variables, non_trainable_variables, x
|
||||
)
|
||||
self, trainable_variables, non_trainable_variables, x, y, sample_weight
|
||||
):
|
||||
y_pred, non_trainable_variables = self.stateless_call(
|
||||
trainable_variables, non_trainable_variables, x
|
||||
)
|
||||
|
||||
loss = self.compute_loss(x, y, y_pred, sample_weight)
|
||||
return loss, (y_pred, non_trainable_variables)
|
||||
loss = self.compute_loss(x, y, y_pred, sample_weight)
|
||||
return loss, (y_pred, non_trainable_variables)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
@ -104,8 +103,10 @@ class Trainer(base_trainer.Trainer):
|
||||
model=self,
|
||||
)
|
||||
|
||||
grad_fn = jax.value_and_grad(self.stateless_compute_loss_and_updates, has_aux=True)
|
||||
|
||||
grad_fn = jax.value_and_grad(
|
||||
self.stateless_compute_loss_and_updates, has_aux=True
|
||||
)
|
||||
|
||||
def _train_step(state, data):
|
||||
(
|
||||
trainable_variables,
|
||||
@ -117,7 +118,11 @@ class Trainer(base_trainer.Trainer):
|
||||
data
|
||||
)
|
||||
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
|
||||
trainable_variables, non_trainable_variables, x, y, sample_weight
|
||||
trainable_variables,
|
||||
non_trainable_variables,
|
||||
x,
|
||||
y,
|
||||
sample_weight,
|
||||
)
|
||||
|
||||
(
|
||||
@ -153,14 +158,16 @@ class Trainer(base_trainer.Trainer):
|
||||
metrics_variables,
|
||||
)
|
||||
return logs, state
|
||||
|
||||
|
||||
if not self.run_eagerly and self.jit_compile:
|
||||
|
||||
@jax.jit
|
||||
def train_step(state, data):
|
||||
return _train_step(state, data)
|
||||
|
||||
else:
|
||||
train_step = _train_step
|
||||
|
||||
|
||||
self.stop_training = False
|
||||
callbacks.on_train_begin()
|
||||
|
||||
@ -273,4 +280,4 @@ class Trainer(base_trainer.Trainer):
|
||||
def predict(
|
||||
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
|
||||
):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
||||
|
@ -296,4 +296,4 @@ class Trainer:
|
||||
except:
|
||||
pass
|
||||
result[key] = value
|
||||
return result
|
||||
return result
|
||||
|
@ -49,7 +49,6 @@ class Progbar:
|
||||
or "posix" in sys.modules
|
||||
or "PYCHARM_HOSTED" in os.environ
|
||||
)
|
||||
self._total_width = 0
|
||||
self._seen_so_far = 0
|
||||
# We use a dict + list to avoid garbage collection
|
||||
# issues found in OrderedDict
|
||||
@ -59,6 +58,7 @@ class Progbar:
|
||||
self._last_update = 0
|
||||
self._time_at_epoch_start = self._start
|
||||
self._time_after_first_step = None
|
||||
self._prev_total_width = 0
|
||||
|
||||
def update(self, current, values=None, finalize=None):
|
||||
"""Updates the progress bar.
|
||||
@ -102,6 +102,7 @@ class Progbar:
|
||||
self._seen_so_far = current
|
||||
|
||||
message = ""
|
||||
special_char_len = 0
|
||||
now = time.time()
|
||||
time_per_unit = self._estimate_step_duration(current, now)
|
||||
|
||||
@ -109,9 +110,8 @@ class Progbar:
|
||||
if now - self._last_update < self.interval and not finalize:
|
||||
return
|
||||
|
||||
prev_total_width = self._total_width
|
||||
if self._dynamic_display:
|
||||
message += "\b" * prev_total_width
|
||||
message += "\b" * self._prev_total_width
|
||||
message += "\r"
|
||||
else:
|
||||
message += "\n"
|
||||
@ -125,13 +125,12 @@ class Progbar:
|
||||
|
||||
if prog_width > 0:
|
||||
bar += "\33[32m" + "━" * prog_width + "\x1b[0m"
|
||||
|
||||
special_char_len += 17
|
||||
bar += "\33[37m" + "━" * (self.width - prog_width) + "\x1b[0m"
|
||||
special_char_len += 9
|
||||
|
||||
else:
|
||||
bar = "%7d/Unknown" % current
|
||||
|
||||
self._total_width = len(bar)
|
||||
message += bar
|
||||
|
||||
# Add ETA if applicable
|
||||
@ -151,6 +150,7 @@ class Progbar:
|
||||
else:
|
||||
# Time elapsed since start, in seconds
|
||||
info = f" \x1b[1m{now - self._start:.0f}s\x1b[0m"
|
||||
special_char_len += 8
|
||||
|
||||
# Add time/step
|
||||
info += self._format_time(time_per_unit, self.unit_name)
|
||||
@ -168,16 +168,16 @@ class Progbar:
|
||||
info += f" {avg:.4e}"
|
||||
else:
|
||||
info += f" {self._values[k]}"
|
||||
|
||||
self._total_width += len(info)
|
||||
if prev_total_width > self._total_width:
|
||||
info += " " * (prev_total_width - self._total_width)
|
||||
|
||||
if finalize:
|
||||
info += "\n"
|
||||
|
||||
message += info
|
||||
|
||||
total_width = len(bar) + len(info) - special_char_len
|
||||
if self._prev_total_width > total_width:
|
||||
message += " " * (self._prev_total_width - total_width)
|
||||
if finalize:
|
||||
message += "\n"
|
||||
|
||||
io_utils.print_msg(message, line_break=False)
|
||||
self._prev_total_width = total_width
|
||||
message = ""
|
||||
|
||||
elif self.verbose == 2:
|
||||
|
Loading…
Reference in New Issue
Block a user