Backwards compat for decay arg

This commit is contained in:
Francois Chollet 2023-08-11 12:34:29 -07:00
parent 73b22e0f68
commit e0194197e7
12 changed files with 30 additions and 0 deletions

@ -50,6 +50,7 @@ class Adadelta(optimizer.Optimizer):
ema_momentum=0.99, ema_momentum=0.99,
ema_overwrite_frequency=None, ema_overwrite_frequency=None,
name="adadelta", name="adadelta",
**kwargs,
): ):
super().__init__( super().__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
@ -61,6 +62,7 @@ class Adadelta(optimizer.Optimizer):
ema_momentum=ema_momentum, ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency, ema_overwrite_frequency=ema_overwrite_frequency,
name=name, name=name,
**kwargs,
) )
self.rho = rho self.rho = rho
self.epsilon = epsilon self.epsilon = epsilon

@ -57,6 +57,7 @@ class Adafactor(optimizer.Optimizer):
ema_momentum=0.99, ema_momentum=0.99,
ema_overwrite_frequency=None, ema_overwrite_frequency=None,
name="adafactor", name="adafactor",
**kwargs,
): ):
super().__init__( super().__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
@ -68,6 +69,7 @@ class Adafactor(optimizer.Optimizer):
use_ema=use_ema, use_ema=use_ema,
ema_momentum=ema_momentum, ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency, ema_overwrite_frequency=ema_overwrite_frequency,
**kwargs,
) )
self.beta_2_decay = beta_2_decay self.beta_2_decay = beta_2_decay
self.epsilon_1 = epsilon_1 self.epsilon_1 = epsilon_1

@ -45,6 +45,7 @@ class Adagrad(optimizer.Optimizer):
ema_momentum=0.99, ema_momentum=0.99,
ema_overwrite_frequency=None, ema_overwrite_frequency=None,
name="adagrad", name="adagrad",
**kwargs,
): ):
super().__init__( super().__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
@ -56,6 +57,7 @@ class Adagrad(optimizer.Optimizer):
ema_momentum=ema_momentum, ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency, ema_overwrite_frequency=ema_overwrite_frequency,
name=name, name=name,
**kwargs,
) )
self.initial_accumulator_value = initial_accumulator_value self.initial_accumulator_value = initial_accumulator_value
self.epsilon = epsilon self.epsilon = epsilon

@ -55,6 +55,7 @@ class Adam(optimizer.Optimizer):
ema_momentum=0.99, ema_momentum=0.99,
ema_overwrite_frequency=None, ema_overwrite_frequency=None,
name="adam", name="adam",
**kwargs,
): ):
super().__init__( super().__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
@ -66,6 +67,7 @@ class Adam(optimizer.Optimizer):
use_ema=use_ema, use_ema=use_ema,
ema_momentum=ema_momentum, ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency, ema_overwrite_frequency=ema_overwrite_frequency,
**kwargs,
) )
self.beta_1 = beta_1 self.beta_1 = beta_1
self.beta_2 = beta_2 self.beta_2 = beta_2

@ -64,6 +64,7 @@ class Adamax(optimizer.Optimizer):
ema_momentum=0.99, ema_momentum=0.99,
ema_overwrite_frequency=None, ema_overwrite_frequency=None,
name="adamax", name="adamax",
**kwargs,
): ):
super().__init__( super().__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
@ -75,6 +76,7 @@ class Adamax(optimizer.Optimizer):
use_ema=use_ema, use_ema=use_ema,
ema_momentum=ema_momentum, ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency, ema_overwrite_frequency=ema_overwrite_frequency,
**kwargs,
) )
self.beta_1 = beta_1 self.beta_1 = beta_1
self.beta_2 = beta_2 self.beta_2 = beta_2

@ -65,6 +65,7 @@ class AdamW(adam.Adam):
ema_momentum=0.99, ema_momentum=0.99,
ema_overwrite_frequency=None, ema_overwrite_frequency=None,
name="adamw", name="adamw",
**kwargs,
): ):
super().__init__( super().__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
@ -80,6 +81,7 @@ class AdamW(adam.Adam):
use_ema=use_ema, use_ema=use_ema,
ema_momentum=ema_momentum, ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency, ema_overwrite_frequency=ema_overwrite_frequency,
**kwargs,
) )
if self.weight_decay is None: if self.weight_decay is None:

@ -22,9 +22,17 @@ class BaseOptimizer:
ema_momentum=0.99, ema_momentum=0.99,
ema_overwrite_frequency=None, ema_overwrite_frequency=None,
name=None, name=None,
**kwargs,
): ):
self._lock = False self._lock = False
if kwargs.pop("decay", None) is not None:
warnings.warn(
"Argument `decay` is no longer supported and will be ignored."
)
if kwargs:
raise ValueError(f"Argument(s) not recognized: {kwargs}")
self.name = name self.name = name
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.clipnorm = clipnorm self.clipnorm = clipnorm

@ -92,6 +92,7 @@ class Ftrl(optimizer.Optimizer):
ema_momentum=0.99, ema_momentum=0.99,
ema_overwrite_frequency=None, ema_overwrite_frequency=None,
name="ftrl", name="ftrl",
**kwargs,
): ):
super().__init__( super().__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
@ -103,6 +104,7 @@ class Ftrl(optimizer.Optimizer):
use_ema=use_ema, use_ema=use_ema,
ema_momentum=ema_momentum, ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency, ema_overwrite_frequency=ema_overwrite_frequency,
**kwargs,
) )
if initial_accumulator_value < 0.0: if initial_accumulator_value < 0.0:

@ -54,6 +54,7 @@ class Lion(optimizer.Optimizer):
ema_momentum=0.99, ema_momentum=0.99,
ema_overwrite_frequency=None, ema_overwrite_frequency=None,
name="lion", name="lion",
**kwargs,
): ):
super().__init__( super().__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
@ -65,6 +66,7 @@ class Lion(optimizer.Optimizer):
use_ema=use_ema, use_ema=use_ema,
ema_momentum=ema_momentum, ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency, ema_overwrite_frequency=ema_overwrite_frequency,
**kwargs,
) )
self.beta_1 = beta_1 self.beta_1 = beta_1
self.beta_2 = beta_2 self.beta_2 = beta_2

@ -50,6 +50,7 @@ class Nadam(optimizer.Optimizer):
ema_momentum=0.99, ema_momentum=0.99,
ema_overwrite_frequency=None, ema_overwrite_frequency=None,
name="nadam", name="nadam",
**kwargs,
): ):
super().__init__( super().__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
@ -61,6 +62,7 @@ class Nadam(optimizer.Optimizer):
use_ema=use_ema, use_ema=use_ema,
ema_momentum=ema_momentum, ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency, ema_overwrite_frequency=ema_overwrite_frequency,
**kwargs,
) )
self.beta_1 = beta_1 self.beta_1 = beta_1
self.beta_2 = beta_2 self.beta_2 = beta_2

@ -65,6 +65,7 @@ class RMSprop(optimizer.Optimizer):
ema_momentum=0.99, ema_momentum=0.99,
ema_overwrite_frequency=100, ema_overwrite_frequency=100,
name="rmsprop", name="rmsprop",
**kwargs,
): ):
super().__init__( super().__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
@ -76,6 +77,7 @@ class RMSprop(optimizer.Optimizer):
ema_momentum=ema_momentum, ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency, ema_overwrite_frequency=ema_overwrite_frequency,
name=name, name=name,
**kwargs,
) )
self.rho = rho self.rho = rho
self.momentum = momentum self.momentum = momentum

@ -53,6 +53,7 @@ class SGD(optimizer.Optimizer):
ema_momentum=0.99, ema_momentum=0.99,
ema_overwrite_frequency=None, ema_overwrite_frequency=None,
name="SGD", name="SGD",
**kwargs,
): ):
super().__init__( super().__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
@ -64,6 +65,7 @@ class SGD(optimizer.Optimizer):
use_ema=use_ema, use_ema=use_ema,
ema_momentum=ema_momentum, ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency, ema_overwrite_frequency=ema_overwrite_frequency,
**kwargs,
) )
if not isinstance(momentum, float) or momentum < 0 or momentum > 1: if not isinstance(momentum, float) or momentum < 0 or momentum > 1:
raise ValueError("`momentum` must be a float between [0, 1].") raise ValueError("`momentum` must be a float between [0, 1].")