Add Lion optimizer (#610)

* Add Lion optimizer

* Address comments

* Add torch_lion optimizer

* Update correctness test

* Trigger Github syncing
This commit is contained in:
HongYu 2023-07-27 10:02:26 +08:00 committed by Francois Chollet
parent 30121f4b5f
commit 794bbd7d29
5 changed files with 249 additions and 0 deletions

@ -0,0 +1,37 @@
import torch
from keras_core import ops
from keras_core import optimizers
from keras_core.backend.torch.optimizers import torch_parallel_optimizer
class Lion(torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Lion):
def _parallel_update_step(
self,
grads,
variables,
learning_rate,
):
keras_variables = variables
variables = [v.value for v in variables]
dtype = variables[0].dtype
lr = ops.cast(learning_rate, dtype)
m_list = [
self._momentums[self._get_variable_index(variable)].value
for variable in keras_variables
]
c_t = torch._foreach_mul(m_list, self.beta_1)
torch._foreach_add_(c_t, grads, alpha=1 - self.beta_1)
c_t = [c.sign() for c in c_t]
torch._foreach_add_(
variables,
torch._foreach_mul(c_t, lr),
alpha=-1,
)
torch._foreach_mul_(m_list, self.beta_2)
torch._foreach_add_(m_list, grads, alpha=1 - self.beta_2)

@ -12,6 +12,7 @@ class TorchOptimizer(BaseOptimizer):
from keras_core.backend.torch.optimizers import torch_adam
from keras_core.backend.torch.optimizers import torch_adamax
from keras_core.backend.torch.optimizers import torch_adamw
from keras_core.backend.torch.optimizers import torch_lion
from keras_core.backend.torch.optimizers import torch_nadam
from keras_core.backend.torch.optimizers import torch_rmsprop
from keras_core.backend.torch.optimizers import torch_sgd
@ -22,6 +23,7 @@ class TorchOptimizer(BaseOptimizer):
optimizers.Adam: torch_adam.Adam,
optimizers.Adamax: torch_adamax.Adamax,
optimizers.AdamW: torch_adamw.AdamW,
optimizers.Lion: torch_lion.Lion,
optimizers.Nadam: torch_nadam.Nadam,
optimizers.RMSprop: torch_rmsprop.RMSprop,
optimizers.SGD: torch_sgd.SGD,

@ -6,6 +6,7 @@ from keras_core.optimizers.adam import Adam
from keras_core.optimizers.adamax import Adamax
from keras_core.optimizers.adamw import AdamW
from keras_core.optimizers.ftrl import Ftrl
from keras_core.optimizers.lion import Lion
from keras_core.optimizers.nadam import Nadam
from keras_core.optimizers.optimizer import Optimizer
from keras_core.optimizers.rmsprop import RMSprop
@ -24,6 +25,7 @@ ALL_OBJECTS = {
Adafactor,
Nadam,
Ftrl,
Lion,
}
ALL_OBJECTS_DICT = {cls.__name__.lower(): cls for cls in ALL_OBJECTS}

@ -0,0 +1,123 @@
from keras_core import ops
from keras_core.api_export import keras_core_export
from keras_core.optimizers import optimizer
@keras_core_export(["keras_core.optimizers.Lion"])
class Lion(optimizer.Optimizer):
"""Optimizer that implements the Lion algorithm.
The Lion optimizer is a stochastic-gradient-descent method that uses the
sign operator to control the magnitude of the update, unlike other adaptive
optimizers such as Adam that rely on second-order moments. This make
Lion more memory-efficient as it only keeps track of the momentum. According
to the authors (see reference), its performance gain over Adam grows with
the batch size. Because the update of Lion is produced through the sign
operation, resulting in a larger norm, a suitable learning rate for Lion is
typically 3-10x smaller than that for AdamW. The weight decay for Lion
should be in turn 3-10x larger than that for AdamW to maintain a
similar strength (lr * wd).
Args:
learning_rate: A float, a
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
a callable that takes no arguments and returns the actual value to
use. The learning rate. Defaults to `0.001`.
beta_1: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use. The
rate to combine the current gradient and the 1st moment estimate.
Defaults to `0.9`.
beta_2: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use. The
exponential decay rate for the 1st moment estimate. Defaults to
`0.99`.
{{base_optimizer_keyword_args}}
References:
- [Chen et al., 2023](http://arxiv.org/abs/2302.06675)
- [Authors' implementation](
http://github.com/google/automl/tree/master/lion)
"""
def __init__(
self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.99,
weight_decay=None,
clipnorm=None,
clipvalue=None,
global_clipnorm=None,
use_ema=False,
ema_momentum=0.99,
ema_overwrite_frequency=None,
name="lion",
):
super().__init__(
learning_rate=learning_rate,
name=name,
weight_decay=weight_decay,
clipnorm=clipnorm,
clipvalue=clipvalue,
global_clipnorm=global_clipnorm,
use_ema=use_ema,
ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency,
)
self.beta_1 = beta_1
self.beta_2 = beta_2
if beta_1 <= 0 or beta_1 > 1:
raise ValueError(
"Argument `beta_1` must be in the [0, 1] range. Otherwise, the "
f"optimizer degenerates to SignSGD. Received: beta_1={beta_1}."
)
def build(self, var_list):
"""Initialize optimizer variables.
Lion optimizer has one variable `momentums`.
Args:
var_list: list of model variables to build Lion variables on.
"""
if self.built:
return
super().build(var_list)
self._momentums = []
for var in var_list:
self._momentums.append(
self.add_variable_from_reference(
reference_variable=var, name="m"
)
)
def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
lr = ops.cast(learning_rate, variable.dtype)
gradient = ops.cast(gradient, variable.dtype)
beta_1 = ops.cast(self.beta_1, variable.dtype)
beta_2 = ops.cast(self.beta_2, variable.dtype)
m = self._momentums[self._get_variable_index(variable)]
# TODO: currently only support dense gradients
variable.assign_sub(
lr * ops.sign(m * beta_1 + gradient * (1.0 - beta_1))
)
m.assign(m * beta_2 + gradient * (1.0 - beta_2))
def get_config(self):
config = super().get_config()
config.update(
{
"beta_1": self.beta_1,
"beta_2": self.beta_2,
}
)
return config
Lion.__doc__ = Lion.__doc__.replace(
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
)

@ -0,0 +1,85 @@
import numpy as np
import pytest
import keras_core
from keras_core import backend
from keras_core import ops
from keras_core import testing
from keras_core.optimizers.lion import Lion
class LionTest(testing.TestCase):
def test_config(self):
optimizer = Lion(
learning_rate=0.5,
beta_1=0.5,
beta_2=0.67,
)
self.run_class_serialization_test(optimizer)
def test_single_step(self):
optimizer = Lion(learning_rate=0.5)
grads = ops.array([1.0, 6.0, 7.0, 2.0])
vars = backend.Variable([1.0, 2.0, 3.0, 4.0])
optimizer.apply_gradients(zip([grads], [vars]))
self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4)
def test_weight_decay(self):
grads, var1, var2, var3 = (
ops.zeros(()),
backend.Variable(2.0),
backend.Variable(2.0, name="exclude"),
backend.Variable(2.0),
)
optimizer_1 = Lion(learning_rate=1.0, weight_decay=0.004)
optimizer_1.apply_gradients(zip([grads], [var1]))
optimizer_2 = Lion(learning_rate=1.0, weight_decay=0.004)
optimizer_2.exclude_from_weight_decay(var_names=["exclude"])
optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))
optimizer_3 = Lion(learning_rate=1.0, weight_decay=0.004)
optimizer_3.exclude_from_weight_decay(var_list=[var3])
optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))
self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)
self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)
self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)
def test_correctness_with_golden(self):
optimizer = Lion()
x = backend.Variable(np.ones([10]))
grads = ops.arange(0.1, 1.1, 0.1)
first_grads = ops.full((10,), 0.01)
golden = np.tile(
[[0.999], [0.998], [0.997], [0.996], [0.995]],
(1, 10),
)
optimizer.apply_gradients(zip([first_grads], [x]))
for i in range(5):
self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)
optimizer.apply_gradients(zip([grads], [x]))
def test_clip_norm(self):
optimizer = Lion(clipnorm=1)
grad = [np.array([100.0, 100.0])]
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])
def test_clip_value(self):
optimizer = Lion(clipvalue=1)
grad = [np.array([100.0, 100.0])]
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [1.0, 1.0])
@pytest.mark.requires_trainable_backend
def test_ema(self):
# TODO: test correctness
model = keras_core.Sequential([keras_core.layers.Dense(10)])
model.compile(optimizer=Lion(use_ema=True), loss="mse")
x = keras_core.ops.zeros((1, 5))
y = keras_core.ops.zeros((1, 10))
model.fit(x, y)