diff --git a/keras_core/backend/torch/nn.py b/keras_core/backend/torch/nn.py index 44dec5ff4..f9b4fce34 100644 --- a/keras_core/backend/torch/nn.py +++ b/keras_core/backend/torch/nn.py @@ -157,8 +157,9 @@ def _transpose_spatial_inputs(inputs): inputs = torch.permute(inputs, (0, 4, 1, 2, 3)) else: raise ValueError( - "Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D " - f"and 3D inputs. But received shape: {inputs.shape}." + "Inputs must have ndim=3, 4 or 5, " + "corresponding to 1D, 2D and 3D inputs. " + f"Received input shape: {inputs.shape}." ) return inputs @@ -222,8 +223,9 @@ def max_pool( outputs = tnn.max_pool3d(inputs, kernel_size=pool_size, stride=strides) else: raise ValueError( - "Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D " - f"and 3D inputs. But received shape: {inputs.shape}." + "Inputs to pooling op must have ndim=3, 4 or 5, " + "corresponding to 1D, 2D and 3D inputs. " + f"Received input shape: {inputs.shape}." ) if data_format == "channels_last": outputs = _transpose_spatial_outputs(outputs) @@ -264,8 +266,9 @@ def average_pool( outputs = tnn.avg_pool3d(inputs, kernel_size=pool_size, stride=strides) else: raise ValueError( - "Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D " - f"and 3D inputs. But received shape: {inputs.shape}." + "Inputs to pooling op must have ndim=3, 4 or 5, " + "corresponding to 1D, 2D and 3D inputs. " + f"Received input shape: {inputs.shape}." ) if data_format == "channels_last": outputs = _transpose_spatial_outputs(outputs) @@ -303,8 +306,8 @@ def conv( if channels % kernel_in_channels > 0: raise ValueError( "The number of input channels must be evenly divisible by " - f"kernel's in_channels. Received input shape {inputs.shape} and " - f"kernel shape {kernel.shape}. " + f"kernel.shape[1]. Received: inputs.shape={inputs.shape}, " + f"kernel.shape={kernel.shape}" ) groups = channels // kernel_in_channels if num_spatial_dims == 1: diff --git a/keras_core/operations/nn_test.py b/keras_core/operations/nn_test.py index 677e65a07..eec32defb 100644 --- a/keras_core/operations/nn_test.py +++ b/keras_core/operations/nn_test.py @@ -1,5 +1,7 @@ import numpy as np +import pytest import tensorflow as tf +from absl.testing import parameterized from keras_core import testing from keras_core.backend.common.keras_tensor import KerasTensor @@ -565,7 +567,7 @@ class NNOpsStaticShapeTest(testing.TestCase): ) -class NNOpsCorrectnessTest(testing.TestCase): +class NNOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): def test_relu(self): x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) self.assertAllClose(knn.relu(x), [0, 0, 1, 2, 3]) @@ -738,28 +740,35 @@ class NNOpsCorrectnessTest(testing.TestCase): tf.nn.avg_pool2d(x, 2, (2, 1), padding="SAME"), ) - def test_conv(self): - # Test 1D conv. + @parameterized.product( + strides=(1, 2, 3), + padding=("valid", "same"), + dilation_rate=(1, 2), + ) + def test_conv_1d(self, strides, padding, dilation_rate): + if strides > 1 and dilation_rate > 1: + pytest.skip("Unsupported configuration") + inputs_1d = np.arange(120, dtype=float).reshape([2, 20, 3]) kernel = np.arange(24, dtype=float).reshape([4, 3, 2]) - outputs = knn.conv(inputs_1d, kernel, 1, padding="valid") - expected = tf.nn.conv1d(inputs_1d, kernel, 1, padding="VALID") - self.assertAllClose(outputs, expected) - - outputs = knn.conv(inputs_1d, kernel, 2, padding="same") - expected = tf.nn.conv1d(inputs_1d, kernel, 2, padding="SAME") - self.assertAllClose(outputs, expected) - outputs = knn.conv( - inputs_1d, kernel, 1, padding="same", dilation_rate=2 + inputs_1d, + kernel, + strides=strides, + padding=padding, + dilation_rate=dilation_rate, ) expected = tf.nn.conv1d( - inputs_1d, kernel, 1, padding="SAME", dilations=2 + inputs_1d, + kernel, + strides, + padding=padding.upper(), + dilations=dilation_rate, ) self.assertAllClose(outputs, expected) - # Test 2D conv. + def test_conv_2d(self): inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3]) kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2]) @@ -806,7 +815,7 @@ class NNOpsCorrectnessTest(testing.TestCase): ) self.assertAllClose(outputs, expected) - # Test 3D conv. + def test_conv_3d(self): inputs_3d = np.arange(3072, dtype=float).reshape([2, 8, 8, 8, 3]) kernel = np.arange(162, dtype=float).reshape([3, 3, 3, 3, 2]) @@ -844,8 +853,7 @@ class NNOpsCorrectnessTest(testing.TestCase): ) self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) - def test_depthwise_conv(self): - # Test 2D conv. + def test_depthwise_conv_2d(self): inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3]) kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2]) @@ -875,7 +883,7 @@ class NNOpsCorrectnessTest(testing.TestCase): ) self.assertAllClose(outputs, expected) - def test_separable_conv(self): + def test_separable_conv_2d(self): # Test 2D conv. inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3]) depthwise_kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2]) @@ -939,8 +947,7 @@ class NNOpsCorrectnessTest(testing.TestCase): ) self.assertAllClose(outputs, expected) - def test_conv_transpose(self): - # Test 1D conv. + def test_conv_transpose_1d(self): inputs_1d = np.arange(24, dtype=float).reshape([2, 4, 3]) kernel = np.arange(30, dtype=float).reshape([2, 5, 3]) outputs = knn.conv_transpose(inputs_1d, kernel, 2, padding="valid") @@ -955,7 +962,7 @@ class NNOpsCorrectnessTest(testing.TestCase): ) self.assertAllClose(outputs, expected) - # Test 2D conv. + def test_conv_transpose_2d(self): inputs_2d = np.arange(96, dtype=float).reshape([2, 4, 4, 3]) kernel = np.arange(60, dtype=float).reshape([2, 2, 5, 3])