Implements remaining ops for PyTorch numpy (#216)

* Add PyTorch numpy functionality

* Add dtype conversion

* Partial fix for PyTorch numpy tests

* small logic fix

* Revert numpy_test

* Add tensor conversion to numpy

* Fix some arithmetic tests

* Fix some torch functions for numpy compatibility

* Fix pytorch ops for numpy compatibility, add TODOs

* Fix formatting

* Implement nits and fix dtype standardization

* Add pytest skipif decorator and fix nits

* Fix formatting and rename dtypes map

* Split tests by backend

* Merge space

* Fix dtype issues from new type checking

* Implement torch.full and torch.full_like numpy compatible

* Implements logspace and linspace with tensor support for start and stop

* Replace len of shape with ndim

* Fix formatting

* Implement torch.trace

* Implement eye k diagonal arg

* Implement torch.tri

* Fix formatting issues

* Fix torch.take dimensionality

* Add split functionality

* Revert torch.eye implementation to prevent conflict

* Implement all padding modes
This commit is contained in:
Neel Kovelamudi 2023-05-25 20:41:36 +00:00 committed by Francois Chollet
parent b02c29d3a0
commit a460a35362
2 changed files with 50 additions and 38 deletions

@ -343,7 +343,6 @@ def full(shape, fill_value, dtype=None):
expand_size = len(shape) - len(fill_value.shape) expand_size = len(shape) - len(fill_value.shape)
tile_shape = tuple(shape[:expand_size]) + (1,) * len(fill_value.shape) tile_shape = tuple(shape[:expand_size]) + (1,) * len(fill_value.shape)
return torch.tile(fill_value, tile_shape) return torch.tile(fill_value, tile_shape)
return torch.full(size=shape, fill_value=fill_value, dtype=dtype) return torch.full(size=shape, fill_value=fill_value, dtype=dtype)
@ -596,8 +595,16 @@ def outer(x1, x2):
def pad(x, pad_width, mode="constant"): def pad(x, pad_width, mode="constant"):
x = convert_to_tensor(x) x = convert_to_tensor(x)
pad_sum = () pad_sum = ()
pad_width = list(pad_width)[::-1] # torch uses reverse order
for pad in pad_width: for pad in pad_width:
pad_sum += pad pad_sum += pad
if mode == "symmetric":
mode = "replicate"
if mode != "constant" and x.ndim < 3:
new_dims = [1] * (3 - x.ndim)
x = cast(x, torch.float32) if x.dtype == torch.int else x
x = x.view(*new_dims, *x.shape)
return torch.nn.functional.pad(x, pad=pad_sum, mode=mode).squeeze()
return torch.nn.functional.pad(x, pad=pad_sum, mode=mode) return torch.nn.functional.pad(x, pad=pad_sum, mode=mode)
@ -665,9 +672,20 @@ def sort(x, axis=-1):
def split(x, indices_or_sections, axis=0): def split(x, indices_or_sections, axis=0):
x = convert_to_tensor(x) x = convert_to_tensor(x)
if isinstance(indices_or_sections, list):
idxs = convert_to_tensor(indices_or_sections)
start_size = indices_or_sections[0]
end_size = x.shape[axis] - indices_or_sections[-1]
chunk_sizes = (
[start_size]
+ torch.diff(idxs).type(torch.int).tolist()
+ [end_size]
)
else:
chunk_sizes = x.shape[axis] // indices_or_sections
return torch.split( return torch.split(
tensor=x, tensor=x,
split_size_or_sections=indices_or_sections, split_size_or_sections=chunk_sizes,
dim=axis, dim=axis,
) )
@ -694,7 +712,7 @@ def take(x, indices, axis=None):
x = convert_to_tensor(x) x = convert_to_tensor(x)
indices = convert_to_tensor(indices).long() indices = convert_to_tensor(indices).long()
if axis is not None: if axis is not None:
return torch.index_select(x, dim=axis, index=indices) return torch.index_select(x, dim=axis, index=indices).squeeze(axis)
return torch.take(x, index=indices) return torch.take(x, index=indices)
@ -734,7 +752,9 @@ def trace(x, offset=None, axis1=None, axis2=None):
def tri(N, M=None, k=0, dtype="float32"): def tri(N, M=None, k=0, dtype="float32"):
dtype = to_torch_dtype(dtype) dtype = to_torch_dtype(dtype)
pass M = M or N
x = torch.ones((N, M), dtype=dtype)
return torch.tril(x, diagonal=k)
def tril(x, k=0): def tril(x, k=0):

@ -1,5 +1,4 @@
import numpy as np import numpy as np
import pytest
from tensorflow.python.ops.numpy_ops import np_config from tensorflow.python.ops.numpy_ops import np_config
from keras_core import backend from keras_core import backend
@ -2003,11 +2002,6 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.outer(x, y)), np.outer(x, y)) self.assertAllClose(np.array(knp.outer(x, y)), np.outer(x, y))
self.assertAllClose(np.array(knp.Outer()(x, y)), np.outer(x, y)) self.assertAllClose(np.array(knp.Outer()(x, y)), np.outer(x, y))
# TODO: Fix numpy compatibility (squeeze by one dimension only)
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.take` and `np.take` have return shape divergence.",
)
def test_take(self): def test_take(self):
x = np.arange(24).reshape([1, 2, 3, 4]) x = np.arange(24).reshape([1, 2, 3, 4])
indices = np.array([0, 1]) indices = np.array([0, 1])
@ -2780,14 +2774,6 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
np.pad(x, ((1, 1), (1, 1))), np.pad(x, ((1, 1), (1, 1))),
) )
# TODO: implement padding with non-constant padding,
# bypass NotImplementedError for PyTorch
@pytest.mark.skipif(
backend.backend() == "torch",
reason="padding not implemented for non-constant use case",
)
def test_pad_without_torch(self):
x = np.array([[1, 2], [3, 4]])
self.assertAllClose( self.assertAllClose(
np.array(knp.pad(x, ((1, 1), (1, 1)), mode="reflect")), np.array(knp.pad(x, ((1, 1), (1, 1)), mode="reflect")),
np.pad(x, ((1, 1), (1, 1)), mode="reflect"), np.pad(x, ((1, 1), (1, 1)), mode="reflect"),
@ -2919,24 +2905,34 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.sort(x, axis=0)), np.sort(x, axis=0)) self.assertAllClose(np.array(knp.sort(x, axis=0)), np.sort(x, axis=0))
self.assertAllClose(np.array(knp.Sort(axis=0)(x)), np.sort(x, axis=0)) self.assertAllClose(np.array(knp.Sort(axis=0)(x)), np.sort(x, axis=0))
# TODO: implement split for `torch` with support for conversion
# of numpy.split args.
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.split` and `np.split` have return arg divergence.",
)
def test_split(self): def test_split(self):
x = np.array([[1, 2, 3], [3, 2, 1]]) x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.split(x, 2)), np.split(x, 2)) if backend.backend() == "torch":
self.assertAllClose(np.array(knp.Split(2)(x)), np.split(x, 2)) self.assertAllClose(
self.assertAllClose( [t.numpy() for t in knp.split(x, 2)], np.split(x, 2)
np.array(knp.split(x, [1, 2], axis=1)), )
np.split(x, [1, 2], axis=1), self.assertAllClose(
) [t.numpy() for t in knp.Split(2)(x)], np.split(x, 2)
self.assertAllClose( )
np.array(knp.Split([1, 2], axis=1)(x)), self.assertAllClose(
np.split(x, [1, 2], axis=1), [t.numpy() for t in knp.split(x, [1, 2], axis=1)],
) np.split(x, [1, 2], axis=1),
)
self.assertAllClose(
[t.numpy() for t in knp.Split([1, 2], axis=1)(x)],
np.split(x, [1, 2], axis=1),
)
else:
self.assertAllClose(np.array(knp.split(x, 2)), np.split(x, 2))
self.assertAllClose(np.array(knp.Split(2)(x)), np.split(x, 2))
self.assertAllClose(
np.array(knp.split(x, [1, 2], axis=1)),
np.split(x, [1, 2], axis=1),
)
self.assertAllClose(
np.array(knp.Split([1, 2], axis=1)(x)),
np.split(x, [1, 2], axis=1),
)
def test_sqrt(self): def test_sqrt(self):
x = np.array([[1, 4, 9], [16, 25, 36]]) x = np.array([[1, 4, 9], [16, 25, 36]])
@ -3073,11 +3069,7 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.identity(3)), np.identity(3)) self.assertAllClose(np.array(knp.identity(3)), np.identity(3))
self.assertAllClose(np.array(knp.Identity()(3)), np.identity(3)) self.assertAllClose(np.array(knp.Identity()(3)), np.identity(3))
@pytest.mark.skipif(
backend.backend() == "torch", reason="No torch equivalent for `np.tri`"
)
def test_tri(self): def test_tri(self):
# TODO: create a manual implementation, as PyTorch has no equivalent
self.assertAllClose(np.array(knp.tri(3)), np.tri(3)) self.assertAllClose(np.array(knp.tri(3)), np.tri(3))
self.assertAllClose(np.array(knp.tri(3, 4)), np.tri(3, 4)) self.assertAllClose(np.array(knp.tri(3, 4)), np.tri(3, 4))
self.assertAllClose(np.array(knp.tri(3, 4, 1)), np.tri(3, 4, 1)) self.assertAllClose(np.array(knp.tri(3, 4, 1)), np.tri(3, 4, 1))