Bug fixes

This commit is contained in:
Francois Chollet 2023-05-30 11:06:56 -07:00
parent a33eb25ca9
commit ad6fd768e4
2 changed files with 7 additions and 6 deletions

@ -3,10 +3,10 @@
import os
import numpy as np
from keras import backend
from keras.datasets.cifar import load_batch
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.datasets.cifar import load_batch
from keras_core.utils.file_utils import get_file

@ -119,10 +119,11 @@ class Nadam(optimizer.Optimizer):
self._u_product_counter += 1
return u_product_t
if self._u_product_counter == (self.iterations + 2):
u_product_t = get_cached_u_product()
else:
u_product_t = compute_new_u_product()
u_product_t = ops.cond(
ops.equal(self._u_product_counter, (self.iterations + 2)),
get_cached_u_product,
compute_new_u_product,
)
u_product_t_1 = u_product_t * u_t_1
beta_2_power = ops.power(beta_2, local_step)