Minor refactor
This commit is contained in:
parent
1c102bd48c
commit
1006d2eca9
@ -157,8 +157,9 @@ def _transpose_spatial_inputs(inputs):
|
||||
inputs = torch.permute(inputs, (0, 4, 1, 2, 3))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D "
|
||||
f"and 3D inputs. But received shape: {inputs.shape}."
|
||||
"Inputs must have ndim=3, 4 or 5, "
|
||||
"corresponding to 1D, 2D and 3D inputs. "
|
||||
f"Received input shape: {inputs.shape}."
|
||||
)
|
||||
return inputs
|
||||
|
||||
@ -222,8 +223,9 @@ def max_pool(
|
||||
outputs = tnn.max_pool3d(inputs, kernel_size=pool_size, stride=strides)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D "
|
||||
f"and 3D inputs. But received shape: {inputs.shape}."
|
||||
"Inputs to pooling op must have ndim=3, 4 or 5, "
|
||||
"corresponding to 1D, 2D and 3D inputs. "
|
||||
f"Received input shape: {inputs.shape}."
|
||||
)
|
||||
if data_format == "channels_last":
|
||||
outputs = _transpose_spatial_outputs(outputs)
|
||||
@ -264,8 +266,9 @@ def average_pool(
|
||||
outputs = tnn.avg_pool3d(inputs, kernel_size=pool_size, stride=strides)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D "
|
||||
f"and 3D inputs. But received shape: {inputs.shape}."
|
||||
"Inputs to pooling op must have ndim=3, 4 or 5, "
|
||||
"corresponding to 1D, 2D and 3D inputs. "
|
||||
f"Received input shape: {inputs.shape}."
|
||||
)
|
||||
if data_format == "channels_last":
|
||||
outputs = _transpose_spatial_outputs(outputs)
|
||||
@ -303,8 +306,8 @@ def conv(
|
||||
if channels % kernel_in_channels > 0:
|
||||
raise ValueError(
|
||||
"The number of input channels must be evenly divisible by "
|
||||
f"kernel's in_channels. Received input shape {inputs.shape} and "
|
||||
f"kernel shape {kernel.shape}. "
|
||||
f"kernel.shape[1]. Received: inputs.shape={inputs.shape}, "
|
||||
f"kernel.shape={kernel.shape}"
|
||||
)
|
||||
groups = channels // kernel_in_channels
|
||||
if num_spatial_dims == 1:
|
||||
|
@ -1,5 +1,7 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
from absl.testing import parameterized
|
||||
|
||||
from keras_core import testing
|
||||
from keras_core.backend.common.keras_tensor import KerasTensor
|
||||
@ -565,7 +567,7 @@ class NNOpsStaticShapeTest(testing.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class NNOpsCorrectnessTest(testing.TestCase):
|
||||
class NNOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
|
||||
def test_relu(self):
|
||||
x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)
|
||||
self.assertAllClose(knn.relu(x), [0, 0, 1, 2, 3])
|
||||
@ -738,28 +740,35 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
||||
tf.nn.avg_pool2d(x, 2, (2, 1), padding="SAME"),
|
||||
)
|
||||
|
||||
def test_conv(self):
|
||||
# Test 1D conv.
|
||||
@parameterized.product(
|
||||
strides=(1, 2, 3),
|
||||
padding=("valid", "same"),
|
||||
dilation_rate=(1, 2),
|
||||
)
|
||||
def test_conv_1d(self, strides, padding, dilation_rate):
|
||||
if strides > 1 and dilation_rate > 1:
|
||||
pytest.skip("Unsupported configuration")
|
||||
|
||||
inputs_1d = np.arange(120, dtype=float).reshape([2, 20, 3])
|
||||
kernel = np.arange(24, dtype=float).reshape([4, 3, 2])
|
||||
|
||||
outputs = knn.conv(inputs_1d, kernel, 1, padding="valid")
|
||||
expected = tf.nn.conv1d(inputs_1d, kernel, 1, padding="VALID")
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
outputs = knn.conv(inputs_1d, kernel, 2, padding="same")
|
||||
expected = tf.nn.conv1d(inputs_1d, kernel, 2, padding="SAME")
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
outputs = knn.conv(
|
||||
inputs_1d, kernel, 1, padding="same", dilation_rate=2
|
||||
inputs_1d,
|
||||
kernel,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
dilation_rate=dilation_rate,
|
||||
)
|
||||
expected = tf.nn.conv1d(
|
||||
inputs_1d, kernel, 1, padding="SAME", dilations=2
|
||||
inputs_1d,
|
||||
kernel,
|
||||
strides,
|
||||
padding=padding.upper(),
|
||||
dilations=dilation_rate,
|
||||
)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
# Test 2D conv.
|
||||
def test_conv_2d(self):
|
||||
inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3])
|
||||
kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])
|
||||
|
||||
@ -806,7 +815,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
||||
)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
# Test 3D conv.
|
||||
def test_conv_3d(self):
|
||||
inputs_3d = np.arange(3072, dtype=float).reshape([2, 8, 8, 8, 3])
|
||||
kernel = np.arange(162, dtype=float).reshape([3, 3, 3, 3, 2])
|
||||
|
||||
@ -844,8 +853,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
||||
)
|
||||
self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
|
||||
|
||||
def test_depthwise_conv(self):
|
||||
# Test 2D conv.
|
||||
def test_depthwise_conv_2d(self):
|
||||
inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3])
|
||||
kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])
|
||||
|
||||
@ -875,7 +883,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
||||
)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
def test_separable_conv(self):
|
||||
def test_separable_conv_2d(self):
|
||||
# Test 2D conv.
|
||||
inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3])
|
||||
depthwise_kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])
|
||||
@ -939,8 +947,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
||||
)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
def test_conv_transpose(self):
|
||||
# Test 1D conv.
|
||||
def test_conv_transpose_1d(self):
|
||||
inputs_1d = np.arange(24, dtype=float).reshape([2, 4, 3])
|
||||
kernel = np.arange(30, dtype=float).reshape([2, 5, 3])
|
||||
outputs = knn.conv_transpose(inputs_1d, kernel, 2, padding="valid")
|
||||
@ -955,7 +962,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
||||
)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
# Test 2D conv.
|
||||
def test_conv_transpose_2d(self):
|
||||
inputs_2d = np.arange(96, dtype=float).reshape([2, 4, 4, 3])
|
||||
kernel = np.arange(60, dtype=float).reshape([2, 2, 5, 3])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user