Update Torch ops.array (#375)
* Update Torch ops.array This was previously inconsistent with Jax and TF, where ops.array produces an array/tensor of the native backend type. Instead this produced a NumPy array * Add ops test * Fix metrics tests * Update torch np.tile implementation
This commit is contained in:
parent
a97cee1736
commit
9a2ee731b8
@ -4,6 +4,7 @@ import torch
|
|||||||
from keras_core.backend.torch.core import cast
|
from keras_core.backend.torch.core import cast
|
||||||
from keras_core.backend.torch.core import convert_to_tensor
|
from keras_core.backend.torch.core import convert_to_tensor
|
||||||
from keras_core.backend.torch.core import get_device
|
from keras_core.backend.torch.core import get_device
|
||||||
|
from keras_core.backend.torch.core import is_tensor
|
||||||
from keras_core.backend.torch.core import to_torch_dtype
|
from keras_core.backend.torch.core import to_torch_dtype
|
||||||
|
|
||||||
TORCH_INT_TYPES = (
|
TORCH_INT_TYPES = (
|
||||||
@ -187,9 +188,9 @@ def argsort(x, axis=-1):
|
|||||||
|
|
||||||
def array(x, dtype=None):
|
def array(x, dtype=None):
|
||||||
dtype = to_torch_dtype(dtype)
|
dtype = to_torch_dtype(dtype)
|
||||||
if not isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
return x
|
return x
|
||||||
return x.numpy()
|
return torch.tensor(x, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def average(x, axis=None, weights=None):
|
def average(x, axis=None, weights=None):
|
||||||
@ -754,6 +755,8 @@ def round(x, decimals=0):
|
|||||||
|
|
||||||
|
|
||||||
def tile(x, repeats):
|
def tile(x, repeats):
|
||||||
|
if is_tensor(repeats):
|
||||||
|
repeats = tuple(repeats.int().numpy())
|
||||||
x = convert_to_tensor(x)
|
x = convert_to_tensor(x)
|
||||||
return torch.tile(x, dims=repeats)
|
return torch.tile(x, dims=repeats)
|
||||||
|
|
||||||
|
@ -509,7 +509,7 @@ def update_confusion_matrix_variables(
|
|||||||
data_tiles = [num_thresholds, 1]
|
data_tiles = [num_thresholds, 1]
|
||||||
|
|
||||||
thresh_tiled = ops.tile(
|
thresh_tiled = ops.tile(
|
||||||
ops.reshape(thresholds, thresh_pretile_shape), ops.array(thresh_tiles)
|
ops.reshape(thresholds, thresh_pretile_shape), thresh_tiles
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tile the predictions for every threshold.
|
# Tile the predictions for every threshold.
|
||||||
|
@ -2242,6 +2242,8 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
|
|||||||
x = np.array([[1, 2, 3], [3, 2, 1]])
|
x = np.array([[1, 2, 3], [3, 2, 1]])
|
||||||
self.assertAllClose(knp.array(x), np.array(x))
|
self.assertAllClose(knp.array(x), np.array(x))
|
||||||
self.assertAllClose(knp.Array()(x), np.array(x))
|
self.assertAllClose(knp.Array()(x), np.array(x))
|
||||||
|
self.assertTrue(backend.is_tensor(knp.array(x)))
|
||||||
|
self.assertTrue(backend.is_tensor(knp.Array()(x)))
|
||||||
|
|
||||||
def test_average(self):
|
def test_average(self):
|
||||||
x = np.array([[1, 2, 3], [3, 2, 1]])
|
x = np.array([[1, 2, 3], [3, 2, 1]])
|
||||||
|
Loading…
Reference in New Issue
Block a user