Add Activations Test (#81)

* Add Activations Test

* fix tests for activations

* Add dtype to KerasTensor
This commit is contained in:
Ramesh Sampath 2023-05-03 17:00:01 -05:00 committed by Francois Chollet
parent 459b1b13fc
commit a0f94646c3
5 changed files with 216 additions and 20 deletions

@ -267,7 +267,7 @@ def silu(x):
@keras_core_export("keras_core.activations.gelu")
def gelu(x):
def gelu(x, approximate=False):
"""Gaussian error linear unit (GELU) activation function.
The Gaussian error linear unit (GELU) is defined as:
@ -280,12 +280,13 @@ def gelu(x):
Args:
x: Input tensor.
approximate: A `bool`, whether to enable approximation.
Reference:
- [Hendrycks et al., 2016](https://arxiv.org/abs/1606.08415)
"""
return ops.gelu(x)
return ops.gelu(x, approximate=approximate)
@keras_core_export("keras_core.activations.tanh")
@ -380,7 +381,7 @@ class Mish(ops.Operation):
@staticmethod
def static_call(x):
return x * backend.tanh(backend.softplus(x))
return x * backend.nn.tanh(backend.nn.softplus(x))
@keras_core_export("keras_core.activations.mish")

@ -1 +1,174 @@
# TODO
import numpy as np
from keras_core import activations
from keras_core import backend
from keras_core import testing
def _ref_softmax(values):
m = np.max(values)
e = np.exp(values - m)
return e / np.sum(e)
def _ref_softplus(x):
return np.log(np.ones_like(x) + np.exp(x))
class ActivationsTest(testing.TestCase):
def test_softmax(self):
x = np.random.random((2, 5))
result = activations.softmax(x[np.newaxis, :])[0]
expected = _ref_softmax(x[0])
self.assertAllClose(result[0], expected, rtol=1e-05)
def test_softmax_2d_axis_0(self):
x = np.random.random((2, 5))
result = activations.softmax(x[np.newaxis, :], axis=1)[0]
expected = np.zeros((2, 5))
for i in range(5):
expected[:, i] = _ref_softmax(x[:, i])
self.assertAllClose(result, expected, rtol=1e-05)
# TODO: Fails on Tuple Axis
# ops/nn_ops.py:3824: TypeError:
# '<=' not supported between instances of 'int' and 'tuple'
# def test_softmax_3d_axis_tuple(self):
# x = np.random.random((2, 3, 5))
# result = activations.softmax([x], axis=(1, 2))[0]
# expected = np.zeros((2, 3, 5))
# for i in range(2):
# expected[i, :, :] = _ref_softmax(x[i, :, :])
# self.assertAllClose(result, expected, rtol=1e-05)
def test_temporal_softmax(self):
x = np.random.random((2, 2, 3)) * 10
result = activations.softmax(x[np.newaxis, :])[0]
expected = _ref_softmax(x[0, 0])
self.assertAllClose(result[0, 0], expected, rtol=1e-05)
def test_selu(self):
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
positive_values = np.array([[1, 2]], dtype=backend.floatx())
result = activations.selu(positive_values[np.newaxis, :])[0]
self.assertAllClose(result, positive_values * scale, rtol=1e-05)
negative_values = np.array([[-1, -2]], dtype=backend.floatx())
result = activations.selu(negative_values[np.newaxis, :])[0]
true_result = (np.exp(negative_values) - 1) * scale * alpha
self.assertAllClose(result, true_result)
def test_softplus(self):
x = np.random.random((2, 5))
result = activations.softplus(x[np.newaxis, :])[0]
expected = _ref_softplus(x)
self.assertAllClose(result, expected, rtol=1e-05)
def test_softsign(self):
def softsign(x):
return np.divide(x, np.ones_like(x) + np.absolute(x))
x = np.random.random((2, 5))
result = activations.softsign(x[np.newaxis, :])[0]
expected = softsign(x)
self.assertAllClose(result, expected, rtol=1e-05)
def test_sigmoid(self):
def ref_sigmoid(x):
if x >= 0:
return 1 / (1 + np.exp(-x))
else:
z = np.exp(x)
return z / (1 + z)
sigmoid = np.vectorize(ref_sigmoid)
x = np.random.random((2, 5))
result = activations.sigmoid(x[np.newaxis, :])[0]
expected = sigmoid(x)
self.assertAllClose(result, expected, rtol=1e-05)
def test_hard_sigmoid(self):
def ref_hard_sigmoid(x):
x = (x / 6.0) + 0.5
z = 0.0 if x <= 0 else (1.0 if x >= 1 else x)
return z
hard_sigmoid = np.vectorize(ref_hard_sigmoid)
x = np.random.random((2, 5))
result = activations.hard_sigmoid(x[np.newaxis, :])[0]
expected = hard_sigmoid(x)
self.assertAllClose(result, expected, rtol=1e-05)
def test_relu(self):
positive_values = np.random.random((2, 5))
result = activations.relu(positive_values[np.newaxis, :])[0]
self.assertAllClose(result, positive_values, rtol=1e-05)
negative_values = np.random.uniform(-1, 0, (2, 5))
result = activations.relu(negative_values[np.newaxis, :])[0]
expected = np.zeros((2, 5))
self.assertAllClose(result, expected, rtol=1e-05)
def test_gelu(self):
def gelu(x, approximate=False):
if approximate:
return (
0.5
* x
* (
1.0
+ np.tanh(
np.sqrt(2.0 / np.pi)
* (x + 0.044715 * np.power(x, 3))
)
)
)
else:
from scipy.stats import norm
return x * norm.cdf(x)
x = np.random.random((2, 5))
result = activations.gelu(x[np.newaxis, :])[0]
expected = gelu(x)
self.assertAllClose(result, expected, rtol=1e-05)
x = np.random.random((2, 5))
result = activations.gelu(x[np.newaxis, :], approximate=True)[0]
expected = gelu(x, True)
self.assertAllClose(result, expected, rtol=1e-05)
def test_elu(self):
x = np.random.random((2, 5))
result = activations.elu(x[np.newaxis, :])[0]
self.assertAllClose(result, x, rtol=1e-05)
negative_values = np.array([[-1, -2]], dtype=backend.floatx())
result = activations.elu(negative_values[np.newaxis, :])[0]
true_result = np.exp(negative_values) - 1
self.assertAllClose(result, true_result)
def test_tanh(self):
x = np.random.random((2, 5))
result = activations.tanh(x[np.newaxis, :])[0]
expected = np.tanh(x)
self.assertAllClose(result, expected, rtol=1e-05)
def test_exponential(self):
x = np.random.random((2, 5))
result = activations.exponential(x[np.newaxis, :])[0]
expected = np.exp(x)
self.assertAllClose(result, expected, rtol=1e-05)
def test_mish(self):
x = np.random.random((2, 5))
result = activations.mish(x[np.newaxis, :])[0]
expected = x * np.tanh(_ref_softplus(x))
self.assertAllClose(result, expected, rtol=1e-05)
def test_linear(self):
x = np.random.random((10, 5))
self.assertAllClose(x, activations.linear(x))

@ -19,6 +19,10 @@ def sigmoid(x):
return jnn.sigmoid(x)
def tanh(x):
return jnn.tanh(x)
def softplus(x):
return jnn.softplus(x)
@ -59,8 +63,8 @@ def gelu(x, approximate=True):
return jnn.gelu(x, approximate)
def softmax(x):
return jnn.softmax(x)
def softmax(x, axis=None):
return jnn.softmax(x, axis=axis)
def log_softmax(x, axis=-1):

@ -20,6 +20,10 @@ def sigmoid(x):
return tf.nn.sigmoid(x)
def tanh(x):
return tf.nn.tanh(x)
def softplus(x):
return tf.math.softplus(x)

@ -74,7 +74,7 @@ class Sigmoid(Operation):
return backend.nn.sigmoid(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape)
return KerasTensor(x.shape, x.dtype)
def sigmoid(x):
@ -83,12 +83,26 @@ def sigmoid(x):
return backend.nn.sigmoid(x)
class Tanh(Operation):
def call(self, x):
return backend.nn.tanh(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, x.dtype)
def tanh(x):
if any_symbolic_tensors((x,)):
return Tanh().symbolic_call(x)
return backend.nn.tanh(x)
class Softplus(Operation):
def call(self, x):
return backend.nn.softplus(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape)
return KerasTensor(x.shape, x.dtype)
def softplus(x):
@ -102,7 +116,7 @@ class Softsign(Operation):
return backend.nn.softsign(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape)
return KerasTensor(x.shape, x.dtype)
def softsign(x):
@ -116,7 +130,7 @@ class Silu(Operation):
return backend.nn.silu(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape)
return KerasTensor(x.shape, x.dtype)
def silu(x):
@ -130,7 +144,7 @@ class Swish(Operation):
return backend.nn.swish(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape)
return KerasTensor(x.shape, x.dtype)
def swish(x):
@ -144,7 +158,7 @@ class LogSigmoid(Operation):
return backend.nn.log_sigmoid(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape)
return KerasTensor(x.shape, x.dtype)
def log_sigmoid(x):
@ -162,7 +176,7 @@ class LeakyRelu(Operation):
return backend.nn.leaky_relu(x, self.negative_slope)
def compute_output_spec(self, x):
return KerasTensor(x.shape)
return KerasTensor(x.shape, x.dtype)
def leaky_relu(x, negative_slope=0.2):
@ -176,7 +190,7 @@ class HardSigmoid(Operation):
return backend.nn.hard_sigmoid(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape)
return KerasTensor(x.shape, x.dtype)
def hard_sigmoid(x):
@ -190,7 +204,7 @@ class Elu(Operation):
return backend.nn.elu(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape)
return KerasTensor(x.shape, x.dtype)
def elu(x):
@ -204,7 +218,7 @@ class Selu(Operation):
return backend.nn.selu(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape)
return KerasTensor(x.shape, x.dtype)
def selu(x):
@ -222,7 +236,7 @@ class Gelu(Operation):
return backend.nn.gelu(x, self.approximate)
def compute_output_spec(self, x):
return KerasTensor(x.shape)
return KerasTensor(x.shape, x.dtype)
def gelu(x, approximate=True):
@ -240,13 +254,13 @@ class Softmax(Operation):
return backend.nn.softmax(x, axis=self.axis)
def compute_output_spec(self, x):
return KerasTensor(x.shape)
return KerasTensor(x.shape, x.dtype)
def softmax(x, axis=None):
if any_symbolic_tensors((x,)):
return Softmax(axis).symbolic_call(x)
return backend.nn.softmax(x)
return backend.nn.softmax(x, axis=axis)
class LogSoftmax(Operation):
@ -258,7 +272,7 @@ class LogSoftmax(Operation):
return backend.nn.log_softmax(x, axis=self.axis)
def compute_output_spec(self, x):
return KerasTensor(x.shape)
return KerasTensor(x.shape, x.dtype)
def log_softmax(x, axis=None):