Add default optimizer to match tf.keras.

This commit is contained in:
Francois Chollet 2023-05-24 11:15:15 -07:00
parent 26f0f4f24c
commit 5148978589
3 changed files with 3 additions and 35 deletions

@ -375,12 +375,6 @@ PYTHON_DTYPES_MAP = {
str: "string",
}
PYTHON_DTYPES_MAP = {
bool: "bool",
int: "int", # TBD by backend
float: "float32",
}
def standardize_dtype(dtype):
if dtype is None:

@ -1,16 +1,8 @@
import torch
from keras_core.backend.torch.core import cast
from keras_core.backend.torch.core import convert_to_tensor
from keras_core.backend.torch.core import to_torch_dtype
TORCH_INT_TYPES = (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
)
def add(x1, x2):
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
@ -39,8 +31,6 @@ def multiply(x1, x2):
def mean(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
# Conversion to float necessary for `torch.mean`
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
return torch.mean(x, axis=axis, keepdims=keepdims)
@ -182,8 +172,6 @@ def array(x, dtype=None):
def average(x, axis=None, weights=None):
x = convert_to_tensor(x)
# Conversion to float necessary for `torch.mean`
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
if weights is not None:
weights = convert_to_tensor(weights)
return torch.sum(torch.mul(x, weights), dim=axis) / torch.sum(
@ -390,10 +378,6 @@ def imag(x):
def isclose(x1, x2):
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
if torch.is_floating_point(x1) and not torch.is_floating_point(x2):
x2 = cast(x2, x1.dtype)
if torch.is_floating_point(x2) and not torch.is_floating_point(x1):
x1 = cast(x1, x2.dtype)
return torch.isclose(x1, x2)
@ -466,8 +450,6 @@ def log2(x):
def logaddexp(x1, x2):
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
x1 = cast(x1, "float32") if x1.dtype in TORCH_INT_TYPES else x1
x2 = cast(x2, "float32") if x2.dtype in TORCH_INT_TYPES else x2
return torch.logaddexp(x1, x2)
@ -603,8 +585,7 @@ def ravel(x):
def real(x):
if not isinstance(x, torch.Tensor):
x = torch.from_numpy(x) # needed for complex type conversion
x = convert_to_tensor(x)
return torch.real(x)
@ -665,8 +646,6 @@ def stack(x, axis=0):
def std(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
# Conversion to float necessary for `torch.std`
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
# Remove Bessel correction to align with numpy
return torch.std(x, dim=axis, keepdim=keepdims, unbiased=False)
@ -697,9 +676,6 @@ def tan(x):
def tensordot(x1, x2, axes=2):
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
# Conversion to long necessary for `torch.tensordot`
x1 = cast(x1, "int64") if x1.dtype in TORCH_INT_TYPES else x1
x2 = cast(x2, "int64") if x2.dtype in TORCH_INT_TYPES else x2
return torch.tensordot(x1, x2, dims=axes)
@ -801,9 +777,7 @@ def transpose(x, axes=None):
def var(x, axis=None, keepdims=False):
x = convert_to_tensor(x, dtype="float32")
# Conversion to float necessary for `torch.var`
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
x = convert_to_tensor(x)
# Bessel correction removed for numpy compatibility
return torch.var(x, dim=axis, keepdim=keepdims, correction=0)

@ -21,7 +21,7 @@ class Trainer:
@tracking.no_automatic_dependency_tracking
def compile(
self,
optimizer,
optimizer="rmsprop",
loss=None,
loss_weights=None,
metrics=None,