106 lines
3.5 KiB
Python
106 lines
3.5 KiB
Python
from keras_core import initializers
|
|
from keras_core import operations as ops
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.optimizers import optimizer
|
|
|
|
|
|
@keras_core_export(["keras_core.optimizers.Adagrad"])
|
|
class Adagrad(optimizer.Optimizer):
|
|
"""Optimizer that implements the Adagrad algorithm.
|
|
|
|
Adagrad is an optimizer with parameter-specific learning rates,
|
|
which are adapted relative to how frequently a parameter gets
|
|
updated during training. The more updates a parameter receives,
|
|
the smaller the updates.
|
|
|
|
Args:
|
|
learning_rate: A float, a
|
|
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
|
|
a callable that takes no arguments and returns the actual value to
|
|
use. The learning rate. Defaults to 0.001. Note that `Adagrad` tends
|
|
to benefit from higher initial learning rate values compared to
|
|
other optimizers. To match the exact form in the original paper,
|
|
use 1.0.
|
|
initial_accumulator_value: Floating point value. Starting value for the
|
|
accumulators (per-parameter momentum values). Must be non-negative.
|
|
epsilon: Small floating point value for maintaining numerical stability.
|
|
{{base_optimizer_keyword_args}}
|
|
|
|
Reference:
|
|
|
|
- [Duchi et al., 2011](
|
|
http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
learning_rate=0.001,
|
|
initial_accumulator_value=0.1,
|
|
epsilon=1e-7,
|
|
weight_decay=None,
|
|
clipnorm=None,
|
|
clipvalue=None,
|
|
global_clipnorm=None,
|
|
use_ema=False,
|
|
ema_momentum=0.99,
|
|
ema_overwrite_frequency=None,
|
|
name="adagrad",
|
|
):
|
|
super().__init__(
|
|
learning_rate=learning_rate,
|
|
weight_decay=weight_decay,
|
|
clipnorm=clipnorm,
|
|
clipvalue=clipvalue,
|
|
global_clipnorm=global_clipnorm,
|
|
use_ema=use_ema,
|
|
ema_momentum=ema_momentum,
|
|
ema_overwrite_frequency=ema_overwrite_frequency,
|
|
name=name,
|
|
)
|
|
self.initial_accumulator_value = initial_accumulator_value
|
|
self.epsilon = epsilon
|
|
|
|
def build(self, var_list):
|
|
if self.built:
|
|
return
|
|
super().build(var_list)
|
|
self._accumulators = []
|
|
initializer = initializers.Constant(self.initial_accumulator_value)
|
|
for var in var_list:
|
|
self._accumulators.append(
|
|
self.add_variable(
|
|
shape=var.shape,
|
|
initializer=initializer,
|
|
dtype=var.dtype,
|
|
name="accumulator",
|
|
)
|
|
)
|
|
|
|
def update_step(self, gradient, variable, learning_rate):
|
|
"""Update step given gradient and the associated model variable."""
|
|
lr = ops.cast(learning_rate, variable.dtype)
|
|
gradient = ops.cast(gradient, variable.dtype)
|
|
|
|
accumulator = self._accumulators[self._get_variable_index(variable)]
|
|
|
|
accumulator.assign(accumulator + gradient * gradient)
|
|
variable.assign(
|
|
variable - (lr * gradient / ops.sqrt(accumulator + self.epsilon))
|
|
)
|
|
|
|
def get_config(self):
|
|
config = super().get_config()
|
|
|
|
config.update(
|
|
{
|
|
"initial_accumulator_value": self.initial_accumulator_value,
|
|
"epsilon": self.epsilon,
|
|
}
|
|
)
|
|
return config
|
|
|
|
|
|
Adagrad.__doc__ = Adagrad.__doc__.replace(
|
|
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
|
)
|