diff --git a/keras_core/backend/torch/optimizers/torch_lion.py b/keras_core/backend/torch/optimizers/torch_lion.py new file mode 100644 index 000000000..e3d0e0e1f --- /dev/null +++ b/keras_core/backend/torch/optimizers/torch_lion.py @@ -0,0 +1,37 @@ +import torch + +from keras_core import ops +from keras_core import optimizers +from keras_core.backend.torch.optimizers import torch_parallel_optimizer + + +class Lion(torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Lion): + def _parallel_update_step( + self, + grads, + variables, + learning_rate, + ): + keras_variables = variables + variables = [v.value for v in variables] + + dtype = variables[0].dtype + lr = ops.cast(learning_rate, dtype) + + m_list = [ + self._momentums[self._get_variable_index(variable)].value + for variable in keras_variables + ] + + c_t = torch._foreach_mul(m_list, self.beta_1) + torch._foreach_add_(c_t, grads, alpha=1 - self.beta_1) + c_t = [c.sign() for c in c_t] + + torch._foreach_add_( + variables, + torch._foreach_mul(c_t, lr), + alpha=-1, + ) + + torch._foreach_mul_(m_list, self.beta_2) + torch._foreach_add_(m_list, grads, alpha=1 - self.beta_2) diff --git a/keras_core/backend/torch/optimizers/torch_optimizer.py b/keras_core/backend/torch/optimizers/torch_optimizer.py index f4940c62b..9914bb3c7 100644 --- a/keras_core/backend/torch/optimizers/torch_optimizer.py +++ b/keras_core/backend/torch/optimizers/torch_optimizer.py @@ -12,6 +12,7 @@ class TorchOptimizer(BaseOptimizer): from keras_core.backend.torch.optimizers import torch_adam from keras_core.backend.torch.optimizers import torch_adamax from keras_core.backend.torch.optimizers import torch_adamw + from keras_core.backend.torch.optimizers import torch_lion from keras_core.backend.torch.optimizers import torch_nadam from keras_core.backend.torch.optimizers import torch_rmsprop from keras_core.backend.torch.optimizers import torch_sgd @@ -22,6 +23,7 @@ class TorchOptimizer(BaseOptimizer): optimizers.Adam: torch_adam.Adam, optimizers.Adamax: torch_adamax.Adamax, optimizers.AdamW: torch_adamw.AdamW, + optimizers.Lion: torch_lion.Lion, optimizers.Nadam: torch_nadam.Nadam, optimizers.RMSprop: torch_rmsprop.RMSprop, optimizers.SGD: torch_sgd.SGD, diff --git a/keras_core/optimizers/__init__.py b/keras_core/optimizers/__init__.py index ea3706bc7..5cc0c54b0 100644 --- a/keras_core/optimizers/__init__.py +++ b/keras_core/optimizers/__init__.py @@ -6,6 +6,7 @@ from keras_core.optimizers.adam import Adam from keras_core.optimizers.adamax import Adamax from keras_core.optimizers.adamw import AdamW from keras_core.optimizers.ftrl import Ftrl +from keras_core.optimizers.lion import Lion from keras_core.optimizers.nadam import Nadam from keras_core.optimizers.optimizer import Optimizer from keras_core.optimizers.rmsprop import RMSprop @@ -24,6 +25,7 @@ ALL_OBJECTS = { Adafactor, Nadam, Ftrl, + Lion, } ALL_OBJECTS_DICT = {cls.__name__.lower(): cls for cls in ALL_OBJECTS} diff --git a/keras_core/optimizers/lion.py b/keras_core/optimizers/lion.py new file mode 100644 index 000000000..1346622b1 --- /dev/null +++ b/keras_core/optimizers/lion.py @@ -0,0 +1,123 @@ +from keras_core import ops +from keras_core.api_export import keras_core_export +from keras_core.optimizers import optimizer + + +@keras_core_export(["keras_core.optimizers.Lion"]) +class Lion(optimizer.Optimizer): + """Optimizer that implements the Lion algorithm. + + The Lion optimizer is a stochastic-gradient-descent method that uses the + sign operator to control the magnitude of the update, unlike other adaptive + optimizers such as Adam that rely on second-order moments. This make + Lion more memory-efficient as it only keeps track of the momentum. According + to the authors (see reference), its performance gain over Adam grows with + the batch size. Because the update of Lion is produced through the sign + operation, resulting in a larger norm, a suitable learning rate for Lion is + typically 3-10x smaller than that for AdamW. The weight decay for Lion + should be in turn 3-10x larger than that for AdamW to maintain a + similar strength (lr * wd). + + Args: + learning_rate: A float, a + `keras_core.optimizers.schedules.LearningRateSchedule` instance, or + a callable that takes no arguments and returns the actual value to + use. The learning rate. Defaults to `0.001`. + beta_1: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + rate to combine the current gradient and the 1st moment estimate. + Defaults to `0.9`. + beta_2: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + exponential decay rate for the 1st moment estimate. Defaults to + `0.99`. + {{base_optimizer_keyword_args}} + + References: + + - [Chen et al., 2023](http://arxiv.org/abs/2302.06675) + - [Authors' implementation]( + http://github.com/google/automl/tree/master/lion) + + """ + + def __init__( + self, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.99, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + name="lion", + ): + super().__init__( + learning_rate=learning_rate, + name=name, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + ) + self.beta_1 = beta_1 + self.beta_2 = beta_2 + if beta_1 <= 0 or beta_1 > 1: + raise ValueError( + "Argument `beta_1` must be in the [0, 1] range. Otherwise, the " + f"optimizer degenerates to SignSGD. Received: beta_1={beta_1}." + ) + + def build(self, var_list): + """Initialize optimizer variables. + + Lion optimizer has one variable `momentums`. + + Args: + var_list: list of model variables to build Lion variables on. + """ + if self.built: + return + super().build(var_list) + self._momentums = [] + for var in var_list: + self._momentums.append( + self.add_variable_from_reference( + reference_variable=var, name="m" + ) + ) + + def update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + lr = ops.cast(learning_rate, variable.dtype) + gradient = ops.cast(gradient, variable.dtype) + beta_1 = ops.cast(self.beta_1, variable.dtype) + beta_2 = ops.cast(self.beta_2, variable.dtype) + m = self._momentums[self._get_variable_index(variable)] + + # TODO: currently only support dense gradients + variable.assign_sub( + lr * ops.sign(m * beta_1 + gradient * (1.0 - beta_1)) + ) + m.assign(m * beta_2 + gradient * (1.0 - beta_2)) + + def get_config(self): + config = super().get_config() + config.update( + { + "beta_1": self.beta_1, + "beta_2": self.beta_2, + } + ) + return config + + +Lion.__doc__ = Lion.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras_core/optimizers/lion_test.py b/keras_core/optimizers/lion_test.py new file mode 100644 index 000000000..a31509103 --- /dev/null +++ b/keras_core/optimizers/lion_test.py @@ -0,0 +1,85 @@ +import numpy as np +import pytest + +import keras_core +from keras_core import backend +from keras_core import ops +from keras_core import testing +from keras_core.optimizers.lion import Lion + + +class LionTest(testing.TestCase): + def test_config(self): + optimizer = Lion( + learning_rate=0.5, + beta_1=0.5, + beta_2=0.67, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step(self): + optimizer = Lion(learning_rate=0.5) + grads = ops.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4) + + def test_weight_decay(self): + grads, var1, var2, var3 = ( + ops.zeros(()), + backend.Variable(2.0), + backend.Variable(2.0, name="exclude"), + backend.Variable(2.0), + ) + optimizer_1 = Lion(learning_rate=1.0, weight_decay=0.004) + optimizer_1.apply_gradients(zip([grads], [var1])) + + optimizer_2 = Lion(learning_rate=1.0, weight_decay=0.004) + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + optimizer_3 = Lion(learning_rate=1.0, weight_decay=0.004) + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) + self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) + self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) + + def test_correctness_with_golden(self): + optimizer = Lion() + + x = backend.Variable(np.ones([10])) + grads = ops.arange(0.1, 1.1, 0.1) + first_grads = ops.full((10,), 0.01) + + golden = np.tile( + [[0.999], [0.998], [0.997], [0.996], [0.995]], + (1, 10), + ) + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = Lion(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = Lion(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) + + @pytest.mark.requires_trainable_backend + def test_ema(self): + # TODO: test correctness + model = keras_core.Sequential([keras_core.layers.Dense(10)]) + model.compile(optimizer=Lion(use_ema=True), loss="mse") + x = keras_core.ops.zeros((1, 5)) + y = keras_core.ops.zeros((1, 10)) + model.fit(x, y)