From ee5be68ce9bd993f8d66805cfe3638fffdb6e903 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 19 Apr 2023 15:24:35 -0700 Subject: [PATCH] Fix progbar glitch --- keras_core/backend/jax/trainer.py | 35 ++++++++++++++++++------------- keras_core/trainers/trainer.py | 2 +- keras_core/utils/progbar.py | 28 ++++++++++++------------- 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/keras_core/backend/jax/trainer.py b/keras_core/backend/jax/trainer.py index bc06c298a..4f8e443ba 100644 --- a/keras_core/backend/jax/trainer.py +++ b/keras_core/backend/jax/trainer.py @@ -9,16 +9,15 @@ from keras_core.trainers.epoch_iterator import EpochIterator class Trainer(base_trainer.Trainer): - def stateless_compute_loss_and_updates( - self, trainable_variables, non_trainable_variables, x, y, sample_weight - ): - y_pred, non_trainable_variables = self.stateless_call( - trainable_variables, non_trainable_variables, x - ) + self, trainable_variables, non_trainable_variables, x, y, sample_weight + ): + y_pred, non_trainable_variables = self.stateless_call( + trainable_variables, non_trainable_variables, x + ) - loss = self.compute_loss(x, y, y_pred, sample_weight) - return loss, (y_pred, non_trainable_variables) + loss = self.compute_loss(x, y, y_pred, sample_weight) + return loss, (y_pred, non_trainable_variables) def fit( self, @@ -104,8 +103,10 @@ class Trainer(base_trainer.Trainer): model=self, ) - grad_fn = jax.value_and_grad(self.stateless_compute_loss_and_updates, has_aux=True) - + grad_fn = jax.value_and_grad( + self.stateless_compute_loss_and_updates, has_aux=True + ) + def _train_step(state, data): ( trainable_variables, @@ -117,7 +118,11 @@ class Trainer(base_trainer.Trainer): data ) (loss, (y_pred, non_trainable_variables)), grads = grad_fn( - trainable_variables, non_trainable_variables, x, y, sample_weight + trainable_variables, + non_trainable_variables, + x, + y, + sample_weight, ) ( @@ -153,14 +158,16 @@ class Trainer(base_trainer.Trainer): metrics_variables, ) return logs, state - + if not self.run_eagerly and self.jit_compile: + @jax.jit def train_step(state, data): return _train_step(state, data) + else: train_step = _train_step - + self.stop_training = False callbacks.on_train_begin() @@ -273,4 +280,4 @@ class Trainer(base_trainer.Trainer): def predict( self, x, batch_size=None, verbose="auto", steps=None, callbacks=None ): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/keras_core/trainers/trainer.py b/keras_core/trainers/trainer.py index 8e6de1dca..178cbc2d0 100644 --- a/keras_core/trainers/trainer.py +++ b/keras_core/trainers/trainer.py @@ -296,4 +296,4 @@ class Trainer: except: pass result[key] = value - return result \ No newline at end of file + return result diff --git a/keras_core/utils/progbar.py b/keras_core/utils/progbar.py index 83d531f66..16ba6804c 100644 --- a/keras_core/utils/progbar.py +++ b/keras_core/utils/progbar.py @@ -49,7 +49,6 @@ class Progbar: or "posix" in sys.modules or "PYCHARM_HOSTED" in os.environ ) - self._total_width = 0 self._seen_so_far = 0 # We use a dict + list to avoid garbage collection # issues found in OrderedDict @@ -59,6 +58,7 @@ class Progbar: self._last_update = 0 self._time_at_epoch_start = self._start self._time_after_first_step = None + self._prev_total_width = 0 def update(self, current, values=None, finalize=None): """Updates the progress bar. @@ -102,6 +102,7 @@ class Progbar: self._seen_so_far = current message = "" + special_char_len = 0 now = time.time() time_per_unit = self._estimate_step_duration(current, now) @@ -109,9 +110,8 @@ class Progbar: if now - self._last_update < self.interval and not finalize: return - prev_total_width = self._total_width if self._dynamic_display: - message += "\b" * prev_total_width + message += "\b" * self._prev_total_width message += "\r" else: message += "\n" @@ -125,13 +125,12 @@ class Progbar: if prog_width > 0: bar += "\33[32m" + "━" * prog_width + "\x1b[0m" - + special_char_len += 17 bar += "\33[37m" + "━" * (self.width - prog_width) + "\x1b[0m" + special_char_len += 9 else: bar = "%7d/Unknown" % current - - self._total_width = len(bar) message += bar # Add ETA if applicable @@ -151,6 +150,7 @@ class Progbar: else: # Time elapsed since start, in seconds info = f" \x1b[1m{now - self._start:.0f}s\x1b[0m" + special_char_len += 8 # Add time/step info += self._format_time(time_per_unit, self.unit_name) @@ -168,16 +168,16 @@ class Progbar: info += f" {avg:.4e}" else: info += f" {self._values[k]}" - - self._total_width += len(info) - if prev_total_width > self._total_width: - info += " " * (prev_total_width - self._total_width) - - if finalize: - info += "\n" - message += info + + total_width = len(bar) + len(info) - special_char_len + if self._prev_total_width > total_width: + message += " " * (self._prev_total_width - total_width) + if finalize: + message += "\n" + io_utils.print_msg(message, line_break=False) + self._prev_total_width = total_width message = "" elif self.verbose == 2: