Add Lion optimizer (#610)
* Add Lion optimizer * Address comments * Add torch_lion optimizer * Update correctness test * Trigger Github syncing
This commit is contained in:
parent
30121f4b5f
commit
794bbd7d29
37
keras_core/backend/torch/optimizers/torch_lion.py
Normal file
37
keras_core/backend/torch/optimizers/torch_lion.py
Normal file
@ -0,0 +1,37 @@
|
||||
import torch
|
||||
|
||||
from keras_core import ops
|
||||
from keras_core import optimizers
|
||||
from keras_core.backend.torch.optimizers import torch_parallel_optimizer
|
||||
|
||||
|
||||
class Lion(torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Lion):
|
||||
def _parallel_update_step(
|
||||
self,
|
||||
grads,
|
||||
variables,
|
||||
learning_rate,
|
||||
):
|
||||
keras_variables = variables
|
||||
variables = [v.value for v in variables]
|
||||
|
||||
dtype = variables[0].dtype
|
||||
lr = ops.cast(learning_rate, dtype)
|
||||
|
||||
m_list = [
|
||||
self._momentums[self._get_variable_index(variable)].value
|
||||
for variable in keras_variables
|
||||
]
|
||||
|
||||
c_t = torch._foreach_mul(m_list, self.beta_1)
|
||||
torch._foreach_add_(c_t, grads, alpha=1 - self.beta_1)
|
||||
c_t = [c.sign() for c in c_t]
|
||||
|
||||
torch._foreach_add_(
|
||||
variables,
|
||||
torch._foreach_mul(c_t, lr),
|
||||
alpha=-1,
|
||||
)
|
||||
|
||||
torch._foreach_mul_(m_list, self.beta_2)
|
||||
torch._foreach_add_(m_list, grads, alpha=1 - self.beta_2)
|
@ -12,6 +12,7 @@ class TorchOptimizer(BaseOptimizer):
|
||||
from keras_core.backend.torch.optimizers import torch_adam
|
||||
from keras_core.backend.torch.optimizers import torch_adamax
|
||||
from keras_core.backend.torch.optimizers import torch_adamw
|
||||
from keras_core.backend.torch.optimizers import torch_lion
|
||||
from keras_core.backend.torch.optimizers import torch_nadam
|
||||
from keras_core.backend.torch.optimizers import torch_rmsprop
|
||||
from keras_core.backend.torch.optimizers import torch_sgd
|
||||
@ -22,6 +23,7 @@ class TorchOptimizer(BaseOptimizer):
|
||||
optimizers.Adam: torch_adam.Adam,
|
||||
optimizers.Adamax: torch_adamax.Adamax,
|
||||
optimizers.AdamW: torch_adamw.AdamW,
|
||||
optimizers.Lion: torch_lion.Lion,
|
||||
optimizers.Nadam: torch_nadam.Nadam,
|
||||
optimizers.RMSprop: torch_rmsprop.RMSprop,
|
||||
optimizers.SGD: torch_sgd.SGD,
|
||||
|
@ -6,6 +6,7 @@ from keras_core.optimizers.adam import Adam
|
||||
from keras_core.optimizers.adamax import Adamax
|
||||
from keras_core.optimizers.adamw import AdamW
|
||||
from keras_core.optimizers.ftrl import Ftrl
|
||||
from keras_core.optimizers.lion import Lion
|
||||
from keras_core.optimizers.nadam import Nadam
|
||||
from keras_core.optimizers.optimizer import Optimizer
|
||||
from keras_core.optimizers.rmsprop import RMSprop
|
||||
@ -24,6 +25,7 @@ ALL_OBJECTS = {
|
||||
Adafactor,
|
||||
Nadam,
|
||||
Ftrl,
|
||||
Lion,
|
||||
}
|
||||
ALL_OBJECTS_DICT = {cls.__name__.lower(): cls for cls in ALL_OBJECTS}
|
||||
|
||||
|
123
keras_core/optimizers/lion.py
Normal file
123
keras_core/optimizers/lion.py
Normal file
@ -0,0 +1,123 @@
|
||||
from keras_core import ops
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.optimizers import optimizer
|
||||
|
||||
|
||||
@keras_core_export(["keras_core.optimizers.Lion"])
|
||||
class Lion(optimizer.Optimizer):
|
||||
"""Optimizer that implements the Lion algorithm.
|
||||
|
||||
The Lion optimizer is a stochastic-gradient-descent method that uses the
|
||||
sign operator to control the magnitude of the update, unlike other adaptive
|
||||
optimizers such as Adam that rely on second-order moments. This make
|
||||
Lion more memory-efficient as it only keeps track of the momentum. According
|
||||
to the authors (see reference), its performance gain over Adam grows with
|
||||
the batch size. Because the update of Lion is produced through the sign
|
||||
operation, resulting in a larger norm, a suitable learning rate for Lion is
|
||||
typically 3-10x smaller than that for AdamW. The weight decay for Lion
|
||||
should be in turn 3-10x larger than that for AdamW to maintain a
|
||||
similar strength (lr * wd).
|
||||
|
||||
Args:
|
||||
learning_rate: A float, a
|
||||
`keras_core.optimizers.schedules.LearningRateSchedule` instance, or
|
||||
a callable that takes no arguments and returns the actual value to
|
||||
use. The learning rate. Defaults to `0.001`.
|
||||
beta_1: A float value or a constant float tensor, or a callable
|
||||
that takes no arguments and returns the actual value to use. The
|
||||
rate to combine the current gradient and the 1st moment estimate.
|
||||
Defaults to `0.9`.
|
||||
beta_2: A float value or a constant float tensor, or a callable
|
||||
that takes no arguments and returns the actual value to use. The
|
||||
exponential decay rate for the 1st moment estimate. Defaults to
|
||||
`0.99`.
|
||||
{{base_optimizer_keyword_args}}
|
||||
|
||||
References:
|
||||
|
||||
- [Chen et al., 2023](http://arxiv.org/abs/2302.06675)
|
||||
- [Authors' implementation](
|
||||
http://github.com/google/automl/tree/master/lion)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate=0.001,
|
||||
beta_1=0.9,
|
||||
beta_2=0.99,
|
||||
weight_decay=None,
|
||||
clipnorm=None,
|
||||
clipvalue=None,
|
||||
global_clipnorm=None,
|
||||
use_ema=False,
|
||||
ema_momentum=0.99,
|
||||
ema_overwrite_frequency=None,
|
||||
name="lion",
|
||||
):
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
name=name,
|
||||
weight_decay=weight_decay,
|
||||
clipnorm=clipnorm,
|
||||
clipvalue=clipvalue,
|
||||
global_clipnorm=global_clipnorm,
|
||||
use_ema=use_ema,
|
||||
ema_momentum=ema_momentum,
|
||||
ema_overwrite_frequency=ema_overwrite_frequency,
|
||||
)
|
||||
self.beta_1 = beta_1
|
||||
self.beta_2 = beta_2
|
||||
if beta_1 <= 0 or beta_1 > 1:
|
||||
raise ValueError(
|
||||
"Argument `beta_1` must be in the [0, 1] range. Otherwise, the "
|
||||
f"optimizer degenerates to SignSGD. Received: beta_1={beta_1}."
|
||||
)
|
||||
|
||||
def build(self, var_list):
|
||||
"""Initialize optimizer variables.
|
||||
|
||||
Lion optimizer has one variable `momentums`.
|
||||
|
||||
Args:
|
||||
var_list: list of model variables to build Lion variables on.
|
||||
"""
|
||||
if self.built:
|
||||
return
|
||||
super().build(var_list)
|
||||
self._momentums = []
|
||||
for var in var_list:
|
||||
self._momentums.append(
|
||||
self.add_variable_from_reference(
|
||||
reference_variable=var, name="m"
|
||||
)
|
||||
)
|
||||
|
||||
def update_step(self, gradient, variable, learning_rate):
|
||||
"""Update step given gradient and the associated model variable."""
|
||||
lr = ops.cast(learning_rate, variable.dtype)
|
||||
gradient = ops.cast(gradient, variable.dtype)
|
||||
beta_1 = ops.cast(self.beta_1, variable.dtype)
|
||||
beta_2 = ops.cast(self.beta_2, variable.dtype)
|
||||
m = self._momentums[self._get_variable_index(variable)]
|
||||
|
||||
# TODO: currently only support dense gradients
|
||||
variable.assign_sub(
|
||||
lr * ops.sign(m * beta_1 + gradient * (1.0 - beta_1))
|
||||
)
|
||||
m.assign(m * beta_2 + gradient * (1.0 - beta_2))
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config()
|
||||
config.update(
|
||||
{
|
||||
"beta_1": self.beta_1,
|
||||
"beta_2": self.beta_2,
|
||||
}
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
Lion.__doc__ = Lion.__doc__.replace(
|
||||
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
||||
)
|
85
keras_core/optimizers/lion_test.py
Normal file
85
keras_core/optimizers/lion_test.py
Normal file
@ -0,0 +1,85 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import keras_core
|
||||
from keras_core import backend
|
||||
from keras_core import ops
|
||||
from keras_core import testing
|
||||
from keras_core.optimizers.lion import Lion
|
||||
|
||||
|
||||
class LionTest(testing.TestCase):
|
||||
def test_config(self):
|
||||
optimizer = Lion(
|
||||
learning_rate=0.5,
|
||||
beta_1=0.5,
|
||||
beta_2=0.67,
|
||||
)
|
||||
self.run_class_serialization_test(optimizer)
|
||||
|
||||
def test_single_step(self):
|
||||
optimizer = Lion(learning_rate=0.5)
|
||||
grads = ops.array([1.0, 6.0, 7.0, 2.0])
|
||||
vars = backend.Variable([1.0, 2.0, 3.0, 4.0])
|
||||
optimizer.apply_gradients(zip([grads], [vars]))
|
||||
self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_weight_decay(self):
|
||||
grads, var1, var2, var3 = (
|
||||
ops.zeros(()),
|
||||
backend.Variable(2.0),
|
||||
backend.Variable(2.0, name="exclude"),
|
||||
backend.Variable(2.0),
|
||||
)
|
||||
optimizer_1 = Lion(learning_rate=1.0, weight_decay=0.004)
|
||||
optimizer_1.apply_gradients(zip([grads], [var1]))
|
||||
|
||||
optimizer_2 = Lion(learning_rate=1.0, weight_decay=0.004)
|
||||
optimizer_2.exclude_from_weight_decay(var_names=["exclude"])
|
||||
optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))
|
||||
|
||||
optimizer_3 = Lion(learning_rate=1.0, weight_decay=0.004)
|
||||
optimizer_3.exclude_from_weight_decay(var_list=[var3])
|
||||
optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))
|
||||
|
||||
self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)
|
||||
self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)
|
||||
self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)
|
||||
|
||||
def test_correctness_with_golden(self):
|
||||
optimizer = Lion()
|
||||
|
||||
x = backend.Variable(np.ones([10]))
|
||||
grads = ops.arange(0.1, 1.1, 0.1)
|
||||
first_grads = ops.full((10,), 0.01)
|
||||
|
||||
golden = np.tile(
|
||||
[[0.999], [0.998], [0.997], [0.996], [0.995]],
|
||||
(1, 10),
|
||||
)
|
||||
|
||||
optimizer.apply_gradients(zip([first_grads], [x]))
|
||||
for i in range(5):
|
||||
self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)
|
||||
optimizer.apply_gradients(zip([grads], [x]))
|
||||
|
||||
def test_clip_norm(self):
|
||||
optimizer = Lion(clipnorm=1)
|
||||
grad = [np.array([100.0, 100.0])]
|
||||
clipped_grad = optimizer._clip_gradients(grad)
|
||||
self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])
|
||||
|
||||
def test_clip_value(self):
|
||||
optimizer = Lion(clipvalue=1)
|
||||
grad = [np.array([100.0, 100.0])]
|
||||
clipped_grad = optimizer._clip_gradients(grad)
|
||||
self.assertAllClose(clipped_grad[0], [1.0, 1.0])
|
||||
|
||||
@pytest.mark.requires_trainable_backend
|
||||
def test_ema(self):
|
||||
# TODO: test correctness
|
||||
model = keras_core.Sequential([keras_core.layers.Dense(10)])
|
||||
model.compile(optimizer=Lion(use_ema=True), loss="mse")
|
||||
x = keras_core.ops.zeros((1, 5))
|
||||
y = keras_core.ops.zeros((1, 10))
|
||||
model.fit(x, y)
|
Loading…
Reference in New Issue
Block a user