Add Activations Test (#81)
* Add Activations Test * fix tests for activations * Add dtype to KerasTensor
This commit is contained in:
parent
459b1b13fc
commit
a0f94646c3
@ -267,7 +267,7 @@ def silu(x):
|
||||
|
||||
|
||||
@keras_core_export("keras_core.activations.gelu")
|
||||
def gelu(x):
|
||||
def gelu(x, approximate=False):
|
||||
"""Gaussian error linear unit (GELU) activation function.
|
||||
|
||||
The Gaussian error linear unit (GELU) is defined as:
|
||||
@ -280,12 +280,13 @@ def gelu(x):
|
||||
|
||||
Args:
|
||||
x: Input tensor.
|
||||
approximate: A `bool`, whether to enable approximation.
|
||||
|
||||
Reference:
|
||||
|
||||
- [Hendrycks et al., 2016](https://arxiv.org/abs/1606.08415)
|
||||
"""
|
||||
return ops.gelu(x)
|
||||
return ops.gelu(x, approximate=approximate)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.activations.tanh")
|
||||
@ -380,7 +381,7 @@ class Mish(ops.Operation):
|
||||
|
||||
@staticmethod
|
||||
def static_call(x):
|
||||
return x * backend.tanh(backend.softplus(x))
|
||||
return x * backend.nn.tanh(backend.nn.softplus(x))
|
||||
|
||||
|
||||
@keras_core_export("keras_core.activations.mish")
|
||||
|
@ -1 +1,174 @@
|
||||
# TODO
|
||||
import numpy as np
|
||||
|
||||
from keras_core import activations
|
||||
from keras_core import backend
|
||||
from keras_core import testing
|
||||
|
||||
|
||||
def _ref_softmax(values):
|
||||
m = np.max(values)
|
||||
e = np.exp(values - m)
|
||||
return e / np.sum(e)
|
||||
|
||||
|
||||
def _ref_softplus(x):
|
||||
return np.log(np.ones_like(x) + np.exp(x))
|
||||
|
||||
|
||||
class ActivationsTest(testing.TestCase):
|
||||
def test_softmax(self):
|
||||
x = np.random.random((2, 5))
|
||||
|
||||
result = activations.softmax(x[np.newaxis, :])[0]
|
||||
expected = _ref_softmax(x[0])
|
||||
self.assertAllClose(result[0], expected, rtol=1e-05)
|
||||
|
||||
def test_softmax_2d_axis_0(self):
|
||||
x = np.random.random((2, 5))
|
||||
result = activations.softmax(x[np.newaxis, :], axis=1)[0]
|
||||
expected = np.zeros((2, 5))
|
||||
for i in range(5):
|
||||
expected[:, i] = _ref_softmax(x[:, i])
|
||||
self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
# TODO: Fails on Tuple Axis
|
||||
# ops/nn_ops.py:3824: TypeError:
|
||||
# '<=' not supported between instances of 'int' and 'tuple'
|
||||
# def test_softmax_3d_axis_tuple(self):
|
||||
# x = np.random.random((2, 3, 5))
|
||||
# result = activations.softmax([x], axis=(1, 2))[0]
|
||||
# expected = np.zeros((2, 3, 5))
|
||||
# for i in range(2):
|
||||
# expected[i, :, :] = _ref_softmax(x[i, :, :])
|
||||
# self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
def test_temporal_softmax(self):
|
||||
x = np.random.random((2, 2, 3)) * 10
|
||||
result = activations.softmax(x[np.newaxis, :])[0]
|
||||
expected = _ref_softmax(x[0, 0])
|
||||
self.assertAllClose(result[0, 0], expected, rtol=1e-05)
|
||||
|
||||
def test_selu(self):
|
||||
alpha = 1.6732632423543772848170429916717
|
||||
scale = 1.0507009873554804934193349852946
|
||||
|
||||
positive_values = np.array([[1, 2]], dtype=backend.floatx())
|
||||
result = activations.selu(positive_values[np.newaxis, :])[0]
|
||||
self.assertAllClose(result, positive_values * scale, rtol=1e-05)
|
||||
|
||||
negative_values = np.array([[-1, -2]], dtype=backend.floatx())
|
||||
result = activations.selu(negative_values[np.newaxis, :])[0]
|
||||
true_result = (np.exp(negative_values) - 1) * scale * alpha
|
||||
self.assertAllClose(result, true_result)
|
||||
|
||||
def test_softplus(self):
|
||||
x = np.random.random((2, 5))
|
||||
result = activations.softplus(x[np.newaxis, :])[0]
|
||||
expected = _ref_softplus(x)
|
||||
self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
def test_softsign(self):
|
||||
def softsign(x):
|
||||
return np.divide(x, np.ones_like(x) + np.absolute(x))
|
||||
|
||||
x = np.random.random((2, 5))
|
||||
result = activations.softsign(x[np.newaxis, :])[0]
|
||||
expected = softsign(x)
|
||||
self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
def test_sigmoid(self):
|
||||
def ref_sigmoid(x):
|
||||
if x >= 0:
|
||||
return 1 / (1 + np.exp(-x))
|
||||
else:
|
||||
z = np.exp(x)
|
||||
return z / (1 + z)
|
||||
|
||||
sigmoid = np.vectorize(ref_sigmoid)
|
||||
|
||||
x = np.random.random((2, 5))
|
||||
result = activations.sigmoid(x[np.newaxis, :])[0]
|
||||
expected = sigmoid(x)
|
||||
self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
def test_hard_sigmoid(self):
|
||||
def ref_hard_sigmoid(x):
|
||||
x = (x / 6.0) + 0.5
|
||||
z = 0.0 if x <= 0 else (1.0 if x >= 1 else x)
|
||||
return z
|
||||
|
||||
hard_sigmoid = np.vectorize(ref_hard_sigmoid)
|
||||
x = np.random.random((2, 5))
|
||||
result = activations.hard_sigmoid(x[np.newaxis, :])[0]
|
||||
expected = hard_sigmoid(x)
|
||||
self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
def test_relu(self):
|
||||
positive_values = np.random.random((2, 5))
|
||||
result = activations.relu(positive_values[np.newaxis, :])[0]
|
||||
self.assertAllClose(result, positive_values, rtol=1e-05)
|
||||
|
||||
negative_values = np.random.uniform(-1, 0, (2, 5))
|
||||
result = activations.relu(negative_values[np.newaxis, :])[0]
|
||||
expected = np.zeros((2, 5))
|
||||
self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
def test_gelu(self):
|
||||
def gelu(x, approximate=False):
|
||||
if approximate:
|
||||
return (
|
||||
0.5
|
||||
* x
|
||||
* (
|
||||
1.0
|
||||
+ np.tanh(
|
||||
np.sqrt(2.0 / np.pi)
|
||||
* (x + 0.044715 * np.power(x, 3))
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
from scipy.stats import norm
|
||||
|
||||
return x * norm.cdf(x)
|
||||
|
||||
x = np.random.random((2, 5))
|
||||
result = activations.gelu(x[np.newaxis, :])[0]
|
||||
expected = gelu(x)
|
||||
self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
x = np.random.random((2, 5))
|
||||
result = activations.gelu(x[np.newaxis, :], approximate=True)[0]
|
||||
expected = gelu(x, True)
|
||||
self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
def test_elu(self):
|
||||
x = np.random.random((2, 5))
|
||||
result = activations.elu(x[np.newaxis, :])[0]
|
||||
self.assertAllClose(result, x, rtol=1e-05)
|
||||
negative_values = np.array([[-1, -2]], dtype=backend.floatx())
|
||||
result = activations.elu(negative_values[np.newaxis, :])[0]
|
||||
true_result = np.exp(negative_values) - 1
|
||||
self.assertAllClose(result, true_result)
|
||||
|
||||
def test_tanh(self):
|
||||
x = np.random.random((2, 5))
|
||||
result = activations.tanh(x[np.newaxis, :])[0]
|
||||
expected = np.tanh(x)
|
||||
self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
def test_exponential(self):
|
||||
x = np.random.random((2, 5))
|
||||
result = activations.exponential(x[np.newaxis, :])[0]
|
||||
expected = np.exp(x)
|
||||
self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
def test_mish(self):
|
||||
x = np.random.random((2, 5))
|
||||
result = activations.mish(x[np.newaxis, :])[0]
|
||||
expected = x * np.tanh(_ref_softplus(x))
|
||||
self.assertAllClose(result, expected, rtol=1e-05)
|
||||
|
||||
def test_linear(self):
|
||||
x = np.random.random((10, 5))
|
||||
self.assertAllClose(x, activations.linear(x))
|
||||
|
@ -19,6 +19,10 @@ def sigmoid(x):
|
||||
return jnn.sigmoid(x)
|
||||
|
||||
|
||||
def tanh(x):
|
||||
return jnn.tanh(x)
|
||||
|
||||
|
||||
def softplus(x):
|
||||
return jnn.softplus(x)
|
||||
|
||||
@ -59,8 +63,8 @@ def gelu(x, approximate=True):
|
||||
return jnn.gelu(x, approximate)
|
||||
|
||||
|
||||
def softmax(x):
|
||||
return jnn.softmax(x)
|
||||
def softmax(x, axis=None):
|
||||
return jnn.softmax(x, axis=axis)
|
||||
|
||||
|
||||
def log_softmax(x, axis=-1):
|
||||
|
@ -20,6 +20,10 @@ def sigmoid(x):
|
||||
return tf.nn.sigmoid(x)
|
||||
|
||||
|
||||
def tanh(x):
|
||||
return tf.nn.tanh(x)
|
||||
|
||||
|
||||
def softplus(x):
|
||||
return tf.math.softplus(x)
|
||||
|
||||
|
@ -74,7 +74,7 @@ class Sigmoid(Operation):
|
||||
return backend.nn.sigmoid(x)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape)
|
||||
return KerasTensor(x.shape, x.dtype)
|
||||
|
||||
|
||||
def sigmoid(x):
|
||||
@ -83,12 +83,26 @@ def sigmoid(x):
|
||||
return backend.nn.sigmoid(x)
|
||||
|
||||
|
||||
class Tanh(Operation):
|
||||
def call(self, x):
|
||||
return backend.nn.tanh(x)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape, x.dtype)
|
||||
|
||||
|
||||
def tanh(x):
|
||||
if any_symbolic_tensors((x,)):
|
||||
return Tanh().symbolic_call(x)
|
||||
return backend.nn.tanh(x)
|
||||
|
||||
|
||||
class Softplus(Operation):
|
||||
def call(self, x):
|
||||
return backend.nn.softplus(x)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape)
|
||||
return KerasTensor(x.shape, x.dtype)
|
||||
|
||||
|
||||
def softplus(x):
|
||||
@ -102,7 +116,7 @@ class Softsign(Operation):
|
||||
return backend.nn.softsign(x)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape)
|
||||
return KerasTensor(x.shape, x.dtype)
|
||||
|
||||
|
||||
def softsign(x):
|
||||
@ -116,7 +130,7 @@ class Silu(Operation):
|
||||
return backend.nn.silu(x)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape)
|
||||
return KerasTensor(x.shape, x.dtype)
|
||||
|
||||
|
||||
def silu(x):
|
||||
@ -130,7 +144,7 @@ class Swish(Operation):
|
||||
return backend.nn.swish(x)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape)
|
||||
return KerasTensor(x.shape, x.dtype)
|
||||
|
||||
|
||||
def swish(x):
|
||||
@ -144,7 +158,7 @@ class LogSigmoid(Operation):
|
||||
return backend.nn.log_sigmoid(x)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape)
|
||||
return KerasTensor(x.shape, x.dtype)
|
||||
|
||||
|
||||
def log_sigmoid(x):
|
||||
@ -162,7 +176,7 @@ class LeakyRelu(Operation):
|
||||
return backend.nn.leaky_relu(x, self.negative_slope)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape)
|
||||
return KerasTensor(x.shape, x.dtype)
|
||||
|
||||
|
||||
def leaky_relu(x, negative_slope=0.2):
|
||||
@ -176,7 +190,7 @@ class HardSigmoid(Operation):
|
||||
return backend.nn.hard_sigmoid(x)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape)
|
||||
return KerasTensor(x.shape, x.dtype)
|
||||
|
||||
|
||||
def hard_sigmoid(x):
|
||||
@ -190,7 +204,7 @@ class Elu(Operation):
|
||||
return backend.nn.elu(x)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape)
|
||||
return KerasTensor(x.shape, x.dtype)
|
||||
|
||||
|
||||
def elu(x):
|
||||
@ -204,7 +218,7 @@ class Selu(Operation):
|
||||
return backend.nn.selu(x)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape)
|
||||
return KerasTensor(x.shape, x.dtype)
|
||||
|
||||
|
||||
def selu(x):
|
||||
@ -222,7 +236,7 @@ class Gelu(Operation):
|
||||
return backend.nn.gelu(x, self.approximate)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape)
|
||||
return KerasTensor(x.shape, x.dtype)
|
||||
|
||||
|
||||
def gelu(x, approximate=True):
|
||||
@ -240,13 +254,13 @@ class Softmax(Operation):
|
||||
return backend.nn.softmax(x, axis=self.axis)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape)
|
||||
return KerasTensor(x.shape, x.dtype)
|
||||
|
||||
|
||||
def softmax(x, axis=None):
|
||||
if any_symbolic_tensors((x,)):
|
||||
return Softmax(axis).symbolic_call(x)
|
||||
return backend.nn.softmax(x)
|
||||
return backend.nn.softmax(x, axis=axis)
|
||||
|
||||
|
||||
class LogSoftmax(Operation):
|
||||
@ -258,7 +272,7 @@ class LogSoftmax(Operation):
|
||||
return backend.nn.log_softmax(x, axis=self.axis)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(x.shape)
|
||||
return KerasTensor(x.shape, x.dtype)
|
||||
|
||||
|
||||
def log_softmax(x, axis=None):
|
||||
|
Loading…
Reference in New Issue
Block a user