Bug fixes
This commit is contained in:
parent
a33eb25ca9
commit
ad6fd768e4
@ -3,10 +3,10 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from keras import backend
|
||||
from keras.datasets.cifar import load_batch
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.datasets.cifar import load_batch
|
||||
from keras_core.utils.file_utils import get_file
|
||||
|
||||
|
||||
|
@ -119,10 +119,11 @@ class Nadam(optimizer.Optimizer):
|
||||
self._u_product_counter += 1
|
||||
return u_product_t
|
||||
|
||||
if self._u_product_counter == (self.iterations + 2):
|
||||
u_product_t = get_cached_u_product()
|
||||
else:
|
||||
u_product_t = compute_new_u_product()
|
||||
u_product_t = ops.cond(
|
||||
ops.equal(self._u_product_counter, (self.iterations + 2)),
|
||||
get_cached_u_product,
|
||||
compute_new_u_product,
|
||||
)
|
||||
|
||||
u_product_t_1 = u_product_t * u_t_1
|
||||
beta_2_power = ops.power(beta_2, local_step)
|
||||
|
Loading…
Reference in New Issue
Block a user