Fix dtype issues with PyTorch numpy from #190 (#210)

* Add PyTorch numpy functionality

* Add dtype conversion

* Partial fix for PyTorch numpy tests

* small logic fix

* Revert numpy_test

* Add tensor conversion to numpy

* Fix some arithmetic tests

* Fix some torch functions for numpy compatibility

* Fix pytorch ops for numpy compatibility, add TODOs

* Fix formatting

* Implement nits and fix dtype standardization

* Add pytest skipif decorator and fix nits

* Fix formatting and rename dtypes map

* Split tests by backend

* Merge space

* Fix dtype issues from new type checking
This commit is contained in:
Neel Kovelamudi 2023-05-24 17:15:44 +00:00 committed by Francois Chollet
parent 1006d2eca9
commit 26f0f4f24c
2 changed files with 34 additions and 2 deletions

@ -375,6 +375,12 @@ PYTHON_DTYPES_MAP = {
str: "string",
}
PYTHON_DTYPES_MAP = {
bool: "bool",
int: "int", # TBD by backend
float: "float32",
}
def standardize_dtype(dtype):
if dtype is None:

@ -1,8 +1,16 @@
import torch
from keras_core.backend.torch.core import cast
from keras_core.backend.torch.core import convert_to_tensor
from keras_core.backend.torch.core import to_torch_dtype
TORCH_INT_TYPES = (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
)
def add(x1, x2):
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
@ -31,6 +39,8 @@ def multiply(x1, x2):
def mean(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
# Conversion to float necessary for `torch.mean`
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
return torch.mean(x, axis=axis, keepdims=keepdims)
@ -172,6 +182,8 @@ def array(x, dtype=None):
def average(x, axis=None, weights=None):
x = convert_to_tensor(x)
# Conversion to float necessary for `torch.mean`
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
if weights is not None:
weights = convert_to_tensor(weights)
return torch.sum(torch.mul(x, weights), dim=axis) / torch.sum(
@ -378,6 +390,10 @@ def imag(x):
def isclose(x1, x2):
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
if torch.is_floating_point(x1) and not torch.is_floating_point(x2):
x2 = cast(x2, x1.dtype)
if torch.is_floating_point(x2) and not torch.is_floating_point(x1):
x1 = cast(x1, x2.dtype)
return torch.isclose(x1, x2)
@ -450,6 +466,8 @@ def log2(x):
def logaddexp(x1, x2):
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
x1 = cast(x1, "float32") if x1.dtype in TORCH_INT_TYPES else x1
x2 = cast(x2, "float32") if x2.dtype in TORCH_INT_TYPES else x2
return torch.logaddexp(x1, x2)
@ -585,7 +603,8 @@ def ravel(x):
def real(x):
x = convert_to_tensor(x)
if not isinstance(x, torch.Tensor):
x = torch.from_numpy(x) # needed for complex type conversion
return torch.real(x)
@ -646,6 +665,8 @@ def stack(x, axis=0):
def std(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
# Conversion to float necessary for `torch.std`
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
# Remove Bessel correction to align with numpy
return torch.std(x, dim=axis, keepdim=keepdims, unbiased=False)
@ -676,6 +697,9 @@ def tan(x):
def tensordot(x1, x2, axes=2):
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
# Conversion to long necessary for `torch.tensordot`
x1 = cast(x1, "int64") if x1.dtype in TORCH_INT_TYPES else x1
x2 = cast(x2, "int64") if x2.dtype in TORCH_INT_TYPES else x2
return torch.tensordot(x1, x2, dims=axes)
@ -777,7 +801,9 @@ def transpose(x, axes=None):
def var(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
x = convert_to_tensor(x, dtype="float32")
# Conversion to float necessary for `torch.var`
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
# Bessel correction removed for numpy compatibility
return torch.var(x, dim=axis, keepdim=keepdims, correction=0)