diff --git a/keras_core/datasets/cifar100.py b/keras_core/datasets/cifar100.py index a1fe42927..a1d57cc67 100644 --- a/keras_core/datasets/cifar100.py +++ b/keras_core/datasets/cifar100.py @@ -3,10 +3,10 @@ import os import numpy as np -from keras import backend -from keras.datasets.cifar import load_batch +from keras_core import backend from keras_core.api_export import keras_core_export +from keras_core.datasets.cifar import load_batch from keras_core.utils.file_utils import get_file diff --git a/keras_core/optimizers/nadam.py b/keras_core/optimizers/nadam.py index d88815362..bd057bbba 100644 --- a/keras_core/optimizers/nadam.py +++ b/keras_core/optimizers/nadam.py @@ -119,10 +119,11 @@ class Nadam(optimizer.Optimizer): self._u_product_counter += 1 return u_product_t - if self._u_product_counter == (self.iterations + 2): - u_product_t = get_cached_u_product() - else: - u_product_t = compute_new_u_product() + u_product_t = ops.cond( + ops.equal(self._u_product_counter, (self.iterations + 2)), + get_cached_u_product, + compute_new_u_product, + ) u_product_t_1 = u_product_t * u_t_1 beta_2_power = ops.power(beta_2, local_step)