Fix some problems in torch backend numpy (#215)
* Add missing part from torch backend * fix tests
This commit is contained in:
parent
fab6abfff5
commit
b02c29d3a0
@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from keras_core.backend.torch.core import cast
|
||||
@ -246,16 +247,12 @@ def count_nonzero(x, axis=None):
|
||||
return torch.count_nonzero(x, dim=axis).T
|
||||
|
||||
|
||||
def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
|
||||
# TODO: There is API divergence between np.cross and torch.cross,
|
||||
# preventing `axisa`, `axisb`, and `axisc` parameters from
|
||||
# being used.
|
||||
# https://github.com/pytorch/pytorch/issues/50273
|
||||
def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=-1):
|
||||
if axisa != -1 or axisb != -1 or axisc != -1:
|
||||
raise NotImplementedError(
|
||||
"Due to API divergence between `torch.cross()` and "
|
||||
"`np.cross`, the following arguments are not supported: "
|
||||
f"axisa={axisa}, axisb={axisb}, axisc={axisc}"
|
||||
raise ValueError(
|
||||
"Torch backend does not support `axisa`, `axisb`, or `axisc`. "
|
||||
f"Received: axisa={axisa}, axisb={axisb}, axisc={axisc}. Please "
|
||||
"use `axis` arg in torch backend."
|
||||
)
|
||||
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
||||
return torch.cross(x1, x2, dim=axis)
|
||||
@ -340,24 +337,18 @@ def floor(x):
|
||||
|
||||
def full(shape, fill_value, dtype=None):
|
||||
dtype = to_torch_dtype(dtype)
|
||||
if hasattr(fill_value, "__len__"):
|
||||
fill_value = convert_to_tensor(fill_value)
|
||||
reps = shape[-1] // fill_value.shape[-1] # channels-last reduction
|
||||
reps_by_dim = (*shape[:-1], reps)
|
||||
return torch.tile(fill_value, reps_by_dim)
|
||||
fill_value = convert_to_tensor(fill_value, dtype=dtype)
|
||||
if len(fill_value.shape) > 0:
|
||||
# `torch.full` only supports scala `fill_value`.
|
||||
expand_size = len(shape) - len(fill_value.shape)
|
||||
tile_shape = tuple(shape[:expand_size]) + (1,) * len(fill_value.shape)
|
||||
return torch.tile(fill_value, tile_shape)
|
||||
|
||||
return torch.full(size=shape, fill_value=fill_value, dtype=dtype)
|
||||
|
||||
|
||||
def full_like(x, fill_value, dtype=None):
|
||||
dtype = to_torch_dtype(dtype)
|
||||
if hasattr(fill_value, "__len__"):
|
||||
fill_value = convert_to_tensor(fill_value)
|
||||
reps_by_dim = tuple(
|
||||
[x.shape[i] // fill_value.shape[i] for i in range(x.ndim)]
|
||||
)
|
||||
return torch.tile(fill_value, reps_by_dim)
|
||||
x = convert_to_tensor(x)
|
||||
return torch.full_like(input=x, fill_value=fill_value, dtype=dtype)
|
||||
return full(shape=x.shape, fill_value=fill_value, dtype=dtype)
|
||||
|
||||
|
||||
def greater(x1, x2):
|
||||
@ -832,15 +823,11 @@ def sum(x, axis=None, keepdims=False):
|
||||
|
||||
|
||||
def eye(N, M=None, k=None, dtype="float32"):
|
||||
# TODO: implement support for `k` diagonal arg,
|
||||
# does not exist in torch.eye()
|
||||
if k is not None:
|
||||
raise NotImplementedError(
|
||||
"Due to API divergence bewtween `torch.eye` "
|
||||
"and `np.eye`, the argument k is not supported: "
|
||||
f"Received: k={k}"
|
||||
)
|
||||
dtype = to_torch_dtype(dtype)
|
||||
if M is not None:
|
||||
return torch.eye(n=N, m=M, dtype=dtype)
|
||||
return torch.eye(n=N, dtype=dtype)
|
||||
M = N if M is None else M
|
||||
k = 0 if k is None else k
|
||||
if k == 0:
|
||||
return torch.eye(N, M, dtype=dtype)
|
||||
diag_length = np.maximum(N, M)
|
||||
diag = torch.ones(diag_length, dtype=dtype)
|
||||
return torch.diag(diag, diagonal=k)[:N, :M]
|
||||
|
@ -3032,14 +3032,7 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):
|
||||
self.assertAllClose(np.array(knp.zeros([2, 3])), np.zeros([2, 3]))
|
||||
self.assertAllClose(np.array(knp.Zeros()([2, 3])), np.zeros([2, 3]))
|
||||
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() == "torch",
|
||||
reason="`torch.eye` does not support arg `k`.",
|
||||
)
|
||||
def test_eye(self):
|
||||
# TODO: implement support for `k` diagonal arg,
|
||||
# does not exist in torch.eye()
|
||||
|
||||
self.assertAllClose(np.array(knp.eye(3)), np.eye(3))
|
||||
self.assertAllClose(np.array(knp.eye(3, 4)), np.eye(3, 4))
|
||||
self.assertAllClose(np.array(knp.eye(3, 4, 1)), np.eye(3, 4, 1))
|
||||
|
Loading…
Reference in New Issue
Block a user