* Add PyTorch numpy functionality * Add dtype conversion * Partial fix for PyTorch numpy tests * small logic fix * Revert numpy_test * Add tensor conversion to numpy * Fix some arithmetic tests * Fix some torch functions for numpy compatibility * Fix pytorch ops for numpy compatibility, add TODOs * Fix formatting * Implement nits and fix dtype standardization * Add pytest skipif decorator and fix nits * Fix formatting and rename dtypes map * Split tests by backend * Merge space * Fix dtype issues from new type checking
This commit is contained in:
parent
1006d2eca9
commit
26f0f4f24c
@ -375,6 +375,12 @@ PYTHON_DTYPES_MAP = {
|
||||
str: "string",
|
||||
}
|
||||
|
||||
PYTHON_DTYPES_MAP = {
|
||||
bool: "bool",
|
||||
int: "int", # TBD by backend
|
||||
float: "float32",
|
||||
}
|
||||
|
||||
|
||||
def standardize_dtype(dtype):
|
||||
if dtype is None:
|
||||
|
@ -1,8 +1,16 @@
|
||||
import torch
|
||||
|
||||
from keras_core.backend.torch.core import cast
|
||||
from keras_core.backend.torch.core import convert_to_tensor
|
||||
from keras_core.backend.torch.core import to_torch_dtype
|
||||
|
||||
TORCH_INT_TYPES = (
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
)
|
||||
|
||||
|
||||
def add(x1, x2):
|
||||
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
||||
@ -31,6 +39,8 @@ def multiply(x1, x2):
|
||||
|
||||
def mean(x, axis=None, keepdims=False):
|
||||
x = convert_to_tensor(x)
|
||||
# Conversion to float necessary for `torch.mean`
|
||||
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
|
||||
return torch.mean(x, axis=axis, keepdims=keepdims)
|
||||
|
||||
|
||||
@ -172,6 +182,8 @@ def array(x, dtype=None):
|
||||
|
||||
def average(x, axis=None, weights=None):
|
||||
x = convert_to_tensor(x)
|
||||
# Conversion to float necessary for `torch.mean`
|
||||
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
|
||||
if weights is not None:
|
||||
weights = convert_to_tensor(weights)
|
||||
return torch.sum(torch.mul(x, weights), dim=axis) / torch.sum(
|
||||
@ -378,6 +390,10 @@ def imag(x):
|
||||
|
||||
def isclose(x1, x2):
|
||||
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
||||
if torch.is_floating_point(x1) and not torch.is_floating_point(x2):
|
||||
x2 = cast(x2, x1.dtype)
|
||||
if torch.is_floating_point(x2) and not torch.is_floating_point(x1):
|
||||
x1 = cast(x1, x2.dtype)
|
||||
return torch.isclose(x1, x2)
|
||||
|
||||
|
||||
@ -450,6 +466,8 @@ def log2(x):
|
||||
|
||||
def logaddexp(x1, x2):
|
||||
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
||||
x1 = cast(x1, "float32") if x1.dtype in TORCH_INT_TYPES else x1
|
||||
x2 = cast(x2, "float32") if x2.dtype in TORCH_INT_TYPES else x2
|
||||
return torch.logaddexp(x1, x2)
|
||||
|
||||
|
||||
@ -585,7 +603,8 @@ def ravel(x):
|
||||
|
||||
|
||||
def real(x):
|
||||
x = convert_to_tensor(x)
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.from_numpy(x) # needed for complex type conversion
|
||||
return torch.real(x)
|
||||
|
||||
|
||||
@ -646,6 +665,8 @@ def stack(x, axis=0):
|
||||
|
||||
def std(x, axis=None, keepdims=False):
|
||||
x = convert_to_tensor(x)
|
||||
# Conversion to float necessary for `torch.std`
|
||||
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
|
||||
# Remove Bessel correction to align with numpy
|
||||
return torch.std(x, dim=axis, keepdim=keepdims, unbiased=False)
|
||||
|
||||
@ -676,6 +697,9 @@ def tan(x):
|
||||
|
||||
def tensordot(x1, x2, axes=2):
|
||||
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
||||
# Conversion to long necessary for `torch.tensordot`
|
||||
x1 = cast(x1, "int64") if x1.dtype in TORCH_INT_TYPES else x1
|
||||
x2 = cast(x2, "int64") if x2.dtype in TORCH_INT_TYPES else x2
|
||||
return torch.tensordot(x1, x2, dims=axes)
|
||||
|
||||
|
||||
@ -777,7 +801,9 @@ def transpose(x, axes=None):
|
||||
|
||||
|
||||
def var(x, axis=None, keepdims=False):
|
||||
x = convert_to_tensor(x)
|
||||
x = convert_to_tensor(x, dtype="float32")
|
||||
# Conversion to float necessary for `torch.var`
|
||||
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
|
||||
# Bessel correction removed for numpy compatibility
|
||||
return torch.var(x, dim=axis, keepdim=keepdims, correction=0)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user