4212fdd5cb
* Add golden correctness tests for Adam and SGD * Fix dtype issues * Sync with main (#56) * Minor touch ups * Fix a pretty major bug * Format code * Big rethink of Variable API * Make build-by-run the default build(), leveraging new zero_history KerasTensor mode * Minor fixes * Format code * Switch back to build-by-eager-run for simplicity * Add raise upon build failure * Work around JAX bug. * Add a few more tests. * Add saving tests * Adds test suite for SGD and golden correctness tests for all optimizers (#40) * Add golden correctness tests for Adam and SGD * Fix dtype issues * Add binary accuracy (#41) * chore: adding binary accuracy * chore: fix docstring * Add tests for add_loss and activity regularization. * Reformat code * Add ActivityRegularization layer * Fix JAX CI. * Add Lambda Callback (#42) * Add LambdaCallback * Add Lambda Callback * Add Lambda Callback * Rename lambda_callback_test.py * Add einsum (#43) * Add einsum * address comments * Fix format line length (#45) * Add Embedding layer * Shorten lines * Add .vscode to .gitignore (#46) * rm vscode settings * add .vscode to gitignore * Set demo program backend (#48) * Add tests for training arg resolution in Layer. * Implement mixed precision. * Replace backend.execute with backend.numpy.XXX (#50) * Add cosine similarity loss and update l2_normalize from regularizers (#34) * Begin cosine loss * Add testing for cosine similarity * Fix formatting * Docstring standardization * Formatting * Create numerical_utils * Fix issue with call context lingering. * Add the EarlyStopping callback (#44) * add earlystopping callback * addressing comments * address comments * addressing comments * remove unused imports * re-enable imports checks (#51) * Add nn.one_hot (#52) * Add GaussianDropout layer. * Add GaussianNoise layer * Add Categorical Accuracy Metric (#47) * chore: adding categorical accuracy metric * chore: reformat docstrings * chore: reformat * chore: ndims with len * refactor the docstring * Fix typos * Implement masking. --------- Co-authored-by: Francois Chollet <francois.chollet@gmail.com> Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com> Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Co-authored-by: Chen Qian <chenmoney@google.com> Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> * Adds rmsprop optimizer and tests * Add AdamW optimizer and tests, minor formatting changes * Implemented formatting fixes * Adds clip norm and clip value tests to Adam * Adds Adagrad and Adadelta optimizers * Applies fixes to formatting and deletes unnecessary kwargs --------- Co-authored-by: Francois Chollet <francois.chollet@gmail.com> Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com> Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Co-authored-by: Chen Qian <chenmoney@google.com> Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com>
122 lines
4.1 KiB
Python
122 lines
4.1 KiB
Python
from keras_core import operations as ops
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.optimizers import optimizer
|
|
|
|
|
|
@keras_core_export(["keras_core.optimizers.Adadelta"])
|
|
class Adadelta(optimizer.Optimizer):
|
|
"""Optimizer that implements the Adadelta algorithm.
|
|
|
|
Adadelta optimization is a stochastic gradient descent method that is based
|
|
on adaptive learning rate per dimension to address two drawbacks:
|
|
|
|
- The continual decay of learning rates throughout training.
|
|
- The need for a manually selected global learning rate.
|
|
|
|
Adadelta is a more robust extension of Adagrad that adapts learning rates
|
|
based on a moving window of gradient updates, instead of accumulating all
|
|
past gradients. This way, Adadelta continues learning even when many updates
|
|
have been done. Compared to Adagrad, in the original version of Adadelta you
|
|
don't have to set an initial learning rate. In this version, the initial
|
|
learning rate can be set, as in most other Keras optimizers.
|
|
|
|
Args:
|
|
learning_rate: Initial value for the learning rate: a floating
|
|
point value, Defaults to 0.001. Note that `Adadelta` tends
|
|
to benefit from higher initial learning rate values compared to
|
|
other optimizers.
|
|
To match the exact form in the original paper, use 1.0.
|
|
rho: A floating point value. The decay rate. Defaults to 0.95.
|
|
epsilon: Small floating point value used to maintain numerical
|
|
stability.
|
|
Defaults to 1e-7.
|
|
{{base_optimizer_keyword_args}}
|
|
|
|
Reference:
|
|
|
|
- [Zeiler, 2012](http://arxiv.org/abs/1212.5701)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
learning_rate=0.001,
|
|
rho=0.95,
|
|
epsilon=1e-7,
|
|
weight_decay=None,
|
|
clipnorm=None,
|
|
clipvalue=None,
|
|
global_clipnorm=None,
|
|
use_ema=False,
|
|
ema_momentum=0.99,
|
|
ema_overwrite_frequency=None,
|
|
name="adadelta",
|
|
):
|
|
super().__init__(
|
|
learning_rate=learning_rate,
|
|
weight_decay=weight_decay,
|
|
clipnorm=clipnorm,
|
|
clipvalue=clipvalue,
|
|
global_clipnorm=global_clipnorm,
|
|
use_ema=use_ema,
|
|
ema_momentum=ema_momentum,
|
|
ema_overwrite_frequency=ema_overwrite_frequency,
|
|
name=name,
|
|
)
|
|
self.rho = rho
|
|
self.epsilon = epsilon
|
|
|
|
def build(self, var_list):
|
|
if self.built:
|
|
return
|
|
super().build(var_list)
|
|
self._accumulated_grads = []
|
|
self._accumulated_delta_vars = []
|
|
for var in var_list:
|
|
self._accumulated_grads.append(
|
|
self.add_variable_from_reference(var, "accumulated_grad")
|
|
)
|
|
self._accumulated_delta_vars.append(
|
|
self.add_variable_from_reference(var, "accumulated_delta_var")
|
|
)
|
|
|
|
def update_step(self, grad, variable, learning_rate):
|
|
"""Update step given gradient and the associated model variable."""
|
|
lr = ops.cast(learning_rate, variable.dtype)
|
|
grad = ops.cast(grad, variable.dtype)
|
|
|
|
rho = self.rho
|
|
accumulated_grad = self._accumulated_grads[
|
|
self._get_variable_index(variable)
|
|
]
|
|
accumulated_delta_var = self._accumulated_delta_vars[
|
|
self._get_variable_index(variable)
|
|
]
|
|
|
|
def rms(x):
|
|
return ops.sqrt(x + self.epsilon)
|
|
|
|
accumulated_grad.assign(
|
|
rho * accumulated_grad + (1 - rho) * grad * grad
|
|
)
|
|
delta_var = -rms(accumulated_delta_var) * grad / rms(accumulated_grad)
|
|
accumulated_delta_var.assign(
|
|
rho * accumulated_delta_var + (1 - rho) * delta_var * delta_var
|
|
)
|
|
variable.assign(variable + lr * delta_var)
|
|
|
|
def get_config(self):
|
|
config = super().get_config()
|
|
|
|
config.update(
|
|
{
|
|
"rho": self.rho,
|
|
"epsilon": self.epsilon,
|
|
}
|
|
)
|
|
return config
|
|
|
|
|
|
Adadelta.__doc__ = Adadelta.__doc__.replace(
|
|
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
|
)
|