Fix some problems in torch backend numpy (#215)

* Add missing part from torch backend

* fix tests
This commit is contained in:
Chen Qian 2023-05-25 11:41:56 -07:00 committed by Francois Chollet
parent fab6abfff5
commit b02c29d3a0
2 changed files with 21 additions and 41 deletions

@ -1,3 +1,4 @@
import numpy as np
import torch
from keras_core.backend.torch.core import cast
@ -246,16 +247,12 @@ def count_nonzero(x, axis=None):
return torch.count_nonzero(x, dim=axis).T
def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
# TODO: There is API divergence between np.cross and torch.cross,
# preventing `axisa`, `axisb`, and `axisc` parameters from
# being used.
# https://github.com/pytorch/pytorch/issues/50273
def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=-1):
if axisa != -1 or axisb != -1 or axisc != -1:
raise NotImplementedError(
"Due to API divergence between `torch.cross()` and "
"`np.cross`, the following arguments are not supported: "
f"axisa={axisa}, axisb={axisb}, axisc={axisc}"
raise ValueError(
"Torch backend does not support `axisa`, `axisb`, or `axisc`. "
f"Received: axisa={axisa}, axisb={axisb}, axisc={axisc}. Please "
"use `axis` arg in torch backend."
)
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
return torch.cross(x1, x2, dim=axis)
@ -340,24 +337,18 @@ def floor(x):
def full(shape, fill_value, dtype=None):
dtype = to_torch_dtype(dtype)
if hasattr(fill_value, "__len__"):
fill_value = convert_to_tensor(fill_value)
reps = shape[-1] // fill_value.shape[-1] # channels-last reduction
reps_by_dim = (*shape[:-1], reps)
return torch.tile(fill_value, reps_by_dim)
fill_value = convert_to_tensor(fill_value, dtype=dtype)
if len(fill_value.shape) > 0:
# `torch.full` only supports scala `fill_value`.
expand_size = len(shape) - len(fill_value.shape)
tile_shape = tuple(shape[:expand_size]) + (1,) * len(fill_value.shape)
return torch.tile(fill_value, tile_shape)
return torch.full(size=shape, fill_value=fill_value, dtype=dtype)
def full_like(x, fill_value, dtype=None):
dtype = to_torch_dtype(dtype)
if hasattr(fill_value, "__len__"):
fill_value = convert_to_tensor(fill_value)
reps_by_dim = tuple(
[x.shape[i] // fill_value.shape[i] for i in range(x.ndim)]
)
return torch.tile(fill_value, reps_by_dim)
x = convert_to_tensor(x)
return torch.full_like(input=x, fill_value=fill_value, dtype=dtype)
return full(shape=x.shape, fill_value=fill_value, dtype=dtype)
def greater(x1, x2):
@ -832,15 +823,11 @@ def sum(x, axis=None, keepdims=False):
def eye(N, M=None, k=None, dtype="float32"):
# TODO: implement support for `k` diagonal arg,
# does not exist in torch.eye()
if k is not None:
raise NotImplementedError(
"Due to API divergence bewtween `torch.eye` "
"and `np.eye`, the argument k is not supported: "
f"Received: k={k}"
)
dtype = to_torch_dtype(dtype)
if M is not None:
return torch.eye(n=N, m=M, dtype=dtype)
return torch.eye(n=N, dtype=dtype)
M = N if M is None else M
k = 0 if k is None else k
if k == 0:
return torch.eye(N, M, dtype=dtype)
diag_length = np.maximum(N, M)
diag = torch.ones(diag_length, dtype=dtype)
return torch.diag(diag, diagonal=k)[:N, :M]

@ -3032,14 +3032,7 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.zeros([2, 3])), np.zeros([2, 3]))
self.assertAllClose(np.array(knp.Zeros()([2, 3])), np.zeros([2, 3]))
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.eye` does not support arg `k`.",
)
def test_eye(self):
# TODO: implement support for `k` diagonal arg,
# does not exist in torch.eye()
self.assertAllClose(np.array(knp.eye(3)), np.eye(3))
self.assertAllClose(np.array(knp.eye(3, 4)), np.eye(3, 4))
self.assertAllClose(np.array(knp.eye(3, 4, 1)), np.eye(3, 4, 1))