From 143cd2ae93ee8ac3d335fcd8c237b6a95533e256 Mon Sep 17 00:00:00 2001 From: Fayaz Rahman Date: Sun, 30 Jul 2023 22:50:17 +0530 Subject: [PATCH] Add hyperbolic ops (#634) * add hyperbolic numpy functions * update backend modules + tests + black * docstrings * fix * remove ops.nn.tanh --- guides/distributed_training_with_jax.py | 20 ++- guides/distributed_training_with_torch.py | 10 +- keras_core/backend/jax/numpy.py | 24 ++++ keras_core/backend/numpy/numpy.py | 24 ++++ keras_core/backend/tensorflow/numpy.py | 24 ++++ keras_core/backend/torch/numpy.py | 30 +++++ keras_core/ops/nn.py | 15 --- keras_core/ops/numpy.py | 143 ++++++++++++++++++++++ keras_core/ops/numpy_test.py | 89 +++++++++++++- 9 files changed, 353 insertions(+), 26 deletions(-) diff --git a/guides/distributed_training_with_jax.py b/guides/distributed_training_with_jax.py index fe436a9c3..8d55d2acd 100644 --- a/guides/distributed_training_with_jax.py +++ b/guides/distributed_training_with_jax.py @@ -58,9 +58,9 @@ def get_model(): # Make a simple convnet with batch normalization and dropout. inputs = keras.Input(shape=(28, 28, 1)) x = keras.layers.Rescaling(1.0 / 255.0)(inputs) - x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)( - x - ) + x = keras.layers.Conv2D( + filters=12, kernel_size=3, padding="same", use_bias=False + )(x) x = keras.layers.BatchNormalization(scale=False, center=True)(x) x = keras.layers.ReLU()(x) x = keras.layers.Conv2D( @@ -187,7 +187,11 @@ compute_gradients = jax.value_and_grad(compute_loss, has_aux=True) # Training step, Keras provides a pure functional optimizer.stateless_apply @jax.jit def train_step(train_state, x, y): - trainable_variables, non_trainable_variables, optimizer_variables = train_state + ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + ) = train_state (loss_value, non_trainable_variables), grads = compute_gradients( trainable_variables, non_trainable_variables, x, y ) @@ -211,7 +215,9 @@ def get_replicated_train_state(devices): var_replication = NamedSharding(var_mesh, P()) # Apply the distribution settings to the model variables - trainable_variables = jax.device_put(model.trainable_variables, var_replication) + trainable_variables = jax.device_put( + model.trainable_variables, var_replication + ) non_trainable_variables = jax.device_put( model.non_trainable_variables, var_replication ) @@ -255,7 +261,9 @@ for epoch in range(num_epochs): trainable_variables, non_trainable_variables, optimizer_variables = train_state for variable, value in zip(model.trainable_variables, trainable_variables): variable.assign(value) -for variable, value in zip(model.non_trainable_variables, non_trainable_variables): +for variable, value in zip( + model.non_trainable_variables, non_trainable_variables +): variable.assign(value) """ diff --git a/guides/distributed_training_with_torch.py b/guides/distributed_training_with_torch.py index 383aeb214..30ff06d3a 100644 --- a/guides/distributed_training_with_torch.py +++ b/guides/distributed_training_with_torch.py @@ -53,9 +53,9 @@ def get_model(): # Make a simple convnet with batch normalization and dropout. inputs = keras.Input(shape=(28, 28, 1)) x = keras.layers.Rescaling(1.0 / 255.0)(inputs) - x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)( - x - ) + x = keras.layers.Conv2D( + filters=12, kernel_size=3, padding="same", use_bias=False + )(x) x = keras.layers.BatchNormalization(scale=False, center=True)(x) x = keras.layers.ReLU()(x) x = keras.layers.Conv2D( @@ -231,7 +231,9 @@ def per_device_launch_fn(current_gpu_index, num_gpu): model = get_model() # prepare the dataloader - dataloader = prepare_dataloader(dataset, current_gpu_index, num_gpu, batch_size) + dataloader = prepare_dataloader( + dataset, current_gpu_index, num_gpu, batch_size + ) # Instantiate the torch optimizer optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) diff --git a/keras_core/backend/jax/numpy.py b/keras_core/backend/jax/numpy.py index c45a33ed8..48ffd13a5 100644 --- a/keras_core/backend/jax/numpy.py +++ b/keras_core/backend/jax/numpy.py @@ -107,10 +107,18 @@ def arccos(x): return jnp.arccos(x) +def arccosh(x): + return jnp.arccosh(x) + + def arcsin(x): return jnp.arcsin(x) +def arcsinh(x): + return jnp.arcsinh(x) + + def arctan(x): return jnp.arctan(x) @@ -119,6 +127,10 @@ def arctan2(x1, x2): return jnp.arctan2(x1, x2) +def arctanh(x): + return jnp.arctanh(x) + + def argmax(x, axis=None): return jnp.argmax(x, axis=axis) @@ -171,6 +183,10 @@ def cos(x): return jnp.cos(x) +def cosh(x): + return jnp.cosh(x) + + def count_nonzero(x, axis=None): return jnp.count_nonzero(x, axis=axis) @@ -441,6 +457,10 @@ def sin(x): return jnp.sin(x) +def sinh(x): + return jnp.sinh(x) + + def size(x): return jnp.size(x) @@ -479,6 +499,10 @@ def tan(x): return jnp.tan(x) +def tanh(x): + return jnp.tanh(x) + + def tensordot(x1, x2, axes=2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) diff --git a/keras_core/backend/numpy/numpy.py b/keras_core/backend/numpy/numpy.py index 23d8068ca..ea682fc5a 100644 --- a/keras_core/backend/numpy/numpy.py +++ b/keras_core/backend/numpy/numpy.py @@ -84,10 +84,18 @@ def arccos(x): return np.arccos(x) +def arccosh(x): + return np.arccosh(x) + + def arcsin(x): return np.arcsin(x) +def arcsinh(x): + return np.arcsinh(x) + + def arctan(x): return np.arctan(x) @@ -96,6 +104,10 @@ def arctan2(x1, x2): return np.arctan2(x1, x2) +def arctanh(x): + return np.arctanh(x) + + def argmax(x, axis=None): axis = tuple(axis) if isinstance(axis, list) else axis return np.argmax(x, axis=axis) @@ -157,6 +169,10 @@ def cos(x): return np.cos(x) +def cosh(x): + return np.cosh(x) + + def count_nonzero(x, axis=None): axis = tuple(axis) if isinstance(axis, list) else axis return np.count_nonzero(x, axis=axis) @@ -438,6 +454,10 @@ def sin(x): return np.sin(x) +def sinh(x): + return np.sinh(x) + + def size(x): return np.size(x) @@ -480,6 +500,10 @@ def tan(x): return np.tan(x) +def tanh(x): + return np.tanh(x) + + def tensordot(x1, x2, axes=2): axes = tuple(axes) if isinstance(axes, list) else axes return np.tensordot(x1, x2, axes=axes) diff --git a/keras_core/backend/tensorflow/numpy.py b/keras_core/backend/tensorflow/numpy.py index 954e1e08a..3a9c614a5 100644 --- a/keras_core/backend/tensorflow/numpy.py +++ b/keras_core/backend/tensorflow/numpy.py @@ -105,10 +105,18 @@ def arccos(x): return tfnp.arccos(x) +def arccosh(x): + return tfnp.arccosh(x) + + def arcsin(x): return tfnp.arcsin(x) +def arcsinh(x): + return tfnp.arcsinh(x) + + def arctan(x): return tfnp.arctan(x) @@ -117,6 +125,10 @@ def arctan2(x1, x2): return tfnp.arctan2(x1, x2) +def arctanh(x): + return tfnp.arctanh(x) + + def argmax(x, axis=None): return tfnp.argmax(x, axis=axis) @@ -174,6 +186,10 @@ def cos(x): return tfnp.cos(x) +def cosh(x): + return tfnp.cosh(x) + + def count_nonzero(x, axis=None): return tfnp.count_nonzero(x, axis=axis) @@ -472,6 +488,10 @@ def sin(x): return tfnp.sin(x) +def sinh(x): + return tfnp.sinh(x) + + def size(x): return tfnp.size(x) @@ -508,6 +528,10 @@ def tan(x): return tfnp.tan(x) +def tanh(x): + return tfnp.tanh(x) + + def tensordot(x1, x2, axes=2): return tfnp.tensordot(x1, x2, axes=axes) diff --git a/keras_core/backend/torch/numpy.py b/keras_core/backend/torch/numpy.py index 0b3606d7a..8a62f223e 100644 --- a/keras_core/backend/torch/numpy.py +++ b/keras_core/backend/torch/numpy.py @@ -173,11 +173,21 @@ def arccos(x): return torch.arccos(x) +def arccosh(x): + x = convert_to_tensor(x) + return torch.arccosh(x) + + def arcsin(x): x = convert_to_tensor(x) return torch.arcsin(x) +def arcsinh(x): + x = convert_to_tensor(x) + return torch.arcsinh(x) + + def arctan(x): x = convert_to_tensor(x) return torch.arctan(x) @@ -188,6 +198,11 @@ def arctan2(x1, x2): return torch.arctan2(x1, x2) +def arctanh(x): + x = convert_to_tensor(x) + return torch.arctanh(x) + + def argmax(x, axis=None): x = convert_to_tensor(x) return torch.argmax(x, dim=axis) @@ -277,6 +292,11 @@ def cos(x): return torch.cos(x) +def cosh(x): + x = convert_to_tensor(x) + return torch.cosh(x) + + def count_nonzero(x, axis=None): x = convert_to_tensor(x) if axis == () or axis == []: @@ -729,6 +749,11 @@ def sin(x): return torch.sin(x) +def sinh(x): + x = convert_to_tensor(x) + return torch.sinh(x) + + def size(x): x_shape = convert_to_tensor(tuple(x.shape)) return torch.prod(x_shape) @@ -806,6 +831,11 @@ def tan(x): return torch.tan(x) +def tanh(x): + x = convert_to_tensor(x) + return torch.tanh(x) + + def tensordot(x1, x2, axes=2): x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) # Conversion to long necessary for `torch.tensordot` diff --git a/keras_core/ops/nn.py b/keras_core/ops/nn.py index 51aacaaac..b2fbd3a44 100644 --- a/keras_core/ops/nn.py +++ b/keras_core/ops/nn.py @@ -57,21 +57,6 @@ def sigmoid(x): return backend.nn.sigmoid(x) -class Tanh(Operation): - def call(self, x): - return backend.nn.tanh(x) - - def compute_output_spec(self, x): - return KerasTensor(x.shape, dtype=x.dtype) - - -@keras_core_export(["keras_core.ops.tanh", "keras_core.ops.nn.tanh"]) -def tanh(x): - if any_symbolic_tensors((x,)): - return Tanh().symbolic_call(x) - return backend.nn.tanh(x) - - class Softplus(Operation): def call(self, x): return backend.nn.softplus(x) diff --git a/keras_core/ops/numpy.py b/keras_core/ops/numpy.py index 820357f49..6c248fc9b 100644 --- a/keras_core/ops/numpy.py +++ b/keras_core/ops/numpy.py @@ -10,9 +10,12 @@ amin append arange arccos +arccosh arcsin +arcsinh arctan arctan2 +arctanh argmax argmin argsort @@ -27,6 +30,7 @@ conj conjugate copy cos +cosh count_nonzero cross cumprod @@ -102,6 +106,7 @@ roll round sign sin +sinh size sort split @@ -116,6 +121,7 @@ swapaxes take take_along_axis tan +tanh tensordot tile trace @@ -713,6 +719,28 @@ def arccos(x): return backend.numpy.arccos(x) +class Arccosh(Operation): + def call(self, x): + return backend.numpy.arccosh(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +def arccosh(x): + """Inverse hyperbolic cosine, element-wise. + + Arguments: + x: Input tensor. + + Returns: + Output tensor of same shape as x. + """ + if any_symbolic_tensors((x,)): + return Arccosh().symbolic_call(x) + return backend.numpy.arccosh(x) + + class Arcsin(Operation): def call(self, x): return backend.numpy.arcsin(x) @@ -742,6 +770,29 @@ def arcsin(x): return backend.numpy.arcsin(x) +class Arcsinh(Operation): + def call(self, x): + return backend.numpy.arcsinh(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_core_export(["keras_core.ops.arcsinh", "keras_core.ops.numpy.arcsinh"]) +def arcsinh(x): + """Inverse hyperbolic sine, element-wise. + + Arguments: + x: Input tensor. + + Returns: + Output tensor of same shape as x. + """ + if any_symbolic_tensors((x,)): + return Arcsinh().symbolic_call(x) + return backend.numpy.arcsinh(x) + + class Arctan(Operation): def call(self, x): return backend.numpy.arctan(x) @@ -826,6 +877,29 @@ def arctan2(x1, x2): return backend.numpy.arctan2(x1, x2) +class Arctanh(Operation): + def call(self, x): + return backend.numpy.arctanh(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_core_export(["keras_core.ops.arctanh", "keras_core.ops.numpy.arctanh"]) +def arctanh(x): + """Inverse hyperbolic tangent, element-wise. + + Arguments: + x: Input tensor. + + Returns: + Output tensor of same shape as x. + """ + if any_symbolic_tensors((x,)): + return Arctanh().symbolic_call(x) + return backend.numpy.arctanh(x) + + class Argmax(Operation): def __init__(self, axis=None): super().__init__() @@ -1289,6 +1363,29 @@ def cos(x): return backend.numpy.cos(x) +class Cosh(Operation): + def call(self, x): + return backend.numpy.cosh(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_core_export(["keras_core.ops.cosh", "keras_core.ops.numpy.cosh"]) +def cosh(x): + """Hyperbolic cosine, element-wise. + + Arguments: + x: Input tensor. + + Returns: + Output tensor of same shape as x. + """ + if any_symbolic_tensors((x,)): + return Cosh().symbolic_call(x) + return backend.numpy.cosh(x) + + class CountNonzero(Operation): def __init__(self, axis=None): super().__init__() @@ -3115,6 +3212,29 @@ def sin(x): return backend.numpy.sin(x) +class Sinh(Operation): + def call(self, x): + return backend.numpy.sinh(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_core_export(["keras_core.ops.sinh", "keras_core.ops.numpy.sinh"]) +def sinh(x): + """Hyperbolic sine, element-wise. + + Arguments: + x: Input tensor. + + Returns: + Output tensor of same shape as x. + """ + if any_symbolic_tensors((x,)): + return Sinh().symbolic_call(x) + return backend.numpy.sinh(x) + + class Size(Operation): def call(self, x): return backend.numpy.size(x) @@ -3376,6 +3496,29 @@ def tan(x): return backend.numpy.tan(x) +class Tanh(Operation): + def call(self, x): + return backend.numpy.tanh(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_core_export(["keras_core.ops.tanh", "keras_core.ops.numpy.tanh"]) +def tanh(x): + """Hyperbolic tangent, element-wise. + + Arguments: + x: Input tensor. + + Returns: + Output tensor of same shape as x. + """ + if any_symbolic_tensors((x,)): + return Tanh().symbolic_call(x) + return backend.numpy.tanh(x) + + class Tensordot(Operation): def __init__(self, axes=2): super().__init__() diff --git a/keras_core/ops/numpy_test.py b/keras_core/ops/numpy_test.py index df97cdd3f..7dbc63e6b 100644 --- a/keras_core/ops/numpy_test.py +++ b/keras_core/ops/numpy_test.py @@ -754,14 +754,26 @@ class NumpyOneInputOpsDynamicShapeTest(testing.TestCase): x = KerasTensor([None, 3]) self.assertEqual(knp.arccos(x).shape, (None, 3)) + def test_arccosh(self): + x = KerasTensor([None, 3]) + self.assertEqual(knp.arccosh(x).shape, (None, 3)) + def test_arcsin(self): x = KerasTensor([None, 3]) self.assertEqual(knp.arcsin(x).shape, (None, 3)) + def test_arcsinh(self): + x = KerasTensor([None, 3]) + self.assertEqual(knp.arcsinh(x).shape, (None, 3)) + def test_arctan(self): x = KerasTensor([None, 3]) self.assertEqual(knp.arctan(x).shape, (None, 3)) + def test_arctanh(self): + x = KerasTensor([None, 3]) + self.assertEqual(knp.arctanh(x).shape, (None, 3)) + def test_argmax(self): x = KerasTensor([None, 3]) self.assertEqual(knp.argmax(x).shape, ()) @@ -855,6 +867,10 @@ class NumpyOneInputOpsDynamicShapeTest(testing.TestCase): x = KerasTensor([None, 3]) self.assertEqual(knp.cos(x).shape, (None, 3)) + def test_cosh(self): + x = KerasTensor([None, 3]) + self.assertEqual(knp.cosh(x).shape, (None, 3)) + def test_count_nonzero(self): x = KerasTensor([None, 3]) self.assertEqual(knp.count_nonzero(x).shape, ()) @@ -1095,6 +1111,10 @@ class NumpyOneInputOpsDynamicShapeTest(testing.TestCase): x = KerasTensor([None, 3]) self.assertEqual(knp.sin(x).shape, (None, 3)) + def test_sinh(self): + x = KerasTensor([None, 3]) + self.assertEqual(knp.sinh(x).shape, (None, 3)) + def test_size(self): x = KerasTensor([None, 3]) self.assertEqual(knp.size(x).shape, ()) @@ -1137,6 +1157,10 @@ class NumpyOneInputOpsDynamicShapeTest(testing.TestCase): x = KerasTensor([None, 3]) self.assertEqual(knp.tan(x).shape, (None, 3)) + def test_tanh(self): + x = KerasTensor([None, 3]) + self.assertEqual(knp.tanh(x).shape, (None, 3)) + def test_tile(self): x = KerasTensor([None, 3]) self.assertEqual(knp.tile(x, [2]).shape, (None, 6)) @@ -1227,14 +1251,26 @@ class NumpyOneInputOpsStaticShapeTest(testing.TestCase): x = KerasTensor([2, 3]) self.assertEqual(knp.arccos(x).shape, (2, 3)) + def test_arccosh(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.arccosh(x).shape, (2, 3)) + def test_arcsin(self): x = KerasTensor([2, 3]) self.assertEqual(knp.arcsin(x).shape, (2, 3)) + def test_arcsinh(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.arcsinh(x).shape, (2, 3)) + def test_arctan(self): x = KerasTensor([2, 3]) self.assertEqual(knp.arctan(x).shape, (2, 3)) + def test_arctanh(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.arctanh(x).shape, (2, 3)) + def test_argmax(self): x = KerasTensor([2, 3]) self.assertEqual(knp.argmax(x).shape, ()) @@ -1297,6 +1333,10 @@ class NumpyOneInputOpsStaticShapeTest(testing.TestCase): x = KerasTensor([2, 3]) self.assertEqual(knp.cos(x).shape, (2, 3)) + def test_cosh(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.cosh(x).shape, (2, 3)) + def test_count_nonzero(self): x = KerasTensor([2, 3]) self.assertEqual(knp.count_nonzero(x).shape, ()) @@ -1532,6 +1572,10 @@ class NumpyOneInputOpsStaticShapeTest(testing.TestCase): x = KerasTensor([2, 3]) self.assertEqual(knp.sin(x).shape, (2, 3)) + def test_sinh(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.sinh(x).shape, (2, 3)) + def test_size(self): x = KerasTensor([2, 3]) self.assertEqual(knp.size(x).shape, ()) @@ -1579,6 +1623,10 @@ class NumpyOneInputOpsStaticShapeTest(testing.TestCase): x = KerasTensor([2, 3]) self.assertEqual(knp.tan(x).shape, (2, 3)) + def test_tanh(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.tanh(x).shape, (2, 3)) + def test_tile(self): x = KerasTensor([2, 3]) self.assertEqual(knp.tile(x, [2]).shape, (2, 6)) @@ -2266,18 +2314,42 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase): np.transpose(x, axes=(1, 0, 3, 2, 4)), ) - def test_arcos(self): + def test_arccos(self): x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) self.assertAllClose(knp.arccos(x), np.arccos(x)) self.assertAllClose(knp.Arccos()(x), np.arccos(x)) + def test_arccosh(self): + x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) + self.assertAllClose(knp.arccosh(x), np.arccosh(x)) + + self.assertAllClose(knp.Arccosh()(x), np.arccosh(x)) + def test_arcsin(self): x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) self.assertAllClose(knp.arcsin(x), np.arcsin(x)) self.assertAllClose(knp.Arcsin()(x), np.arcsin(x)) + def test_arcsinh(self): + x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) + self.assertAllClose(knp.arcsinh(x), np.arcsinh(x)) + + self.assertAllClose(knp.Arcsinh()(x), np.arcsinh(x)) + + def test_arctan(self): + x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) + self.assertAllClose(knp.arctan(x), np.arctan(x)) + + self.assertAllClose(knp.Arctan()(x), np.arctan(x)) + + def test_arctanh(self): + x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) + self.assertAllClose(knp.arctanh(x), np.arctanh(x)) + + self.assertAllClose(knp.Arctanh()(x), np.arctanh(x)) + def test_argmax(self): x = np.array([[1, 2, 3], [3, 2, 1]]) self.assertAllClose(knp.argmax(x), np.argmax(x)) @@ -2433,6 +2505,11 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase): self.assertAllClose(knp.cos(x), np.cos(x)) self.assertAllClose(knp.Cos()(x), np.cos(x)) + def test_cosh(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.cosh(x), np.cosh(x)) + self.assertAllClose(knp.Cosh()(x), np.cosh(x)) + def test_count_nonzero(self): x = np.array([[0, 2, 3], [3, 2, 0]]) self.assertAllClose(knp.count_nonzero(x), np.count_nonzero(x)) @@ -2899,6 +2976,11 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase): self.assertAllClose(knp.sin(x), np.sin(x)) self.assertAllClose(knp.Sin()(x), np.sin(x)) + def test_sinh(self): + x = np.array([[1, -2, 3], [-3, 2, -1]]) + self.assertAllClose(knp.sinh(x), np.sinh(x)) + self.assertAllClose(knp.Sinh()(x), np.sinh(x)) + def test_size(self): x = np.array([[1, 2, 3], [3, 2, 1]]) self.assertAllClose(knp.size(x), np.size(x)) @@ -2993,6 +3075,11 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase): self.assertAllClose(knp.tan(x), np.tan(x)) self.assertAllClose(knp.Tan()(x), np.tan(x)) + def test_tanh(self): + x = np.array([[1, -2, 3], [-3, 2, -1]]) + self.assertAllClose(knp.tanh(x), np.tanh(x)) + self.assertAllClose(knp.Tanh()(x), np.tanh(x)) + def test_tile(self): x = np.array([[1, 2, 3], [3, 2, 1]]) self.assertAllClose(knp.tile(x, [2, 3]), np.tile(x, [2, 3]))