Backwards compat for decay arg
This commit is contained in:
parent
73b22e0f68
commit
e0194197e7
@ -50,6 +50,7 @@ class Adadelta(optimizer.Optimizer):
|
||||
ema_momentum=0.99,
|
||||
ema_overwrite_frequency=None,
|
||||
name="adadelta",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
@ -61,6 +62,7 @@ class Adadelta(optimizer.Optimizer):
|
||||
ema_momentum=ema_momentum,
|
||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
self.rho = rho
|
||||
self.epsilon = epsilon
|
||||
|
@ -57,6 +57,7 @@ class Adafactor(optimizer.Optimizer):
|
||||
ema_momentum=0.99,
|
||||
ema_overwrite_frequency=None,
|
||||
name="adafactor",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
@ -68,6 +69,7 @@ class Adafactor(optimizer.Optimizer):
|
||||
use_ema=use_ema,
|
||||
ema_momentum=ema_momentum,
|
||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||
**kwargs,
|
||||
)
|
||||
self.beta_2_decay = beta_2_decay
|
||||
self.epsilon_1 = epsilon_1
|
||||
|
@ -45,6 +45,7 @@ class Adagrad(optimizer.Optimizer):
|
||||
ema_momentum=0.99,
|
||||
ema_overwrite_frequency=None,
|
||||
name="adagrad",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
@ -56,6 +57,7 @@ class Adagrad(optimizer.Optimizer):
|
||||
ema_momentum=ema_momentum,
|
||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
self.initial_accumulator_value = initial_accumulator_value
|
||||
self.epsilon = epsilon
|
||||
|
@ -55,6 +55,7 @@ class Adam(optimizer.Optimizer):
|
||||
ema_momentum=0.99,
|
||||
ema_overwrite_frequency=None,
|
||||
name="adam",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
@ -66,6 +67,7 @@ class Adam(optimizer.Optimizer):
|
||||
use_ema=use_ema,
|
||||
ema_momentum=ema_momentum,
|
||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||
**kwargs,
|
||||
)
|
||||
self.beta_1 = beta_1
|
||||
self.beta_2 = beta_2
|
||||
|
@ -64,6 +64,7 @@ class Adamax(optimizer.Optimizer):
|
||||
ema_momentum=0.99,
|
||||
ema_overwrite_frequency=None,
|
||||
name="adamax",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
@ -75,6 +76,7 @@ class Adamax(optimizer.Optimizer):
|
||||
use_ema=use_ema,
|
||||
ema_momentum=ema_momentum,
|
||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||
**kwargs,
|
||||
)
|
||||
self.beta_1 = beta_1
|
||||
self.beta_2 = beta_2
|
||||
|
@ -65,6 +65,7 @@ class AdamW(adam.Adam):
|
||||
ema_momentum=0.99,
|
||||
ema_overwrite_frequency=None,
|
||||
name="adamw",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
@ -80,6 +81,7 @@ class AdamW(adam.Adam):
|
||||
use_ema=use_ema,
|
||||
ema_momentum=ema_momentum,
|
||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.weight_decay is None:
|
||||
|
@ -22,9 +22,17 @@ class BaseOptimizer:
|
||||
ema_momentum=0.99,
|
||||
ema_overwrite_frequency=None,
|
||||
name=None,
|
||||
**kwargs,
|
||||
):
|
||||
self._lock = False
|
||||
|
||||
if kwargs.pop("decay", None) is not None:
|
||||
warnings.warn(
|
||||
"Argument `decay` is no longer supported and will be ignored."
|
||||
)
|
||||
if kwargs:
|
||||
raise ValueError(f"Argument(s) not recognized: {kwargs}")
|
||||
|
||||
self.name = name
|
||||
self.weight_decay = weight_decay
|
||||
self.clipnorm = clipnorm
|
||||
|
@ -92,6 +92,7 @@ class Ftrl(optimizer.Optimizer):
|
||||
ema_momentum=0.99,
|
||||
ema_overwrite_frequency=None,
|
||||
name="ftrl",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
@ -103,6 +104,7 @@ class Ftrl(optimizer.Optimizer):
|
||||
use_ema=use_ema,
|
||||
ema_momentum=ema_momentum,
|
||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if initial_accumulator_value < 0.0:
|
||||
|
@ -54,6 +54,7 @@ class Lion(optimizer.Optimizer):
|
||||
ema_momentum=0.99,
|
||||
ema_overwrite_frequency=None,
|
||||
name="lion",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
@ -65,6 +66,7 @@ class Lion(optimizer.Optimizer):
|
||||
use_ema=use_ema,
|
||||
ema_momentum=ema_momentum,
|
||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||
**kwargs,
|
||||
)
|
||||
self.beta_1 = beta_1
|
||||
self.beta_2 = beta_2
|
||||
|
@ -50,6 +50,7 @@ class Nadam(optimizer.Optimizer):
|
||||
ema_momentum=0.99,
|
||||
ema_overwrite_frequency=None,
|
||||
name="nadam",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
@ -61,6 +62,7 @@ class Nadam(optimizer.Optimizer):
|
||||
use_ema=use_ema,
|
||||
ema_momentum=ema_momentum,
|
||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||
**kwargs,
|
||||
)
|
||||
self.beta_1 = beta_1
|
||||
self.beta_2 = beta_2
|
||||
|
@ -65,6 +65,7 @@ class RMSprop(optimizer.Optimizer):
|
||||
ema_momentum=0.99,
|
||||
ema_overwrite_frequency=100,
|
||||
name="rmsprop",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
@ -76,6 +77,7 @@ class RMSprop(optimizer.Optimizer):
|
||||
ema_momentum=ema_momentum,
|
||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
self.rho = rho
|
||||
self.momentum = momentum
|
||||
|
@ -53,6 +53,7 @@ class SGD(optimizer.Optimizer):
|
||||
ema_momentum=0.99,
|
||||
ema_overwrite_frequency=None,
|
||||
name="SGD",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
@ -64,6 +65,7 @@ class SGD(optimizer.Optimizer):
|
||||
use_ema=use_ema,
|
||||
ema_momentum=ema_momentum,
|
||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||
**kwargs,
|
||||
)
|
||||
if not isinstance(momentum, float) or momentum < 0 or momentum > 1:
|
||||
raise ValueError("`momentum` must be a float between [0, 1].")
|
||||
|
Loading…
Reference in New Issue
Block a user