keras/keras_core/utils/numerical_utils.py
Aritra Roy Gosthipaty 2481069ed4 Adding: Numpy Backend (#483)
* chore: adding numpy backend

* creview comments

* review comments

* chore: adding math

* chore: adding random module

* chore: adding ranndom in init

* review comments

* chore: adding numpy and nn for numpy backend

* chore: adding generic pool, max, and average pool

* chore: adding the conv ops

* chore: reformat code and using jax for conv and pool

* chore:  added self value

* chore: activation tests pass

* chore: adding post build method

* chore: adding necessaity methods to the numpy trainer

* chore: fixing utils test

* chore: fixing losses test suite

* chore: fix backend tests

* chore: fixing initializers test

* chore: fixing accuracy metrics test

* chore: fixing ops test

* chore: review comments

* chore: init with image and fixing random tests

* chore: skipping random seed set for numpy backend

* chore: adding single resize image method

* chore: skipping tests for applications and layers

* chore: skipping tests for models

* chore: skipping testsor saving

* chore: skipping tests for trainers

* chore:ixing one hot

* chore: fixing vmap in numpy and metrics test

* chore: adding a wrapper to numpy sum, started fixing layer tests

* fix: is_tensor now accepts numpy scalars

* chore: adding draw seed

* fix: warn message for numpy masking

* fix: checking whether kernel are tensors

* chore: adding rnn

* chore: adding dynamic backend for numpy

* fix: axis cannot be None for normalize

* chore: adding jax resize for numpy image

* chore: adding rnn implementation in numpy

* chore: using pytest fixtures

* change: numpy import string

* chore: review comments

* chore: adding numpy to backend list of github actions

* chore: remove debug print statements
2023-07-19 01:08:48 +05:30

108 lines
3.3 KiB
Python

import numpy as np
from keras_core import backend
from keras_core.api_export import keras_core_export
@keras_core_export("keras_core.utils.normalize")
def normalize(x, axis=-1, order=2):
"""Normalizes an array.
If the input is a NumPy array, a NumPy array will be returned.
If it's a backend tensor, a backend tensor will be returned.
Args:
x: Array to normalize.
axis: axis along which to normalize.
order: Normalization order (e.g. `order=2` for L2 norm).
Returns:
A normalized copy of the array.
"""
from keras_core import ops
if not isinstance(order, int) or not order >= 1:
raise ValueError(
"Argument `order` must be an int >= 1. " f"Received: order={order}"
)
if isinstance(x, np.ndarray):
# NumPy input
norm = np.atleast_1d(np.linalg.norm(x, order, axis))
norm[norm == 0] = 1
# axis cannot be `None`
axis = axis or -1
return x / np.expand_dims(norm, axis)
# Backend tensor input
if len(x.shape) == 0:
x = ops.expand_dims(x, axis=0)
epsilon = backend.epsilon()
if order == 2:
power_sum = ops.sum(ops.square(x), axis=axis, keepdims=True)
norm = ops.reciprocal(ops.sqrt(ops.maximum(power_sum, epsilon)))
else:
power_sum = ops.sum(ops.power(x, order), axis=axis, keepdims=True)
norm = ops.reciprocal(
ops.power(ops.maximum(power_sum, epsilon), 1.0 / order)
)
return ops.multiply(x, norm)
@keras_core_export("keras_core.utils.to_categorical")
def to_categorical(x, num_classes=None):
"""Converts a class vector (integers) to binary class matrix.
E.g. for use with `categorical_crossentropy`.
Args:
x: Array-like with class values to be converted into a matrix
(integers from 0 to `num_classes - 1`).
num_classes: Total number of classes. If `None`, this would be inferred
as `max(x) + 1`. Defaults to `None`.
Returns:
A binary matrix representation of the input as a NumPy array. The class
axis is placed last.
Example:
>>> a = keras_core.utils.to_categorical([0, 1, 2, 3], num_classes=4)
>>> print(a)
[[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]]
>>> b = np.array([.9, .04, .03, .03,
... .3, .45, .15, .13,
... .04, .01, .94, .05,
... .12, .21, .5, .17],
... shape=[4, 4])
>>> loss = keras_core.backend.categorical_crossentropy(a, b)
>>> print(np.around(loss, 5))
[0.10536 0.82807 0.1011 1.77196]
>>> loss = keras_core.backend.categorical_crossentropy(a, a)
>>> print(np.around(loss, 5))
[0. 0. 0. 0.]
"""
if backend.is_tensor(x):
return backend.nn.one_hot(x, num_classes)
x = np.array(x, dtype="int64")
input_shape = x.shape
# Shrink the last dimension if the shape is (..., 1).
if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
input_shape = tuple(input_shape[:-1])
x = x.reshape(-1)
if not num_classes:
num_classes = np.max(x) + 1
batch_size = x.shape[0]
categorical = np.zeros((batch_size, num_classes))
categorical[np.arange(batch_size), x] = 1
output_shape = input_shape + (num_classes,)
categorical = np.reshape(categorical, output_shape)
return categorical