Update Torch ops.array (#375)
* Update Torch ops.array This was previously inconsistent with Jax and TF, where ops.array produces an array/tensor of the native backend type. Instead this produced a NumPy array * Add ops test * Fix metrics tests * Update torch np.tile implementation
This commit is contained in:
parent
a97cee1736
commit
9a2ee731b8
@ -4,6 +4,7 @@ import torch
|
||||
from keras_core.backend.torch.core import cast
|
||||
from keras_core.backend.torch.core import convert_to_tensor
|
||||
from keras_core.backend.torch.core import get_device
|
||||
from keras_core.backend.torch.core import is_tensor
|
||||
from keras_core.backend.torch.core import to_torch_dtype
|
||||
|
||||
TORCH_INT_TYPES = (
|
||||
@ -187,9 +188,9 @@ def argsort(x, axis=-1):
|
||||
|
||||
def array(x, dtype=None):
|
||||
dtype = to_torch_dtype(dtype)
|
||||
if not isinstance(x, torch.Tensor):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x
|
||||
return x.numpy()
|
||||
return torch.tensor(x, dtype=dtype)
|
||||
|
||||
|
||||
def average(x, axis=None, weights=None):
|
||||
@ -754,6 +755,8 @@ def round(x, decimals=0):
|
||||
|
||||
|
||||
def tile(x, repeats):
|
||||
if is_tensor(repeats):
|
||||
repeats = tuple(repeats.int().numpy())
|
||||
x = convert_to_tensor(x)
|
||||
return torch.tile(x, dims=repeats)
|
||||
|
||||
|
@ -509,7 +509,7 @@ def update_confusion_matrix_variables(
|
||||
data_tiles = [num_thresholds, 1]
|
||||
|
||||
thresh_tiled = ops.tile(
|
||||
ops.reshape(thresholds, thresh_pretile_shape), ops.array(thresh_tiles)
|
||||
ops.reshape(thresholds, thresh_pretile_shape), thresh_tiles
|
||||
)
|
||||
|
||||
# Tile the predictions for every threshold.
|
||||
|
@ -2242,6 +2242,8 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
|
||||
x = np.array([[1, 2, 3], [3, 2, 1]])
|
||||
self.assertAllClose(knp.array(x), np.array(x))
|
||||
self.assertAllClose(knp.Array()(x), np.array(x))
|
||||
self.assertTrue(backend.is_tensor(knp.array(x)))
|
||||
self.assertTrue(backend.is_tensor(knp.Array()(x)))
|
||||
|
||||
def test_average(self):
|
||||
x = np.array([[1, 2, 3], [3, 2, 1]])
|
||||
|
Loading…
Reference in New Issue
Block a user