Minor refactor
This commit is contained in:
parent
1c102bd48c
commit
1006d2eca9
@ -157,8 +157,9 @@ def _transpose_spatial_inputs(inputs):
|
|||||||
inputs = torch.permute(inputs, (0, 4, 1, 2, 3))
|
inputs = torch.permute(inputs, (0, 4, 1, 2, 3))
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D "
|
"Inputs must have ndim=3, 4 or 5, "
|
||||||
f"and 3D inputs. But received shape: {inputs.shape}."
|
"corresponding to 1D, 2D and 3D inputs. "
|
||||||
|
f"Received input shape: {inputs.shape}."
|
||||||
)
|
)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@ -222,8 +223,9 @@ def max_pool(
|
|||||||
outputs = tnn.max_pool3d(inputs, kernel_size=pool_size, stride=strides)
|
outputs = tnn.max_pool3d(inputs, kernel_size=pool_size, stride=strides)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D "
|
"Inputs to pooling op must have ndim=3, 4 or 5, "
|
||||||
f"and 3D inputs. But received shape: {inputs.shape}."
|
"corresponding to 1D, 2D and 3D inputs. "
|
||||||
|
f"Received input shape: {inputs.shape}."
|
||||||
)
|
)
|
||||||
if data_format == "channels_last":
|
if data_format == "channels_last":
|
||||||
outputs = _transpose_spatial_outputs(outputs)
|
outputs = _transpose_spatial_outputs(outputs)
|
||||||
@ -264,8 +266,9 @@ def average_pool(
|
|||||||
outputs = tnn.avg_pool3d(inputs, kernel_size=pool_size, stride=strides)
|
outputs = tnn.avg_pool3d(inputs, kernel_size=pool_size, stride=strides)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D "
|
"Inputs to pooling op must have ndim=3, 4 or 5, "
|
||||||
f"and 3D inputs. But received shape: {inputs.shape}."
|
"corresponding to 1D, 2D and 3D inputs. "
|
||||||
|
f"Received input shape: {inputs.shape}."
|
||||||
)
|
)
|
||||||
if data_format == "channels_last":
|
if data_format == "channels_last":
|
||||||
outputs = _transpose_spatial_outputs(outputs)
|
outputs = _transpose_spatial_outputs(outputs)
|
||||||
@ -303,8 +306,8 @@ def conv(
|
|||||||
if channels % kernel_in_channels > 0:
|
if channels % kernel_in_channels > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The number of input channels must be evenly divisible by "
|
"The number of input channels must be evenly divisible by "
|
||||||
f"kernel's in_channels. Received input shape {inputs.shape} and "
|
f"kernel.shape[1]. Received: inputs.shape={inputs.shape}, "
|
||||||
f"kernel shape {kernel.shape}. "
|
f"kernel.shape={kernel.shape}"
|
||||||
)
|
)
|
||||||
groups = channels // kernel_in_channels
|
groups = channels // kernel_in_channels
|
||||||
if num_spatial_dims == 1:
|
if num_spatial_dims == 1:
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from keras_core import testing
|
from keras_core import testing
|
||||||
from keras_core.backend.common.keras_tensor import KerasTensor
|
from keras_core.backend.common.keras_tensor import KerasTensor
|
||||||
@ -565,7 +567,7 @@ class NNOpsStaticShapeTest(testing.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class NNOpsCorrectnessTest(testing.TestCase):
|
class NNOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
|
||||||
def test_relu(self):
|
def test_relu(self):
|
||||||
x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)
|
x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)
|
||||||
self.assertAllClose(knn.relu(x), [0, 0, 1, 2, 3])
|
self.assertAllClose(knn.relu(x), [0, 0, 1, 2, 3])
|
||||||
@ -738,28 +740,35 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
|||||||
tf.nn.avg_pool2d(x, 2, (2, 1), padding="SAME"),
|
tf.nn.avg_pool2d(x, 2, (2, 1), padding="SAME"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_conv(self):
|
@parameterized.product(
|
||||||
# Test 1D conv.
|
strides=(1, 2, 3),
|
||||||
|
padding=("valid", "same"),
|
||||||
|
dilation_rate=(1, 2),
|
||||||
|
)
|
||||||
|
def test_conv_1d(self, strides, padding, dilation_rate):
|
||||||
|
if strides > 1 and dilation_rate > 1:
|
||||||
|
pytest.skip("Unsupported configuration")
|
||||||
|
|
||||||
inputs_1d = np.arange(120, dtype=float).reshape([2, 20, 3])
|
inputs_1d = np.arange(120, dtype=float).reshape([2, 20, 3])
|
||||||
kernel = np.arange(24, dtype=float).reshape([4, 3, 2])
|
kernel = np.arange(24, dtype=float).reshape([4, 3, 2])
|
||||||
|
|
||||||
outputs = knn.conv(inputs_1d, kernel, 1, padding="valid")
|
|
||||||
expected = tf.nn.conv1d(inputs_1d, kernel, 1, padding="VALID")
|
|
||||||
self.assertAllClose(outputs, expected)
|
|
||||||
|
|
||||||
outputs = knn.conv(inputs_1d, kernel, 2, padding="same")
|
|
||||||
expected = tf.nn.conv1d(inputs_1d, kernel, 2, padding="SAME")
|
|
||||||
self.assertAllClose(outputs, expected)
|
|
||||||
|
|
||||||
outputs = knn.conv(
|
outputs = knn.conv(
|
||||||
inputs_1d, kernel, 1, padding="same", dilation_rate=2
|
inputs_1d,
|
||||||
|
kernel,
|
||||||
|
strides=strides,
|
||||||
|
padding=padding,
|
||||||
|
dilation_rate=dilation_rate,
|
||||||
)
|
)
|
||||||
expected = tf.nn.conv1d(
|
expected = tf.nn.conv1d(
|
||||||
inputs_1d, kernel, 1, padding="SAME", dilations=2
|
inputs_1d,
|
||||||
|
kernel,
|
||||||
|
strides,
|
||||||
|
padding=padding.upper(),
|
||||||
|
dilations=dilation_rate,
|
||||||
)
|
)
|
||||||
self.assertAllClose(outputs, expected)
|
self.assertAllClose(outputs, expected)
|
||||||
|
|
||||||
# Test 2D conv.
|
def test_conv_2d(self):
|
||||||
inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3])
|
inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3])
|
||||||
kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])
|
kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])
|
||||||
|
|
||||||
@ -806,7 +815,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertAllClose(outputs, expected)
|
self.assertAllClose(outputs, expected)
|
||||||
|
|
||||||
# Test 3D conv.
|
def test_conv_3d(self):
|
||||||
inputs_3d = np.arange(3072, dtype=float).reshape([2, 8, 8, 8, 3])
|
inputs_3d = np.arange(3072, dtype=float).reshape([2, 8, 8, 8, 3])
|
||||||
kernel = np.arange(162, dtype=float).reshape([3, 3, 3, 3, 2])
|
kernel = np.arange(162, dtype=float).reshape([3, 3, 3, 3, 2])
|
||||||
|
|
||||||
@ -844,8 +853,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
|
self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
def test_depthwise_conv(self):
|
def test_depthwise_conv_2d(self):
|
||||||
# Test 2D conv.
|
|
||||||
inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3])
|
inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3])
|
||||||
kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])
|
kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])
|
||||||
|
|
||||||
@ -875,7 +883,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertAllClose(outputs, expected)
|
self.assertAllClose(outputs, expected)
|
||||||
|
|
||||||
def test_separable_conv(self):
|
def test_separable_conv_2d(self):
|
||||||
# Test 2D conv.
|
# Test 2D conv.
|
||||||
inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3])
|
inputs_2d = np.arange(600, dtype=float).reshape([2, 10, 10, 3])
|
||||||
depthwise_kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])
|
depthwise_kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])
|
||||||
@ -939,8 +947,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertAllClose(outputs, expected)
|
self.assertAllClose(outputs, expected)
|
||||||
|
|
||||||
def test_conv_transpose(self):
|
def test_conv_transpose_1d(self):
|
||||||
# Test 1D conv.
|
|
||||||
inputs_1d = np.arange(24, dtype=float).reshape([2, 4, 3])
|
inputs_1d = np.arange(24, dtype=float).reshape([2, 4, 3])
|
||||||
kernel = np.arange(30, dtype=float).reshape([2, 5, 3])
|
kernel = np.arange(30, dtype=float).reshape([2, 5, 3])
|
||||||
outputs = knn.conv_transpose(inputs_1d, kernel, 2, padding="valid")
|
outputs = knn.conv_transpose(inputs_1d, kernel, 2, padding="valid")
|
||||||
@ -955,7 +962,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertAllClose(outputs, expected)
|
self.assertAllClose(outputs, expected)
|
||||||
|
|
||||||
# Test 2D conv.
|
def test_conv_transpose_2d(self):
|
||||||
inputs_2d = np.arange(96, dtype=float).reshape([2, 4, 4, 3])
|
inputs_2d = np.arange(96, dtype=float).reshape([2, 4, 4, 3])
|
||||||
kernel = np.arange(60, dtype=float).reshape([2, 2, 5, 3])
|
kernel = np.arange(60, dtype=float).reshape([2, 2, 5, 3])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user