4212fdd5cb
* Add golden correctness tests for Adam and SGD * Fix dtype issues * Sync with main (#56) * Minor touch ups * Fix a pretty major bug * Format code * Big rethink of Variable API * Make build-by-run the default build(), leveraging new zero_history KerasTensor mode * Minor fixes * Format code * Switch back to build-by-eager-run for simplicity * Add raise upon build failure * Work around JAX bug. * Add a few more tests. * Add saving tests * Adds test suite for SGD and golden correctness tests for all optimizers (#40) * Add golden correctness tests for Adam and SGD * Fix dtype issues * Add binary accuracy (#41) * chore: adding binary accuracy * chore: fix docstring * Add tests for add_loss and activity regularization. * Reformat code * Add ActivityRegularization layer * Fix JAX CI. * Add Lambda Callback (#42) * Add LambdaCallback * Add Lambda Callback * Add Lambda Callback * Rename lambda_callback_test.py * Add einsum (#43) * Add einsum * address comments * Fix format line length (#45) * Add Embedding layer * Shorten lines * Add .vscode to .gitignore (#46) * rm vscode settings * add .vscode to gitignore * Set demo program backend (#48) * Add tests for training arg resolution in Layer. * Implement mixed precision. * Replace backend.execute with backend.numpy.XXX (#50) * Add cosine similarity loss and update l2_normalize from regularizers (#34) * Begin cosine loss * Add testing for cosine similarity * Fix formatting * Docstring standardization * Formatting * Create numerical_utils * Fix issue with call context lingering. * Add the EarlyStopping callback (#44) * add earlystopping callback * addressing comments * address comments * addressing comments * remove unused imports * re-enable imports checks (#51) * Add nn.one_hot (#52) * Add GaussianDropout layer. * Add GaussianNoise layer * Add Categorical Accuracy Metric (#47) * chore: adding categorical accuracy metric * chore: reformat docstrings * chore: reformat * chore: ndims with len * refactor the docstring * Fix typos * Implement masking. --------- Co-authored-by: Francois Chollet <francois.chollet@gmail.com> Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com> Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Co-authored-by: Chen Qian <chenmoney@google.com> Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> * Adds rmsprop optimizer and tests * Add AdamW optimizer and tests, minor formatting changes * Implemented formatting fixes * Adds clip norm and clip value tests to Adam * Adds Adagrad and Adadelta optimizers * Applies fixes to formatting and deletes unnecessary kwargs --------- Co-authored-by: Francois Chollet <francois.chollet@gmail.com> Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com> Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Co-authored-by: Chen Qian <chenmoney@google.com> Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com>
108 lines
3.5 KiB
Python
108 lines
3.5 KiB
Python
from keras_core import initializers
|
|
from keras_core import operations as ops
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.optimizers import optimizer
|
|
|
|
|
|
@keras_core_export(["keras_core.optimizers.Adagrad"])
|
|
class Adagrad(optimizer.Optimizer):
|
|
"""Optimizer that implements the Adagrad algorithm.
|
|
|
|
Adagrad is an optimizer with parameter-specific learning rates,
|
|
which are adapted relative to how frequently a parameter gets
|
|
updated during training. The more updates a parameter receives,
|
|
the smaller the updates.
|
|
|
|
Args:
|
|
learning_rate: Initial value for the learning rate:
|
|
a floating point value,
|
|
Defaults to 0.001.
|
|
Note that `Adagrad` tends to benefit from higher initial
|
|
learning rate values compared to other optimizers.
|
|
To match the exact form in the original paper, use 1.0.
|
|
initial_accumulator_value: Floating point value.
|
|
Starting value for the accumulators (per-parameter
|
|
momentum values).
|
|
Must be non-negative.
|
|
epsilon: Small floating point value used to maintain
|
|
numerical stability.
|
|
{{base_optimizer_keyword_args}}
|
|
|
|
Reference:
|
|
|
|
- [Duchi et al., 2011](
|
|
http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
learning_rate=0.001,
|
|
initial_accumulator_value=0.1,
|
|
epsilon=1e-7,
|
|
weight_decay=None,
|
|
clipnorm=None,
|
|
clipvalue=None,
|
|
global_clipnorm=None,
|
|
use_ema=False,
|
|
ema_momentum=0.99,
|
|
ema_overwrite_frequency=None,
|
|
name="adagrad",
|
|
):
|
|
super().__init__(
|
|
learning_rate=learning_rate,
|
|
weight_decay=weight_decay,
|
|
clipnorm=clipnorm,
|
|
clipvalue=clipvalue,
|
|
global_clipnorm=global_clipnorm,
|
|
use_ema=use_ema,
|
|
ema_momentum=ema_momentum,
|
|
ema_overwrite_frequency=ema_overwrite_frequency,
|
|
name=name,
|
|
)
|
|
self.initial_accumulator_value = initial_accumulator_value
|
|
self.epsilon = epsilon
|
|
|
|
def build(self, var_list):
|
|
if self.built:
|
|
return
|
|
super().build(var_list)
|
|
self._accumulators = []
|
|
initializer = initializers.Constant(self.initial_accumulator_value)
|
|
for var in var_list:
|
|
self._accumulators.append(
|
|
self.add_variable(
|
|
shape=var.shape,
|
|
initializer=initializer,
|
|
dtype=var.dtype,
|
|
name="accumulator",
|
|
)
|
|
)
|
|
|
|
def update_step(self, gradient, variable, learning_rate):
|
|
"""Update step given gradient and the associated model variable."""
|
|
lr = ops.cast(learning_rate, variable.dtype)
|
|
gradient = ops.cast(gradient, variable.dtype)
|
|
|
|
accumulator = self._accumulators[self._get_variable_index(variable)]
|
|
|
|
accumulator.assign(accumulator + gradient * gradient)
|
|
variable.assign(
|
|
variable - (lr * gradient / ops.sqrt(accumulator + self.epsilon))
|
|
)
|
|
|
|
def get_config(self):
|
|
config = super().get_config()
|
|
|
|
config.update(
|
|
{
|
|
"initial_accumulator_value": self.initial_accumulator_value,
|
|
"epsilon": self.epsilon,
|
|
}
|
|
)
|
|
return config
|
|
|
|
|
|
Adagrad.__doc__ = Adagrad.__doc__.replace(
|
|
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
|
|
)
|