Add normalize() to numerical utils.

This commit is contained in:
Francois Chollet 2023-04-29 15:08:38 -07:00
parent 125ef355e7
commit 38fda5a175
4 changed files with 64 additions and 10 deletions

@ -4,7 +4,7 @@ from keras_core.api_export import keras_core_export
from keras_core.losses.loss import Loss
from keras_core.losses.loss import squeeze_to_same_rank
from keras_core.saving import serialization_lib
from keras_core.utils.numerical_utils import l2_normalize
from keras_core.utils.numerical_utils import normalize
class LossFunctionWrapper(Loss):
@ -579,6 +579,6 @@ def cosine_similarity(y_true, y_pred, axis=-1):
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
y_true, y_pred = squeeze_to_same_rank(y_true, y_pred)
y_pred = l2_normalize(y_pred, axis=axis)
y_true = l2_normalize(y_true, axis=axis)
y_pred = normalize(y_pred, axis=axis)
y_true = normalize(y_true, axis=axis)
return -ops.sum(y_true * y_pred, axis=axis)

@ -2,7 +2,7 @@ import math
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.utils.numerical_utils import l2_normalize
from keras_core.utils.numerical_utils import normalize
@keras_core_export(
@ -313,11 +313,11 @@ class OrthogonalRegularizer(Regularizer):
f"inputs.shape={inputs.shape}"
)
if self.mode == "rows":
inputs = l2_normalize(inputs, axis=1)
inputs = normalize(inputs, axis=1)
product = ops.matmul(inputs, ops.transpose(inputs))
size = inputs.shape[0]
else:
inputs = l2_normalize(inputs, axis=0)
inputs = normalize(inputs, axis=0)
product = ops.matmul(ops.transpose(inputs), inputs)
size = inputs.shape[1]
product_no_diagonal = product * (

@ -5,11 +5,44 @@ from keras_core import operations as ops
from keras_core.api_export import keras_core_export
def l2_normalize(x, axis=0):
@keras_core_export("keras_core.utils.normalize")
def normalize(x, axis=-1, order=2):
"""Normalizes an array.
If the input is a NumPy array, a NumPy array will be returned.
If it's a backend tensor, a backend tensor will be returned.
Args:
x: Array to normalize.
axis: axis along which to normalize.
order: Normalization order (e.g. `order=2` for L2 norm).
Returns:
A normalized copy of the array.
"""
if not isinstance(order, int) or not order >= 1:
raise ValueError(
"Argument `order` must be an int >= 1. " f"Received: order={order}"
)
if isinstance(x, np.ndarray):
# NumPy input
norm = np.atleast_1d(np.linalg.norm(x, order, axis))
norm[norm == 0] = 1
return x / np.expand_dims(norm, axis)
# Backend tensor input
if len(x.shape) == 0:
x = ops.expand_dims(x, axis=0)
epsilon = backend.epsilon()
square_sum = ops.sum(ops.square(x), axis=axis, keepdims=True)
l2_norm = ops.reciprocal(ops.sqrt(ops.maximum(square_sum, epsilon)))
return ops.multiply(x, l2_norm)
if order == 2:
power_sum = ops.sum(ops.square(x), axis=axis, keepdims=True)
norm = ops.reciprocal(ops.sqrt(ops.maximum(power_sum, epsilon)))
else:
power_sum = ops.sum(ops.power(x, order), axis=axis, keepdims=True)
norm = ops.reciprocal(
ops.power(ops.maximum(power_sum, epsilon), 1.0 / order)
)
return ops.multiply(x, norm)
@keras_core_export("keras_core.utils.to_categorical")

@ -1,4 +1,5 @@
import numpy as np
import tensorflow as tf
from absl.testing import parameterized
from keras_core import backend
@ -52,3 +53,23 @@ class TestNumericalUtils(testing.TestCase, parameterized.TestCase):
one_hot = numerical_utils.to_categorical(label, NUM_CLASSES)
assert backend.is_tensor(one_hot)
self.assertAllClose(one_hot, expected)
@parameterized.parameters([1, 2, 3])
def test_normalize(self, order):
xnp = np.random.random((3, 3))
xb = backend.random.uniform((3, 3))
# Test NumPy
out = numerical_utils.normalize(xnp, axis=-1, order=order)
self.assertTrue(isinstance(out, np.ndarray))
self.assertAllClose(
tf.keras.utils.normalize(xnp, axis=-1, order=order), out
)
# Test backend
out = numerical_utils.normalize(xb, axis=-1, order=order)
self.assertTrue(backend.is_tensor(out))
self.assertAllClose(
tf.keras.utils.normalize(np.array(xb), axis=-1, order=order),
np.array(out),
)