Implements remaining ops for PyTorch numpy (#216)
* Add PyTorch numpy functionality * Add dtype conversion * Partial fix for PyTorch numpy tests * small logic fix * Revert numpy_test * Add tensor conversion to numpy * Fix some arithmetic tests * Fix some torch functions for numpy compatibility * Fix pytorch ops for numpy compatibility, add TODOs * Fix formatting * Implement nits and fix dtype standardization * Add pytest skipif decorator and fix nits * Fix formatting and rename dtypes map * Split tests by backend * Merge space * Fix dtype issues from new type checking * Implement torch.full and torch.full_like numpy compatible * Implements logspace and linspace with tensor support for start and stop * Replace len of shape with ndim * Fix formatting * Implement torch.trace * Implement eye k diagonal arg * Implement torch.tri * Fix formatting issues * Fix torch.take dimensionality * Add split functionality * Revert torch.eye implementation to prevent conflict * Implement all padding modes
This commit is contained in:
parent
b02c29d3a0
commit
a460a35362
@ -343,7 +343,6 @@ def full(shape, fill_value, dtype=None):
|
||||
expand_size = len(shape) - len(fill_value.shape)
|
||||
tile_shape = tuple(shape[:expand_size]) + (1,) * len(fill_value.shape)
|
||||
return torch.tile(fill_value, tile_shape)
|
||||
|
||||
return torch.full(size=shape, fill_value=fill_value, dtype=dtype)
|
||||
|
||||
|
||||
@ -596,8 +595,16 @@ def outer(x1, x2):
|
||||
def pad(x, pad_width, mode="constant"):
|
||||
x = convert_to_tensor(x)
|
||||
pad_sum = ()
|
||||
pad_width = list(pad_width)[::-1] # torch uses reverse order
|
||||
for pad in pad_width:
|
||||
pad_sum += pad
|
||||
if mode == "symmetric":
|
||||
mode = "replicate"
|
||||
if mode != "constant" and x.ndim < 3:
|
||||
new_dims = [1] * (3 - x.ndim)
|
||||
x = cast(x, torch.float32) if x.dtype == torch.int else x
|
||||
x = x.view(*new_dims, *x.shape)
|
||||
return torch.nn.functional.pad(x, pad=pad_sum, mode=mode).squeeze()
|
||||
return torch.nn.functional.pad(x, pad=pad_sum, mode=mode)
|
||||
|
||||
|
||||
@ -665,9 +672,20 @@ def sort(x, axis=-1):
|
||||
|
||||
def split(x, indices_or_sections, axis=0):
|
||||
x = convert_to_tensor(x)
|
||||
if isinstance(indices_or_sections, list):
|
||||
idxs = convert_to_tensor(indices_or_sections)
|
||||
start_size = indices_or_sections[0]
|
||||
end_size = x.shape[axis] - indices_or_sections[-1]
|
||||
chunk_sizes = (
|
||||
[start_size]
|
||||
+ torch.diff(idxs).type(torch.int).tolist()
|
||||
+ [end_size]
|
||||
)
|
||||
else:
|
||||
chunk_sizes = x.shape[axis] // indices_or_sections
|
||||
return torch.split(
|
||||
tensor=x,
|
||||
split_size_or_sections=indices_or_sections,
|
||||
split_size_or_sections=chunk_sizes,
|
||||
dim=axis,
|
||||
)
|
||||
|
||||
@ -694,7 +712,7 @@ def take(x, indices, axis=None):
|
||||
x = convert_to_tensor(x)
|
||||
indices = convert_to_tensor(indices).long()
|
||||
if axis is not None:
|
||||
return torch.index_select(x, dim=axis, index=indices)
|
||||
return torch.index_select(x, dim=axis, index=indices).squeeze(axis)
|
||||
return torch.take(x, index=indices)
|
||||
|
||||
|
||||
@ -734,7 +752,9 @@ def trace(x, offset=None, axis1=None, axis2=None):
|
||||
|
||||
def tri(N, M=None, k=0, dtype="float32"):
|
||||
dtype = to_torch_dtype(dtype)
|
||||
pass
|
||||
M = M or N
|
||||
x = torch.ones((N, M), dtype=dtype)
|
||||
return torch.tril(x, diagonal=k)
|
||||
|
||||
|
||||
def tril(x, k=0):
|
||||
|
@ -1,5 +1,4 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from tensorflow.python.ops.numpy_ops import np_config
|
||||
|
||||
from keras_core import backend
|
||||
@ -2003,11 +2002,6 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
|
||||
self.assertAllClose(np.array(knp.outer(x, y)), np.outer(x, y))
|
||||
self.assertAllClose(np.array(knp.Outer()(x, y)), np.outer(x, y))
|
||||
|
||||
# TODO: Fix numpy compatibility (squeeze by one dimension only)
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() == "torch",
|
||||
reason="`torch.take` and `np.take` have return shape divergence.",
|
||||
)
|
||||
def test_take(self):
|
||||
x = np.arange(24).reshape([1, 2, 3, 4])
|
||||
indices = np.array([0, 1])
|
||||
@ -2780,14 +2774,6 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
|
||||
np.pad(x, ((1, 1), (1, 1))),
|
||||
)
|
||||
|
||||
# TODO: implement padding with non-constant padding,
|
||||
# bypass NotImplementedError for PyTorch
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() == "torch",
|
||||
reason="padding not implemented for non-constant use case",
|
||||
)
|
||||
def test_pad_without_torch(self):
|
||||
x = np.array([[1, 2], [3, 4]])
|
||||
self.assertAllClose(
|
||||
np.array(knp.pad(x, ((1, 1), (1, 1)), mode="reflect")),
|
||||
np.pad(x, ((1, 1), (1, 1)), mode="reflect"),
|
||||
@ -2919,14 +2905,24 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
|
||||
self.assertAllClose(np.array(knp.sort(x, axis=0)), np.sort(x, axis=0))
|
||||
self.assertAllClose(np.array(knp.Sort(axis=0)(x)), np.sort(x, axis=0))
|
||||
|
||||
# TODO: implement split for `torch` with support for conversion
|
||||
# of numpy.split args.
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() == "torch",
|
||||
reason="`torch.split` and `np.split` have return arg divergence.",
|
||||
)
|
||||
def test_split(self):
|
||||
x = np.array([[1, 2, 3], [3, 2, 1]])
|
||||
if backend.backend() == "torch":
|
||||
self.assertAllClose(
|
||||
[t.numpy() for t in knp.split(x, 2)], np.split(x, 2)
|
||||
)
|
||||
self.assertAllClose(
|
||||
[t.numpy() for t in knp.Split(2)(x)], np.split(x, 2)
|
||||
)
|
||||
self.assertAllClose(
|
||||
[t.numpy() for t in knp.split(x, [1, 2], axis=1)],
|
||||
np.split(x, [1, 2], axis=1),
|
||||
)
|
||||
self.assertAllClose(
|
||||
[t.numpy() for t in knp.Split([1, 2], axis=1)(x)],
|
||||
np.split(x, [1, 2], axis=1),
|
||||
)
|
||||
else:
|
||||
self.assertAllClose(np.array(knp.split(x, 2)), np.split(x, 2))
|
||||
self.assertAllClose(np.array(knp.Split(2)(x)), np.split(x, 2))
|
||||
self.assertAllClose(
|
||||
@ -3073,11 +3069,7 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):
|
||||
self.assertAllClose(np.array(knp.identity(3)), np.identity(3))
|
||||
self.assertAllClose(np.array(knp.Identity()(3)), np.identity(3))
|
||||
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() == "torch", reason="No torch equivalent for `np.tri`"
|
||||
)
|
||||
def test_tri(self):
|
||||
# TODO: create a manual implementation, as PyTorch has no equivalent
|
||||
self.assertAllClose(np.array(knp.tri(3)), np.tri(3))
|
||||
self.assertAllClose(np.array(knp.tri(3, 4)), np.tri(3, 4))
|
||||
self.assertAllClose(np.array(knp.tri(3, 4, 1)), np.tri(3, 4, 1))
|
||||
|
Loading…
Reference in New Issue
Block a user