Minor refactor

This commit is contained in:
Francois Chollet 2023-05-23 16:01:22 -07:00
parent 1c102bd48c
commit 1006d2eca9
2 changed files with 39 additions and 29 deletions

@ -157,8 +157,9 @@ def _transpose_spatial_inputs(inputs):
inputs = torch.permute(inputs, (0, 4, 1, 2, 3))
else:
raise ValueError(
"Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D "
f"and 3D inputs. But received shape: {inputs.shape}."
"Inputs must have ndim=3, 4 or 5, "
"corresponding to 1D, 2D and 3D inputs. "
f"Received input shape: {inputs.shape}."
)
return inputs
@ -222,8 +223,9 @@ def max_pool(
outputs = tnn.max_pool3d(inputs, kernel_size=pool_size, stride=strides)
else:
raise ValueError(
"Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D "
f"and 3D inputs. But received shape: {inputs.shape}."
"Inputs to pooling op must have ndim=3, 4 or 5, "
"corresponding to 1D, 2D and 3D inputs. "
f"Received input shape: {inputs.shape}."
)
if data_format == "channels_last":
outputs = _transpose_spatial_outputs(outputs)
@ -264,8 +266,9 @@ def average_pool(
outputs = tnn.avg_pool3d(inputs, kernel_size=pool_size, stride=strides)
else:
raise ValueError(
"Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D "
f"and 3D inputs. But received shape: {inputs.shape}."
"Inputs to pooling op must have ndim=3, 4 or 5, "
"corresponding to 1D, 2D and 3D inputs. "
f"Received input shape: {inputs.shape}."
)
if data_format == "channels_last":
outputs = _transpose_spatial_outputs(outputs)
@ -303,8 +306,8 @@ def conv(
if channels % kernel_in_channels > 0:
raise ValueError(
"The number of input channels must be evenly divisible by "
f"kernel's in_channels. Received input shape {inputs.shape} and "
f"kernel shape {kernel.shape}. "
f"kernel.shape[1]. Received: inputs.shape={inputs.shape}, "
f"kernel.shape={kernel.shape}"
)
groups = channels // kernel_in_channels
if num_spatial_dims == 1:

@ -1,5 +1,7 @@
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized
from keras_core import testing
from keras_core.backend.common.keras_tensor import KerasTensor
@ -565,7 +567,7 @@ class NNOpsStaticShapeTest(testing.TestCase):
)
class NNOpsCorrectnessTest(testing.TestCase):
class NNOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
def test_relu(self):
x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)
self.assertAllClose(knn.relu(x), [0, 0, 1, 2, 3])
@ -738,28 +740,35 @@ class NNOpsCorrectnessTest(testing.TestCase):
tf.nn.avg_pool2d(x, 2, (2, 1), padding="SAME"),
)
def test_conv(self):
# Test 1D conv.
@parameterized.product(
strides=(1, 2, 3),
padding=("valid", "same"),
dilation_rate=(1, 2),
)
def test_conv_1d(self, strides, padding, dilation_rate):
if strides > 1 and dilation_rate > 1:
pytest.skip("Unsupported configuration")
inputs_1d = np.arange(120, dtype=float).reshape([2, 20, 3])
kernel = np.arange(24, dtype=float).reshape([4, 3, 2])
outputs = knn.conv(inputs_1d, kernel, 1, padding="valid")
expected = tf.nn.conv1d(inputs_1d, kernel, 1, padding="VALID")
self.assertAllClose(outputs, expected)
outputs = knn.conv(inputs_1d, kernel, 2, padding="same")
expected = tf.nn.conv1d(inputs_1d, kernel, 2, padding="SAME")
self.assertAllClose(outputs, expected)
outputs = knn.conv(
inputs_1d, kernel, 1, padding="same", dilation_rate=2
inputs_1d,
kernel,
strides=strides,
padding=padding,
dilation_rate=dilation_rate,
)
expected = tf.nn.conv1d(
inputs_1d, kernel, 1, padding="SAME", dilations=2
inputs_1d,
kernel,
strides,
padding=padding.upper(),
dilations=dilation_rate,
)
self.assertAllClose(outputs, expected)
# Test 2D conv.
def test_conv_2d(self):
inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3])
kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])
@ -806,7 +815,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
)
self.assertAllClose(outputs, expected)
# Test 3D conv.
def test_conv_3d(self):
inputs_3d = np.arange(3072, dtype=float).reshape([2, 8, 8, 8, 3])
kernel = np.arange(162, dtype=float).reshape([3, 3, 3, 3, 2])
@ -844,8 +853,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
)
self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
def test_depthwise_conv(self):
# Test 2D conv.
def test_depthwise_conv_2d(self):
inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3])
kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])
@ -875,7 +883,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
)
self.assertAllClose(outputs, expected)
def test_separable_conv(self):
def test_separable_conv_2d(self):
# Test 2D conv.
inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3])
depthwise_kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])
@ -939,8 +947,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
)
self.assertAllClose(outputs, expected)
def test_conv_transpose(self):
# Test 1D conv.
def test_conv_transpose_1d(self):
inputs_1d = np.arange(24, dtype=float).reshape([2, 4, 3])
kernel = np.arange(30, dtype=float).reshape([2, 5, 3])
outputs = knn.conv_transpose(inputs_1d, kernel, 2, padding="valid")
@ -955,7 +962,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
)
self.assertAllClose(outputs, expected)
# Test 2D conv.
def test_conv_transpose_2d(self):
inputs_2d = np.arange(96, dtype=float).reshape([2, 4, 4, 3])
kernel = np.arange(60, dtype=float).reshape([2, 2, 5, 3])