keras/keras_core/optimizers/adafactor.py
2023-05-17 16:06:25 -07:00

191 lines
7.0 KiB
Python

from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.optimizers import optimizer
@keras_core_export(["keras_core.optimizers.Adafactor"])
class Adafactor(optimizer.Optimizer):
"""Optimizer that implements the Adafactor algorithm.
Adafactor is commonly used in NLP tasks, and has the advantage
of taking less memory because it only saves partial information of previous
gradients.
The default argument setup is based on the original paper (see reference).
When gradients are of dimension > 2, Adafactor optimizer will delete the
last 2 dimensions separately in its accumulator variables.
Args:
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to 0.001.
beta_2_decay: float, defaults to -0.8. The decay rate of `beta_2`.
epsilon_1: float, defaults to 1e-30. A small offset to keep demoninator
away from 0.
epsilon_2: float, defaults to 1e-3. A small offset to avoid learning
rate becoming too small by time.
clip_threshold: float, defaults to 1.0. Clipping threshold. This is a
part of Adafactor algorithm, independent from `clipnorm`,
`clipvalue`, and `global_clipnorm`.
relative_step: bool, defaults to True. If `learning_rate` is a
constant and `relative_step=True`, learning rate will be adjusted
based on current iterations. This is a default learning rate decay
in Adafactor.
{{base_optimizer_keyword_args}}
Reference:
- [Shazeer, Noam et al., 2018](https://arxiv.org/abs/1804.04235).
"""
def __init__(
self,
learning_rate=0.001,
beta_2_decay=-0.8,
epsilon_1=1e-30,
epsilon_2=1e-3,
clip_threshold=1.0,
relative_step=True,
weight_decay=None,
clipnorm=None,
clipvalue=None,
global_clipnorm=None,
use_ema=False,
ema_momentum=0.99,
ema_overwrite_frequency=None,
name="adafactor",
):
super().__init__(
learning_rate=learning_rate,
name=name,
weight_decay=weight_decay,
clipnorm=clipnorm,
clipvalue=clipvalue,
global_clipnorm=global_clipnorm,
use_ema=use_ema,
ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency,
)
self.beta_2_decay = beta_2_decay
self.epsilon_1 = epsilon_1
self.epsilon_2 = epsilon_2
self.clip_threshold = clip_threshold
self.relative_step = relative_step
def build(self, var_list):
"""Initialize optimizer variables.
Adam optimizer has 3 types of variables: momentums, velocities and
velocity_hat (only set when amsgrad is applied),
Args:
var_list: list of model variables to build Adam variables on.
"""
if self.built:
return
super().build(var_list)
self._r = []
self._c = []
self._v = []
for var in var_list:
if len(var.shape) < 2:
# Don't factor if variable is of dimension < 2, but we still
# need to create dummy variables as placeholder.
self._r.append(backend.Variable(0, name=var.name))
self._c.append(backend.Variable(0, name=var.name))
else:
# Always factor the last 2 dimenstions.
r_shape = var.shape[:-1]
c_shape = var.shape[:-2] + var.shape[-1]
self._r.append(
self.add_variable(
shape=r_shape,
dtype=var.dtype,
name=var.name,
)
)
self._c.append(
self.add_variable(
shape=c_shape,
dtype=var.dtype,
name=var.name,
)
)
self._v.append(
self.add_variable_from_reference(
reference_variable=var, name="v"
)
)
def _rms(self, x):
return ops.sqrt(ops.mean(ops.square(x)))
def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
lr = ops.cast(learning_rate, variable.dtype)
gradient = ops.cast(gradient, variable.dtype)
epsilon_2 = ops.cast(self.epsilon_2, variable.dtype)
one = ops.cast(1.0, variable.dtype)
local_step = ops.cast(self.iterations + 1, variable.dtype)
if not callable(self._learning_rate) and self.relative_step:
lr = ops.minimum(lr, 1 / ops.sqrt(local_step))
r = self._r[self._get_variable_index(variable)]
c = self._c[self._get_variable_index(variable)]
v = self._v[self._get_variable_index(variable)]
rho_t = ops.minimum(lr, 1 / ops.sqrt(local_step))
alpha_t = ops.maximum(epsilon_2, self._rms(variable)) * rho_t
regulated_grad_square = ops.square(gradient) + self.epsilon_1
beta_2_t = 1 - ops.power(local_step, self.beta_2_decay)
if len(variable.shape) >= 2:
# `r` deletes the last dimension of gradient, so it is of shape
# `gradient.shape[:-1]`.
r.assign(
beta_2_t * r
+ (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-1)
)
# `c` deletes the second last dimension of gradient, so it is of
# shape `gradient.shape[:-2] + gradient.shape[-1]`.
c.assign(
beta_2_t * c
+ (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-2)
)
v.assign(
ops.expand_dims(
r / ops.mean(r, axis=-1, keepdims=True), axis=-1
)
* ops.expand_dims(c, -2)
)
else:
v.assign(beta_2_t * v + (1 - beta_2_t) * regulated_grad_square)
# `convert_to_tensor` unifies the handling of sparse and dense grads.
u_t = gradient / ops.sqrt(v)
u_t_hat = u_t / ops.maximum(one, (self._rms(u_t) / self.clip_threshold))
variable.assign(variable - alpha_t * u_t_hat)
def get_config(self):
config = super().get_config()
config.update(
{
"beta_2_decay": self.beta_2_decay,
"epsilon_1": self.epsilon_1,
"epsilon_2": self.epsilon_2,
"clip_threshold": self.clip_threshold,
"relative_step": self.relative_step,
}
)
return config
Adafactor.__doc__ = Adafactor.__doc__.replace(
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
)