keras/keras_core/optimizers/adamax_test.py
Neel Kovelamudi d9b92cafb5 Adds all remaining Keras optimizers (Adamax, Adafactor, Nadam, and Ftrl) (#80)
* Add golden correctness tests for Adam and SGD

* Fix dtype issues

* Sync with main (#56)

* Minor touch ups

* Fix a pretty major bug

* Format code

* Big rethink of Variable API

* Make build-by-run the default build(), leveraging new zero_history KerasTensor mode

* Minor fixes

* Format code

* Switch back to build-by-eager-run for simplicity

* Add raise upon build failure

* Work around JAX bug.

* Add a few more tests.

* Add saving tests

* Adds test suite for SGD and golden correctness tests for all optimizers (#40)

* Add golden correctness tests for Adam and SGD

* Fix dtype issues

* Add binary accuracy (#41)

* chore: adding binary accuracy

* chore: fix docstring

* Add tests for add_loss and activity regularization.

* Reformat code

* Add ActivityRegularization layer

* Fix JAX CI.

* Add Lambda Callback (#42)

* Add LambdaCallback

* Add Lambda Callback

* Add Lambda Callback

* Rename lambda_callback_test.py

* Add einsum (#43)

* Add einsum

* address comments

* Fix format line length (#45)

* Add Embedding layer

* Shorten lines

* Add .vscode to .gitignore (#46)

* rm vscode settings

* add .vscode to gitignore

* Set demo program backend (#48)

* Add tests for training arg resolution in Layer.

* Implement mixed precision.

* Replace backend.execute with backend.numpy.XXX (#50)

* Add cosine similarity loss and update l2_normalize from regularizers (#34)

* Begin cosine loss

* Add testing for cosine similarity

* Fix formatting

* Docstring standardization

* Formatting

* Create numerical_utils

* Fix issue with call context lingering.

* Add the EarlyStopping callback (#44)

* add earlystopping callback

* addressing comments

* address comments

* addressing comments

* remove unused imports

* re-enable imports checks (#51)

* Add nn.one_hot (#52)

* Add GaussianDropout layer.

* Add GaussianNoise layer

* Add Categorical Accuracy Metric (#47)

* chore: adding categorical accuracy metric

* chore: reformat docstrings

* chore: reformat

* chore: ndims with len

* refactor the docstring

* Fix typos

* Implement masking.

---------

Co-authored-by: Francois Chollet <francois.chollet@gmail.com>
Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>
Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com>
Co-authored-by: Chen Qian <chenmoney@google.com>
Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com>
Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com>

* Adds rmsprop optimizer and tests

* Add AdamW optimizer and tests, minor formatting changes

* Implemented formatting fixes

* Adds clip norm and clip value tests to Adam

* Adds Adagrad and Adadelta optimizers

* Applies fixes to formatting and deletes unnecessary kwargs

* Adds Adamax and Adafactor and associated tests

* Adds Nadam and Ftrl optimizers and associated tests

---------

Co-authored-by: Francois Chollet <francois.chollet@gmail.com>
Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>
Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com>
Co-authored-by: Chen Qian <chenmoney@google.com>
Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com>
Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com>
2023-05-03 21:16:00 +00:00

85 lines
3.1 KiB
Python

# flake8: noqa
import numpy as np
from keras_core import backend
from keras_core import testing
from keras_core.optimizers.adamax import Adamax
class AdamaxTest(testing.TestCase):
def test_config(self):
optimizer = Adamax(
learning_rate=0.5,
beta_1=0.8,
beta_2=0.95,
epsilon=1e-5,
)
self.run_class_serialization_test(optimizer)
def test_single_step(self):
optimizer = Adamax(learning_rate=0.5)
grads = np.array([1.0, 6.0, 7.0, 2.0])
vars = backend.Variable([1.0, 2.0, 3.0, 4.0])
optimizer.apply_gradients(zip([grads], [vars]))
self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4)
def test_weight_decay(self):
grads, var1, var2, var3 = (
np.zeros(()),
backend.Variable(2.0),
backend.Variable(2.0, name="exclude"),
backend.Variable(2.0),
)
optimizer_1 = Adamax(learning_rate=1.0, weight_decay=0.004)
optimizer_1.apply_gradients(zip([grads], [var1]))
optimizer_2 = Adamax(learning_rate=1.0, weight_decay=0.004)
optimizer_2.exclude_from_weight_decay(var_names=["exclude"])
optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))
optimizer_3 = Adamax(learning_rate=1.0, weight_decay=0.004)
optimizer_3.exclude_from_weight_decay(var_list=[var3])
optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))
self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)
self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)
self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)
def test_correctness_with_golden(self):
optimizer = Adamax(
learning_rate=0.2, beta_1=0.85, beta_2=0.95, epsilon=1e-6
)
x = backend.Variable(np.ones([10]))
grads = np.arange(0.1, 1.1, 0.1)
first_grads = np.full((10,), 0.01)
# fmt: off
golden = np.array(
[[0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8],
[0.6827, 0.6873, 0.6888, 0.6896, 0.6901, 0.6904, 0.6906, 0.6908, 0.6909, 0.691],
[0.5333, 0.5407, 0.5431, 0.5444, 0.5451, 0.5456, 0.546, 0.5462, 0.5464, 0.5466],
[0.368, 0.3773, 0.3804, 0.382, 0.3829, 0.3835, 0.384, 0.3843, 0.3846, 0.3848],
[0.1933, 0.204, 0.2076, 0.2094, 0.2105, 0.2112, 0.2117, 0.2121, 0.2124, 0.2126]]
)
# fmt: on
optimizer.apply_gradients(zip([first_grads], [x]))
for i in range(5):
self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)
optimizer.apply_gradients(zip([grads], [x]))
def test_clip_norm(self):
optimizer = Adamax(clipnorm=1)
grad = [np.array([100.0, 100.0])]
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])
def test_clip_value(self):
optimizer = Adamax(clipvalue=1)
grad = [np.array([100.0, 100.0])]
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [1.0, 1.0])