keras/keras_core/regularizers/regularizers_test.py
hertschuh c2abf535e1 Add lowercase aliases for some regularizers for backwards compatibility. (#304)
Also fix assertion in regularizers unit tests.
2023-06-09 10:04:36 -07:00

54 lines
1.6 KiB
Python

import numpy as np
from keras_core import backend
from keras_core import regularizers
from keras_core import testing
# TODO: serialization tests
class RegularizersTest(testing.TestCase):
def test_l1(self):
value = np.random.random((4, 4))
x = backend.Variable(value)
y = regularizers.L1(0.1)(x)
self.assertAllClose(y, 0.1 * np.sum(np.abs(value)))
def test_l2(self):
value = np.random.random((4, 4))
x = backend.Variable(value)
y = regularizers.L2(0.1)(x)
self.assertAllClose(y, 0.1 * np.sum(np.square(value)))
def test_l1_l2(self):
value = np.random.random((4, 4))
x = backend.Variable(value)
y = regularizers.L1L2(l1=0.1, l2=0.2)(x)
self.assertAllClose(
y, 0.1 * np.sum(np.abs(value)) + 0.2 * np.sum(np.square(value))
)
def test_orthogonal_regularizer(self):
value = np.random.random((4, 4))
x = backend.Variable(value)
regularizers.OrthogonalRegularizer(factor=0.1, mode="rows")(x)
# TODO
def test_get_method(self):
obj = regularizers.get("l1l2")
self.assertIsInstance(obj, regularizers.L1L2)
obj = regularizers.get("l1")
self.assertIsInstance(obj, regularizers.L1)
obj = regularizers.get("l2")
self.assertIsInstance(obj, regularizers.L2)
obj = regularizers.get("orthogonal_regularizer")
self.assertIsInstance(obj, regularizers.OrthogonalRegularizer)
obj = regularizers.get(None)
self.assertEqual(obj, None)
with self.assertRaises(ValueError):
regularizers.get("typo")