Bug fixes
This commit is contained in:
parent
a33eb25ca9
commit
ad6fd768e4
@ -3,10 +3,10 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from keras import backend
|
|
||||||
from keras.datasets.cifar import load_batch
|
|
||||||
|
|
||||||
|
from keras_core import backend
|
||||||
from keras_core.api_export import keras_core_export
|
from keras_core.api_export import keras_core_export
|
||||||
|
from keras_core.datasets.cifar import load_batch
|
||||||
from keras_core.utils.file_utils import get_file
|
from keras_core.utils.file_utils import get_file
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,10 +119,11 @@ class Nadam(optimizer.Optimizer):
|
|||||||
self._u_product_counter += 1
|
self._u_product_counter += 1
|
||||||
return u_product_t
|
return u_product_t
|
||||||
|
|
||||||
if self._u_product_counter == (self.iterations + 2):
|
u_product_t = ops.cond(
|
||||||
u_product_t = get_cached_u_product()
|
ops.equal(self._u_product_counter, (self.iterations + 2)),
|
||||||
else:
|
get_cached_u_product,
|
||||||
u_product_t = compute_new_u_product()
|
compute_new_u_product,
|
||||||
|
)
|
||||||
|
|
||||||
u_product_t_1 = u_product_t * u_t_1
|
u_product_t_1 = u_product_t * u_t_1
|
||||||
beta_2_power = ops.power(beta_2, local_step)
|
beta_2_power = ops.power(beta_2, local_step)
|
||||||
|
Loading…
Reference in New Issue
Block a user