Implements remaining ops for PyTorch numpy (#216)

* Add PyTorch numpy functionality

* Add dtype conversion

* Partial fix for PyTorch numpy tests

* small logic fix

* Revert numpy_test

* Add tensor conversion to numpy

* Fix some arithmetic tests

* Fix some torch functions for numpy compatibility

* Fix pytorch ops for numpy compatibility, add TODOs

* Fix formatting

* Implement nits and fix dtype standardization

* Add pytest skipif decorator and fix nits

* Fix formatting and rename dtypes map

* Split tests by backend

* Merge space

* Fix dtype issues from new type checking

* Implement torch.full and torch.full_like numpy compatible

* Implements logspace and linspace with tensor support for start and stop

* Replace len of shape with ndim

* Fix formatting

* Implement torch.trace

* Implement eye k diagonal arg

* Implement torch.tri

* Fix formatting issues

* Fix torch.take dimensionality

* Add split functionality

* Revert torch.eye implementation to prevent conflict

* Implement all padding modes
This commit is contained in:
Neel Kovelamudi 2023-05-25 20:41:36 +00:00 committed by Francois Chollet
parent b02c29d3a0
commit a460a35362
2 changed files with 50 additions and 38 deletions

@ -343,7 +343,6 @@ def full(shape, fill_value, dtype=None):
expand_size = len(shape) - len(fill_value.shape)
tile_shape = tuple(shape[:expand_size]) + (1,) * len(fill_value.shape)
return torch.tile(fill_value, tile_shape)
return torch.full(size=shape, fill_value=fill_value, dtype=dtype)
@ -596,8 +595,16 @@ def outer(x1, x2):
def pad(x, pad_width, mode="constant"):
x = convert_to_tensor(x)
pad_sum = ()
pad_width = list(pad_width)[::-1] # torch uses reverse order
for pad in pad_width:
pad_sum += pad
if mode == "symmetric":
mode = "replicate"
if mode != "constant" and x.ndim < 3:
new_dims = [1] * (3 - x.ndim)
x = cast(x, torch.float32) if x.dtype == torch.int else x
x = x.view(*new_dims, *x.shape)
return torch.nn.functional.pad(x, pad=pad_sum, mode=mode).squeeze()
return torch.nn.functional.pad(x, pad=pad_sum, mode=mode)
@ -665,9 +672,20 @@ def sort(x, axis=-1):
def split(x, indices_or_sections, axis=0):
x = convert_to_tensor(x)
if isinstance(indices_or_sections, list):
idxs = convert_to_tensor(indices_or_sections)
start_size = indices_or_sections[0]
end_size = x.shape[axis] - indices_or_sections[-1]
chunk_sizes = (
[start_size]
+ torch.diff(idxs).type(torch.int).tolist()
+ [end_size]
)
else:
chunk_sizes = x.shape[axis] // indices_or_sections
return torch.split(
tensor=x,
split_size_or_sections=indices_or_sections,
split_size_or_sections=chunk_sizes,
dim=axis,
)
@ -694,7 +712,7 @@ def take(x, indices, axis=None):
x = convert_to_tensor(x)
indices = convert_to_tensor(indices).long()
if axis is not None:
return torch.index_select(x, dim=axis, index=indices)
return torch.index_select(x, dim=axis, index=indices).squeeze(axis)
return torch.take(x, index=indices)
@ -734,7 +752,9 @@ def trace(x, offset=None, axis1=None, axis2=None):
def tri(N, M=None, k=0, dtype="float32"):
dtype = to_torch_dtype(dtype)
pass
M = M or N
x = torch.ones((N, M), dtype=dtype)
return torch.tril(x, diagonal=k)
def tril(x, k=0):

@ -1,5 +1,4 @@
import numpy as np
import pytest
from tensorflow.python.ops.numpy_ops import np_config
from keras_core import backend
@ -2003,11 +2002,6 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.outer(x, y)), np.outer(x, y))
self.assertAllClose(np.array(knp.Outer()(x, y)), np.outer(x, y))
# TODO: Fix numpy compatibility (squeeze by one dimension only)
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.take` and `np.take` have return shape divergence.",
)
def test_take(self):
x = np.arange(24).reshape([1, 2, 3, 4])
indices = np.array([0, 1])
@ -2780,14 +2774,6 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
np.pad(x, ((1, 1), (1, 1))),
)
# TODO: implement padding with non-constant padding,
# bypass NotImplementedError for PyTorch
@pytest.mark.skipif(
backend.backend() == "torch",
reason="padding not implemented for non-constant use case",
)
def test_pad_without_torch(self):
x = np.array([[1, 2], [3, 4]])
self.assertAllClose(
np.array(knp.pad(x, ((1, 1), (1, 1)), mode="reflect")),
np.pad(x, ((1, 1), (1, 1)), mode="reflect"),
@ -2919,24 +2905,34 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.sort(x, axis=0)), np.sort(x, axis=0))
self.assertAllClose(np.array(knp.Sort(axis=0)(x)), np.sort(x, axis=0))
# TODO: implement split for `torch` with support for conversion
# of numpy.split args.
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.split` and `np.split` have return arg divergence.",
)
def test_split(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.split(x, 2)), np.split(x, 2))
self.assertAllClose(np.array(knp.Split(2)(x)), np.split(x, 2))
self.assertAllClose(
np.array(knp.split(x, [1, 2], axis=1)),
np.split(x, [1, 2], axis=1),
)
self.assertAllClose(
np.array(knp.Split([1, 2], axis=1)(x)),
np.split(x, [1, 2], axis=1),
)
if backend.backend() == "torch":
self.assertAllClose(
[t.numpy() for t in knp.split(x, 2)], np.split(x, 2)
)
self.assertAllClose(
[t.numpy() for t in knp.Split(2)(x)], np.split(x, 2)
)
self.assertAllClose(
[t.numpy() for t in knp.split(x, [1, 2], axis=1)],
np.split(x, [1, 2], axis=1),
)
self.assertAllClose(
[t.numpy() for t in knp.Split([1, 2], axis=1)(x)],
np.split(x, [1, 2], axis=1),
)
else:
self.assertAllClose(np.array(knp.split(x, 2)), np.split(x, 2))
self.assertAllClose(np.array(knp.Split(2)(x)), np.split(x, 2))
self.assertAllClose(
np.array(knp.split(x, [1, 2], axis=1)),
np.split(x, [1, 2], axis=1),
)
self.assertAllClose(
np.array(knp.Split([1, 2], axis=1)(x)),
np.split(x, [1, 2], axis=1),
)
def test_sqrt(self):
x = np.array([[1, 4, 9], [16, 25, 36]])
@ -3073,11 +3069,7 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.identity(3)), np.identity(3))
self.assertAllClose(np.array(knp.Identity()(3)), np.identity(3))
@pytest.mark.skipif(
backend.backend() == "torch", reason="No torch equivalent for `np.tri`"
)
def test_tri(self):
# TODO: create a manual implementation, as PyTorch has no equivalent
self.assertAllClose(np.array(knp.tri(3)), np.tri(3))
self.assertAllClose(np.array(knp.tri(3, 4)), np.tri(3, 4))
self.assertAllClose(np.array(knp.tri(3, 4, 1)), np.tri(3, 4, 1))