Backwards compat for decay arg
This commit is contained in:
parent
73b22e0f68
commit
e0194197e7
@ -50,6 +50,7 @@ class Adadelta(optimizer.Optimizer):
|
|||||||
ema_momentum=0.99,
|
ema_momentum=0.99,
|
||||||
ema_overwrite_frequency=None,
|
ema_overwrite_frequency=None,
|
||||||
name="adadelta",
|
name="adadelta",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
@ -61,6 +62,7 @@ class Adadelta(optimizer.Optimizer):
|
|||||||
ema_momentum=ema_momentum,
|
ema_momentum=ema_momentum,
|
||||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||||
name=name,
|
name=name,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.rho = rho
|
self.rho = rho
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
|
@ -57,6 +57,7 @@ class Adafactor(optimizer.Optimizer):
|
|||||||
ema_momentum=0.99,
|
ema_momentum=0.99,
|
||||||
ema_overwrite_frequency=None,
|
ema_overwrite_frequency=None,
|
||||||
name="adafactor",
|
name="adafactor",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
@ -68,6 +69,7 @@ class Adafactor(optimizer.Optimizer):
|
|||||||
use_ema=use_ema,
|
use_ema=use_ema,
|
||||||
ema_momentum=ema_momentum,
|
ema_momentum=ema_momentum,
|
||||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.beta_2_decay = beta_2_decay
|
self.beta_2_decay = beta_2_decay
|
||||||
self.epsilon_1 = epsilon_1
|
self.epsilon_1 = epsilon_1
|
||||||
|
@ -45,6 +45,7 @@ class Adagrad(optimizer.Optimizer):
|
|||||||
ema_momentum=0.99,
|
ema_momentum=0.99,
|
||||||
ema_overwrite_frequency=None,
|
ema_overwrite_frequency=None,
|
||||||
name="adagrad",
|
name="adagrad",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
@ -56,6 +57,7 @@ class Adagrad(optimizer.Optimizer):
|
|||||||
ema_momentum=ema_momentum,
|
ema_momentum=ema_momentum,
|
||||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||||
name=name,
|
name=name,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.initial_accumulator_value = initial_accumulator_value
|
self.initial_accumulator_value = initial_accumulator_value
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
|
@ -55,6 +55,7 @@ class Adam(optimizer.Optimizer):
|
|||||||
ema_momentum=0.99,
|
ema_momentum=0.99,
|
||||||
ema_overwrite_frequency=None,
|
ema_overwrite_frequency=None,
|
||||||
name="adam",
|
name="adam",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
@ -66,6 +67,7 @@ class Adam(optimizer.Optimizer):
|
|||||||
use_ema=use_ema,
|
use_ema=use_ema,
|
||||||
ema_momentum=ema_momentum,
|
ema_momentum=ema_momentum,
|
||||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.beta_1 = beta_1
|
self.beta_1 = beta_1
|
||||||
self.beta_2 = beta_2
|
self.beta_2 = beta_2
|
||||||
|
@ -64,6 +64,7 @@ class Adamax(optimizer.Optimizer):
|
|||||||
ema_momentum=0.99,
|
ema_momentum=0.99,
|
||||||
ema_overwrite_frequency=None,
|
ema_overwrite_frequency=None,
|
||||||
name="adamax",
|
name="adamax",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
@ -75,6 +76,7 @@ class Adamax(optimizer.Optimizer):
|
|||||||
use_ema=use_ema,
|
use_ema=use_ema,
|
||||||
ema_momentum=ema_momentum,
|
ema_momentum=ema_momentum,
|
||||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.beta_1 = beta_1
|
self.beta_1 = beta_1
|
||||||
self.beta_2 = beta_2
|
self.beta_2 = beta_2
|
||||||
|
@ -65,6 +65,7 @@ class AdamW(adam.Adam):
|
|||||||
ema_momentum=0.99,
|
ema_momentum=0.99,
|
||||||
ema_overwrite_frequency=None,
|
ema_overwrite_frequency=None,
|
||||||
name="adamw",
|
name="adamw",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
@ -80,6 +81,7 @@ class AdamW(adam.Adam):
|
|||||||
use_ema=use_ema,
|
use_ema=use_ema,
|
||||||
ema_momentum=ema_momentum,
|
ema_momentum=ema_momentum,
|
||||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.weight_decay is None:
|
if self.weight_decay is None:
|
||||||
|
@ -22,9 +22,17 @@ class BaseOptimizer:
|
|||||||
ema_momentum=0.99,
|
ema_momentum=0.99,
|
||||||
ema_overwrite_frequency=None,
|
ema_overwrite_frequency=None,
|
||||||
name=None,
|
name=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self._lock = False
|
self._lock = False
|
||||||
|
|
||||||
|
if kwargs.pop("decay", None) is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Argument `decay` is no longer supported and will be ignored."
|
||||||
|
)
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Argument(s) not recognized: {kwargs}")
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.weight_decay = weight_decay
|
self.weight_decay = weight_decay
|
||||||
self.clipnorm = clipnorm
|
self.clipnorm = clipnorm
|
||||||
|
@ -92,6 +92,7 @@ class Ftrl(optimizer.Optimizer):
|
|||||||
ema_momentum=0.99,
|
ema_momentum=0.99,
|
||||||
ema_overwrite_frequency=None,
|
ema_overwrite_frequency=None,
|
||||||
name="ftrl",
|
name="ftrl",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
@ -103,6 +104,7 @@ class Ftrl(optimizer.Optimizer):
|
|||||||
use_ema=use_ema,
|
use_ema=use_ema,
|
||||||
ema_momentum=ema_momentum,
|
ema_momentum=ema_momentum,
|
||||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if initial_accumulator_value < 0.0:
|
if initial_accumulator_value < 0.0:
|
||||||
|
@ -54,6 +54,7 @@ class Lion(optimizer.Optimizer):
|
|||||||
ema_momentum=0.99,
|
ema_momentum=0.99,
|
||||||
ema_overwrite_frequency=None,
|
ema_overwrite_frequency=None,
|
||||||
name="lion",
|
name="lion",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
@ -65,6 +66,7 @@ class Lion(optimizer.Optimizer):
|
|||||||
use_ema=use_ema,
|
use_ema=use_ema,
|
||||||
ema_momentum=ema_momentum,
|
ema_momentum=ema_momentum,
|
||||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.beta_1 = beta_1
|
self.beta_1 = beta_1
|
||||||
self.beta_2 = beta_2
|
self.beta_2 = beta_2
|
||||||
|
@ -50,6 +50,7 @@ class Nadam(optimizer.Optimizer):
|
|||||||
ema_momentum=0.99,
|
ema_momentum=0.99,
|
||||||
ema_overwrite_frequency=None,
|
ema_overwrite_frequency=None,
|
||||||
name="nadam",
|
name="nadam",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
@ -61,6 +62,7 @@ class Nadam(optimizer.Optimizer):
|
|||||||
use_ema=use_ema,
|
use_ema=use_ema,
|
||||||
ema_momentum=ema_momentum,
|
ema_momentum=ema_momentum,
|
||||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.beta_1 = beta_1
|
self.beta_1 = beta_1
|
||||||
self.beta_2 = beta_2
|
self.beta_2 = beta_2
|
||||||
|
@ -65,6 +65,7 @@ class RMSprop(optimizer.Optimizer):
|
|||||||
ema_momentum=0.99,
|
ema_momentum=0.99,
|
||||||
ema_overwrite_frequency=100,
|
ema_overwrite_frequency=100,
|
||||||
name="rmsprop",
|
name="rmsprop",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
@ -76,6 +77,7 @@ class RMSprop(optimizer.Optimizer):
|
|||||||
ema_momentum=ema_momentum,
|
ema_momentum=ema_momentum,
|
||||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||||
name=name,
|
name=name,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.rho = rho
|
self.rho = rho
|
||||||
self.momentum = momentum
|
self.momentum = momentum
|
||||||
|
@ -53,6 +53,7 @@ class SGD(optimizer.Optimizer):
|
|||||||
ema_momentum=0.99,
|
ema_momentum=0.99,
|
||||||
ema_overwrite_frequency=None,
|
ema_overwrite_frequency=None,
|
||||||
name="SGD",
|
name="SGD",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
@ -64,6 +65,7 @@ class SGD(optimizer.Optimizer):
|
|||||||
use_ema=use_ema,
|
use_ema=use_ema,
|
||||||
ema_momentum=ema_momentum,
|
ema_momentum=ema_momentum,
|
||||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
if not isinstance(momentum, float) or momentum < 0 or momentum > 1:
|
if not isinstance(momentum, float) or momentum < 0 or momentum > 1:
|
||||||
raise ValueError("`momentum` must be a float between [0, 1].")
|
raise ValueError("`momentum` must be a float between [0, 1].")
|
||||||
|
Loading…
Reference in New Issue
Block a user