Add default optimizer to match tf.keras.
This commit is contained in:
parent
26f0f4f24c
commit
5148978589
@ -375,12 +375,6 @@ PYTHON_DTYPES_MAP = {
|
||||
str: "string",
|
||||
}
|
||||
|
||||
PYTHON_DTYPES_MAP = {
|
||||
bool: "bool",
|
||||
int: "int", # TBD by backend
|
||||
float: "float32",
|
||||
}
|
||||
|
||||
|
||||
def standardize_dtype(dtype):
|
||||
if dtype is None:
|
||||
|
@ -1,16 +1,8 @@
|
||||
import torch
|
||||
|
||||
from keras_core.backend.torch.core import cast
|
||||
from keras_core.backend.torch.core import convert_to_tensor
|
||||
from keras_core.backend.torch.core import to_torch_dtype
|
||||
|
||||
TORCH_INT_TYPES = (
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
)
|
||||
|
||||
|
||||
def add(x1, x2):
|
||||
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
||||
@ -39,8 +31,6 @@ def multiply(x1, x2):
|
||||
|
||||
def mean(x, axis=None, keepdims=False):
|
||||
x = convert_to_tensor(x)
|
||||
# Conversion to float necessary for `torch.mean`
|
||||
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
|
||||
return torch.mean(x, axis=axis, keepdims=keepdims)
|
||||
|
||||
|
||||
@ -182,8 +172,6 @@ def array(x, dtype=None):
|
||||
|
||||
def average(x, axis=None, weights=None):
|
||||
x = convert_to_tensor(x)
|
||||
# Conversion to float necessary for `torch.mean`
|
||||
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
|
||||
if weights is not None:
|
||||
weights = convert_to_tensor(weights)
|
||||
return torch.sum(torch.mul(x, weights), dim=axis) / torch.sum(
|
||||
@ -390,10 +378,6 @@ def imag(x):
|
||||
|
||||
def isclose(x1, x2):
|
||||
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
||||
if torch.is_floating_point(x1) and not torch.is_floating_point(x2):
|
||||
x2 = cast(x2, x1.dtype)
|
||||
if torch.is_floating_point(x2) and not torch.is_floating_point(x1):
|
||||
x1 = cast(x1, x2.dtype)
|
||||
return torch.isclose(x1, x2)
|
||||
|
||||
|
||||
@ -466,8 +450,6 @@ def log2(x):
|
||||
|
||||
def logaddexp(x1, x2):
|
||||
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
||||
x1 = cast(x1, "float32") if x1.dtype in TORCH_INT_TYPES else x1
|
||||
x2 = cast(x2, "float32") if x2.dtype in TORCH_INT_TYPES else x2
|
||||
return torch.logaddexp(x1, x2)
|
||||
|
||||
|
||||
@ -603,8 +585,7 @@ def ravel(x):
|
||||
|
||||
|
||||
def real(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.from_numpy(x) # needed for complex type conversion
|
||||
x = convert_to_tensor(x)
|
||||
return torch.real(x)
|
||||
|
||||
|
||||
@ -665,8 +646,6 @@ def stack(x, axis=0):
|
||||
|
||||
def std(x, axis=None, keepdims=False):
|
||||
x = convert_to_tensor(x)
|
||||
# Conversion to float necessary for `torch.std`
|
||||
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
|
||||
# Remove Bessel correction to align with numpy
|
||||
return torch.std(x, dim=axis, keepdim=keepdims, unbiased=False)
|
||||
|
||||
@ -697,9 +676,6 @@ def tan(x):
|
||||
|
||||
def tensordot(x1, x2, axes=2):
|
||||
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
||||
# Conversion to long necessary for `torch.tensordot`
|
||||
x1 = cast(x1, "int64") if x1.dtype in TORCH_INT_TYPES else x1
|
||||
x2 = cast(x2, "int64") if x2.dtype in TORCH_INT_TYPES else x2
|
||||
return torch.tensordot(x1, x2, dims=axes)
|
||||
|
||||
|
||||
@ -801,9 +777,7 @@ def transpose(x, axes=None):
|
||||
|
||||
|
||||
def var(x, axis=None, keepdims=False):
|
||||
x = convert_to_tensor(x, dtype="float32")
|
||||
# Conversion to float necessary for `torch.var`
|
||||
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
|
||||
x = convert_to_tensor(x)
|
||||
# Bessel correction removed for numpy compatibility
|
||||
return torch.var(x, dim=axis, keepdim=keepdims, correction=0)
|
||||
|
||||
|
@ -21,7 +21,7 @@ class Trainer:
|
||||
@tracking.no_automatic_dependency_tracking
|
||||
def compile(
|
||||
self,
|
||||
optimizer,
|
||||
optimizer="rmsprop",
|
||||
loss=None,
|
||||
loss_weights=None,
|
||||
metrics=None,
|
||||
|
Loading…
Reference in New Issue
Block a user