keras/keras_core/optimizers/sgd.py
Neel Kovelamudi 4212fdd5cb Adds Adagrad and Adadelta optimizers and associated tests. (#72)
* Add golden correctness tests for Adam and SGD

* Fix dtype issues

* Sync with main (#56)

* Minor touch ups

* Fix a pretty major bug

* Format code

* Big rethink of Variable API

* Make build-by-run the default build(), leveraging new zero_history KerasTensor mode

* Minor fixes

* Format code

* Switch back to build-by-eager-run for simplicity

* Add raise upon build failure

* Work around JAX bug.

* Add a few more tests.

* Add saving tests

* Adds test suite for SGD and golden correctness tests for all optimizers (#40)

* Add golden correctness tests for Adam and SGD

* Fix dtype issues

* Add binary accuracy (#41)

* chore: adding binary accuracy

* chore: fix docstring

* Add tests for add_loss and activity regularization.

* Reformat code

* Add ActivityRegularization layer

* Fix JAX CI.

* Add Lambda Callback (#42)

* Add LambdaCallback

* Add Lambda Callback

* Add Lambda Callback

* Rename lambda_callback_test.py

* Add einsum (#43)

* Add einsum

* address comments

* Fix format line length (#45)

* Add Embedding layer

* Shorten lines

* Add .vscode to .gitignore (#46)

* rm vscode settings

* add .vscode to gitignore

* Set demo program backend (#48)

* Add tests for training arg resolution in Layer.

* Implement mixed precision.

* Replace backend.execute with backend.numpy.XXX (#50)

* Add cosine similarity loss and update l2_normalize from regularizers (#34)

* Begin cosine loss

* Add testing for cosine similarity

* Fix formatting

* Docstring standardization

* Formatting

* Create numerical_utils

* Fix issue with call context lingering.

* Add the EarlyStopping callback (#44)

* add earlystopping callback

* addressing comments

* address comments

* addressing comments

* remove unused imports

* re-enable imports checks (#51)

* Add nn.one_hot (#52)

* Add GaussianDropout layer.

* Add GaussianNoise layer

* Add Categorical Accuracy Metric (#47)

* chore: adding categorical accuracy metric

* chore: reformat docstrings

* chore: reformat

* chore: ndims with len

* refactor the docstring

* Fix typos

* Implement masking.

---------

Co-authored-by: Francois Chollet <francois.chollet@gmail.com>
Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>
Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com>
Co-authored-by: Chen Qian <chenmoney@google.com>
Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com>
Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com>

* Adds rmsprop optimizer and tests

* Add AdamW optimizer and tests, minor formatting changes

* Implemented formatting fixes

* Adds clip norm and clip value tests to Adam

* Adds Adagrad and Adadelta optimizers

* Applies fixes to formatting and deletes unnecessary kwargs

---------

Co-authored-by: Francois Chollet <francois.chollet@gmail.com>
Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>
Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com>
Co-authored-by: Chen Qian <chenmoney@google.com>
Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com>
Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com>
2023-05-03 02:12:03 +00:00

124 lines
3.8 KiB
Python

from keras_core import operations as ops
from keras_core.optimizers import optimizer
class SGD(optimizer.Optimizer):
"""Gradient descent (with momentum) optimizer.
Update rule for parameter `w` with gradient `g` when `momentum` is 0:
```python
w = w - learning_rate * g
```
Update rule when `momentum` is larger than 0:
```python
velocity = momentum * velocity - learning_rate * g
w = w + velocity
```
When `nesterov=True`, this rule becomes:
```python
velocity = momentum * velocity - learning_rate * g
w = w + momentum * velocity - learning_rate * g
```
Args:
learning_rate: A `Tensor`, floating point value, or a schedule that is a
`tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
that takes no arguments and returns the actual value to use. The
learning rate. Defaults to 0.001.
momentum: float hyperparameter >= 0 that accelerates gradient descent in
the relevant direction and dampens oscillations. Defaults to 0, i.e.,
vanilla gradient descent.
nesterov: boolean. Whether to apply Nesterov momentum.
Defaults to `False`.
{{base_optimizer_keyword_args}}
"""
def __init__(
self,
learning_rate=0.01,
momentum=0.0,
nesterov=False,
weight_decay=None,
clipnorm=None,
clipvalue=None,
global_clipnorm=None,
use_ema=False,
ema_momentum=0.99,
ema_overwrite_frequency=None,
name="SGD",
):
super().__init__(
learning_rate=learning_rate,
name=name,
weight_decay=weight_decay,
clipnorm=clipnorm,
clipvalue=clipvalue,
global_clipnorm=global_clipnorm,
use_ema=use_ema,
ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency,
)
if not isinstance(momentum, float) or momentum < 0 or momentum > 1:
raise ValueError("`momentum` must be a float between [0, 1].")
self.momentum = momentum
self.nesterov = nesterov
def build(self, variables):
"""Initialize optimizer variables.
SGD optimizer has one variable `momentums`, only set if `self.momentum`
is not 0.
Args:
var_list: list of model variables to build SGD variables on.
"""
if self.built:
return
super().build(variables)
self.momentums = []
for variable in variables:
self.momentums.append(
self.add_variable_from_reference(
reference_variable=variable, name="m"
)
)
def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
learning_rate = ops.cast(learning_rate, variable.dtype)
gradient = ops.cast(gradient, variable.dtype)
m = None
momentum = ops.cast(self.momentum, variable.dtype)
m = self.momentums[self._get_variable_index(variable)]
if m is not None:
m.assign(-gradient * learning_rate + m * momentum)
if self.nesterov:
variable.assign(
variable - gradient * learning_rate + m * momentum
)
else:
variable.assign(variable + m)
else:
variable.assign(variable - gradient * learning_rate)
def get_config(self):
config = super().get_config()
config.update(
{
"momentum": self.momentum,
"nesterov": self.nesterov,
}
)
return config
SGD.__doc__ = SGD.__doc__.replace(
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
)