Update Torch ops.array (#375)

* Update Torch ops.array

This was previously inconsistent with Jax and TF, where ops.array produces an array/tensor of the native backend type. Instead this produced a NumPy array

* Add ops test

* Fix metrics tests

* Update torch np.tile implementation
This commit is contained in:
Ian Stenbit 2023-06-20 18:39:12 -06:00 committed by Francois Chollet
parent a97cee1736
commit 9a2ee731b8
3 changed files with 8 additions and 3 deletions

@ -4,6 +4,7 @@ import torch
from keras_core.backend.torch.core import cast
from keras_core.backend.torch.core import convert_to_tensor
from keras_core.backend.torch.core import get_device
from keras_core.backend.torch.core import is_tensor
from keras_core.backend.torch.core import to_torch_dtype
TORCH_INT_TYPES = (
@ -187,9 +188,9 @@ def argsort(x, axis=-1):
def array(x, dtype=None):
dtype = to_torch_dtype(dtype)
if not isinstance(x, torch.Tensor):
if isinstance(x, torch.Tensor):
return x
return x.numpy()
return torch.tensor(x, dtype=dtype)
def average(x, axis=None, weights=None):
@ -754,6 +755,8 @@ def round(x, decimals=0):
def tile(x, repeats):
if is_tensor(repeats):
repeats = tuple(repeats.int().numpy())
x = convert_to_tensor(x)
return torch.tile(x, dims=repeats)

@ -509,7 +509,7 @@ def update_confusion_matrix_variables(
data_tiles = [num_thresholds, 1]
thresh_tiled = ops.tile(
ops.reshape(thresholds, thresh_pretile_shape), ops.array(thresh_tiles)
ops.reshape(thresholds, thresh_pretile_shape), thresh_tiles
)
# Tile the predictions for every threshold.

@ -2242,6 +2242,8 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(knp.array(x), np.array(x))
self.assertAllClose(knp.Array()(x), np.array(x))
self.assertTrue(backend.is_tensor(knp.array(x)))
self.assertTrue(backend.is_tensor(knp.Array()(x)))
def test_average(self):
x = np.array([[1, 2, 3], [3, 2, 1]])