keras/keras_core/regularizers/regularizers_test.py
Chen Qian eabdb87f9f Add some numpy ops (#1)
* Add numpy ops (initial batch) and some config

* Add unit test

* fix call

* Revert "fix call"

This reverts commit 6748ad183029ff4b97317b77ceed8661916bb9a0.

* full unit test coverage

* fix setup.py
2023-04-12 11:31:58 -07:00

37 lines
1.1 KiB
Python

import numpy as np
from keras_core import backend
from keras_core import initializers
from keras_core import operations as ops
from keras_core import regularizers
from keras_core import testing
# TODO: serialization tests
class RegularizersTest(testing.TestCase):
def test_l1(self):
value = np.random.random((4, 4))
x = backend.Variable(value)
y = regularizers.L1(0.1)(x)
self.assertAllClose(y, 0.1 * np.sum(np.abs(value)))
def test_l2(self):
value = np.random.random((4, 4))
x = backend.Variable(value)
y = regularizers.L2(0.1)(x)
self.assertAllClose(y, 0.1 * np.sum(np.square(value)))
def test_l1_l2(self):
value = np.random.random((4, 4))
x = backend.Variable(value)
y = regularizers.L1L2(l1=0.1, l2=0.2)(x)
self.assertAllClose(
y, 0.1 * np.sum(np.abs(value)) + 0.2 * np.sum(np.square(value))
)
def test_orthogonal_regularizer(self):
value = np.random.random((4, 4))
x = backend.Variable(value)
y = regularizers.OrthogonalRegularizer(factor=0.1, mode="rows")(x)
# TODO