Merge branch 'main' of github.com:keras-team/keras-core

This commit is contained in:
Francois Chollet 2023-07-30 10:28:45 -07:00
parent d2c913e2f5
commit 4135b933b4
9 changed files with 353 additions and 26 deletions

@ -58,9 +58,9 @@ def get_model():
# Make a simple convnet with batch normalization and dropout.
inputs = keras.Input(shape=(28, 28, 1))
x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(
x
)
x = keras.layers.Conv2D(
filters=12, kernel_size=3, padding="same", use_bias=False
)(x)
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(
@ -187,7 +187,11 @@ compute_gradients = jax.value_and_grad(compute_loss, has_aux=True)
# Training step, Keras provides a pure functional optimizer.stateless_apply
@jax.jit
def train_step(train_state, x, y):
trainable_variables, non_trainable_variables, optimizer_variables = train_state
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
) = train_state
(loss_value, non_trainable_variables), grads = compute_gradients(
trainable_variables, non_trainable_variables, x, y
)
@ -211,7 +215,9 @@ def get_replicated_train_state(devices):
var_replication = NamedSharding(var_mesh, P())
# Apply the distribution settings to the model variables
trainable_variables = jax.device_put(model.trainable_variables, var_replication)
trainable_variables = jax.device_put(
model.trainable_variables, var_replication
)
non_trainable_variables = jax.device_put(
model.non_trainable_variables, var_replication
)
@ -255,7 +261,9 @@ for epoch in range(num_epochs):
trainable_variables, non_trainable_variables, optimizer_variables = train_state
for variable, value in zip(model.trainable_variables, trainable_variables):
variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
for variable, value in zip(
model.non_trainable_variables, non_trainable_variables
):
variable.assign(value)
"""

@ -53,9 +53,9 @@ def get_model():
# Make a simple convnet with batch normalization and dropout.
inputs = keras.Input(shape=(28, 28, 1))
x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(
x
)
x = keras.layers.Conv2D(
filters=12, kernel_size=3, padding="same", use_bias=False
)(x)
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(
@ -231,7 +231,9 @@ def per_device_launch_fn(current_gpu_index, num_gpu):
model = get_model()
# prepare the dataloader
dataloader = prepare_dataloader(dataset, current_gpu_index, num_gpu, batch_size)
dataloader = prepare_dataloader(
dataset, current_gpu_index, num_gpu, batch_size
)
# Instantiate the torch optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

@ -107,10 +107,18 @@ def arccos(x):
return jnp.arccos(x)
def arccosh(x):
return jnp.arccosh(x)
def arcsin(x):
return jnp.arcsin(x)
def arcsinh(x):
return jnp.arcsinh(x)
def arctan(x):
return jnp.arctan(x)
@ -119,6 +127,10 @@ def arctan2(x1, x2):
return jnp.arctan2(x1, x2)
def arctanh(x):
return jnp.arctanh(x)
def argmax(x, axis=None):
return jnp.argmax(x, axis=axis)
@ -171,6 +183,10 @@ def cos(x):
return jnp.cos(x)
def cosh(x):
return jnp.cosh(x)
def count_nonzero(x, axis=None):
return jnp.count_nonzero(x, axis=axis)
@ -441,6 +457,10 @@ def sin(x):
return jnp.sin(x)
def sinh(x):
return jnp.sinh(x)
def size(x):
return jnp.size(x)
@ -479,6 +499,10 @@ def tan(x):
return jnp.tan(x)
def tanh(x):
return jnp.tanh(x)
def tensordot(x1, x2, axes=2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)

@ -84,10 +84,18 @@ def arccos(x):
return np.arccos(x)
def arccosh(x):
return np.arccosh(x)
def arcsin(x):
return np.arcsin(x)
def arcsinh(x):
return np.arcsinh(x)
def arctan(x):
return np.arctan(x)
@ -96,6 +104,10 @@ def arctan2(x1, x2):
return np.arctan2(x1, x2)
def arctanh(x):
return np.arctanh(x)
def argmax(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.argmax(x, axis=axis)
@ -157,6 +169,10 @@ def cos(x):
return np.cos(x)
def cosh(x):
return np.cosh(x)
def count_nonzero(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.count_nonzero(x, axis=axis)
@ -438,6 +454,10 @@ def sin(x):
return np.sin(x)
def sinh(x):
return np.sinh(x)
def size(x):
return np.size(x)
@ -480,6 +500,10 @@ def tan(x):
return np.tan(x)
def tanh(x):
return np.tanh(x)
def tensordot(x1, x2, axes=2):
axes = tuple(axes) if isinstance(axes, list) else axes
return np.tensordot(x1, x2, axes=axes)

@ -105,10 +105,18 @@ def arccos(x):
return tfnp.arccos(x)
def arccosh(x):
return tfnp.arccosh(x)
def arcsin(x):
return tfnp.arcsin(x)
def arcsinh(x):
return tfnp.arcsinh(x)
def arctan(x):
return tfnp.arctan(x)
@ -117,6 +125,10 @@ def arctan2(x1, x2):
return tfnp.arctan2(x1, x2)
def arctanh(x):
return tfnp.arctanh(x)
def argmax(x, axis=None):
return tfnp.argmax(x, axis=axis)
@ -174,6 +186,10 @@ def cos(x):
return tfnp.cos(x)
def cosh(x):
return tfnp.cosh(x)
def count_nonzero(x, axis=None):
return tfnp.count_nonzero(x, axis=axis)
@ -472,6 +488,10 @@ def sin(x):
return tfnp.sin(x)
def sinh(x):
return tfnp.sinh(x)
def size(x):
return tfnp.size(x)
@ -508,6 +528,10 @@ def tan(x):
return tfnp.tan(x)
def tanh(x):
return tfnp.tanh(x)
def tensordot(x1, x2, axes=2):
return tfnp.tensordot(x1, x2, axes=axes)

@ -173,11 +173,21 @@ def arccos(x):
return torch.arccos(x)
def arccosh(x):
x = convert_to_tensor(x)
return torch.arccosh(x)
def arcsin(x):
x = convert_to_tensor(x)
return torch.arcsin(x)
def arcsinh(x):
x = convert_to_tensor(x)
return torch.arcsinh(x)
def arctan(x):
x = convert_to_tensor(x)
return torch.arctan(x)
@ -188,6 +198,11 @@ def arctan2(x1, x2):
return torch.arctan2(x1, x2)
def arctanh(x):
x = convert_to_tensor(x)
return torch.arctanh(x)
def argmax(x, axis=None):
x = convert_to_tensor(x)
return torch.argmax(x, dim=axis)
@ -277,6 +292,11 @@ def cos(x):
return torch.cos(x)
def cosh(x):
x = convert_to_tensor(x)
return torch.cosh(x)
def count_nonzero(x, axis=None):
x = convert_to_tensor(x)
if axis == () or axis == []:
@ -729,6 +749,11 @@ def sin(x):
return torch.sin(x)
def sinh(x):
x = convert_to_tensor(x)
return torch.sinh(x)
def size(x):
x_shape = convert_to_tensor(tuple(x.shape))
return torch.prod(x_shape)
@ -806,6 +831,11 @@ def tan(x):
return torch.tan(x)
def tanh(x):
x = convert_to_tensor(x)
return torch.tanh(x)
def tensordot(x1, x2, axes=2):
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
# Conversion to long necessary for `torch.tensordot`

@ -57,21 +57,6 @@ def sigmoid(x):
return backend.nn.sigmoid(x)
class Tanh(Operation):
def call(self, x):
return backend.nn.tanh(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
@keras_core_export(["keras_core.ops.tanh", "keras_core.ops.nn.tanh"])
def tanh(x):
if any_symbolic_tensors((x,)):
return Tanh().symbolic_call(x)
return backend.nn.tanh(x)
class Softplus(Operation):
def call(self, x):
return backend.nn.softplus(x)

@ -10,9 +10,12 @@ amin
append
arange
arccos
arccosh
arcsin
arcsinh
arctan
arctan2
arctanh
argmax
argmin
argsort
@ -27,6 +30,7 @@ conj
conjugate
copy
cos
cosh
count_nonzero
cross
cumprod
@ -102,6 +106,7 @@ roll
round
sign
sin
sinh
size
sort
split
@ -116,6 +121,7 @@ swapaxes
take
take_along_axis
tan
tanh
tensordot
tile
trace
@ -713,6 +719,28 @@ def arccos(x):
return backend.numpy.arccos(x)
class Arccosh(Operation):
def call(self, x):
return backend.numpy.arccosh(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def arccosh(x):
"""Inverse hyperbolic cosine, element-wise.
Arguments:
x: Input tensor.
Returns:
Output tensor of same shape as x.
"""
if any_symbolic_tensors((x,)):
return Arccosh().symbolic_call(x)
return backend.numpy.arccosh(x)
class Arcsin(Operation):
def call(self, x):
return backend.numpy.arcsin(x)
@ -742,6 +770,29 @@ def arcsin(x):
return backend.numpy.arcsin(x)
class Arcsinh(Operation):
def call(self, x):
return backend.numpy.arcsinh(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
@keras_core_export(["keras_core.ops.arcsinh", "keras_core.ops.numpy.arcsinh"])
def arcsinh(x):
"""Inverse hyperbolic sine, element-wise.
Arguments:
x: Input tensor.
Returns:
Output tensor of same shape as x.
"""
if any_symbolic_tensors((x,)):
return Arcsinh().symbolic_call(x)
return backend.numpy.arcsinh(x)
class Arctan(Operation):
def call(self, x):
return backend.numpy.arctan(x)
@ -826,6 +877,29 @@ def arctan2(x1, x2):
return backend.numpy.arctan2(x1, x2)
class Arctanh(Operation):
def call(self, x):
return backend.numpy.arctanh(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
@keras_core_export(["keras_core.ops.arctanh", "keras_core.ops.numpy.arctanh"])
def arctanh(x):
"""Inverse hyperbolic tangent, element-wise.
Arguments:
x: Input tensor.
Returns:
Output tensor of same shape as x.
"""
if any_symbolic_tensors((x,)):
return Arctanh().symbolic_call(x)
return backend.numpy.arctanh(x)
class Argmax(Operation):
def __init__(self, axis=None):
super().__init__()
@ -1289,6 +1363,29 @@ def cos(x):
return backend.numpy.cos(x)
class Cosh(Operation):
def call(self, x):
return backend.numpy.cosh(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
@keras_core_export(["keras_core.ops.cosh", "keras_core.ops.numpy.cosh"])
def cosh(x):
"""Hyperbolic cosine, element-wise.
Arguments:
x: Input tensor.
Returns:
Output tensor of same shape as x.
"""
if any_symbolic_tensors((x,)):
return Cosh().symbolic_call(x)
return backend.numpy.cosh(x)
class CountNonzero(Operation):
def __init__(self, axis=None):
super().__init__()
@ -3115,6 +3212,29 @@ def sin(x):
return backend.numpy.sin(x)
class Sinh(Operation):
def call(self, x):
return backend.numpy.sinh(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
@keras_core_export(["keras_core.ops.sinh", "keras_core.ops.numpy.sinh"])
def sinh(x):
"""Hyperbolic sine, element-wise.
Arguments:
x: Input tensor.
Returns:
Output tensor of same shape as x.
"""
if any_symbolic_tensors((x,)):
return Sinh().symbolic_call(x)
return backend.numpy.sinh(x)
class Size(Operation):
def call(self, x):
return backend.numpy.size(x)
@ -3376,6 +3496,29 @@ def tan(x):
return backend.numpy.tan(x)
class Tanh(Operation):
def call(self, x):
return backend.numpy.tanh(x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
@keras_core_export(["keras_core.ops.tanh", "keras_core.ops.numpy.tanh"])
def tanh(x):
"""Hyperbolic tangent, element-wise.
Arguments:
x: Input tensor.
Returns:
Output tensor of same shape as x.
"""
if any_symbolic_tensors((x,)):
return Tanh().symbolic_call(x)
return backend.numpy.tanh(x)
class Tensordot(Operation):
def __init__(self, axes=2):
super().__init__()

@ -754,14 +754,26 @@ class NumpyOneInputOpsDynamicShapeTest(testing.TestCase):
x = KerasTensor([None, 3])
self.assertEqual(knp.arccos(x).shape, (None, 3))
def test_arccosh(self):
x = KerasTensor([None, 3])
self.assertEqual(knp.arccosh(x).shape, (None, 3))
def test_arcsin(self):
x = KerasTensor([None, 3])
self.assertEqual(knp.arcsin(x).shape, (None, 3))
def test_arcsinh(self):
x = KerasTensor([None, 3])
self.assertEqual(knp.arcsinh(x).shape, (None, 3))
def test_arctan(self):
x = KerasTensor([None, 3])
self.assertEqual(knp.arctan(x).shape, (None, 3))
def test_arctanh(self):
x = KerasTensor([None, 3])
self.assertEqual(knp.arctanh(x).shape, (None, 3))
def test_argmax(self):
x = KerasTensor([None, 3])
self.assertEqual(knp.argmax(x).shape, ())
@ -855,6 +867,10 @@ class NumpyOneInputOpsDynamicShapeTest(testing.TestCase):
x = KerasTensor([None, 3])
self.assertEqual(knp.cos(x).shape, (None, 3))
def test_cosh(self):
x = KerasTensor([None, 3])
self.assertEqual(knp.cosh(x).shape, (None, 3))
def test_count_nonzero(self):
x = KerasTensor([None, 3])
self.assertEqual(knp.count_nonzero(x).shape, ())
@ -1095,6 +1111,10 @@ class NumpyOneInputOpsDynamicShapeTest(testing.TestCase):
x = KerasTensor([None, 3])
self.assertEqual(knp.sin(x).shape, (None, 3))
def test_sinh(self):
x = KerasTensor([None, 3])
self.assertEqual(knp.sinh(x).shape, (None, 3))
def test_size(self):
x = KerasTensor([None, 3])
self.assertEqual(knp.size(x).shape, ())
@ -1137,6 +1157,10 @@ class NumpyOneInputOpsDynamicShapeTest(testing.TestCase):
x = KerasTensor([None, 3])
self.assertEqual(knp.tan(x).shape, (None, 3))
def test_tanh(self):
x = KerasTensor([None, 3])
self.assertEqual(knp.tanh(x).shape, (None, 3))
def test_tile(self):
x = KerasTensor([None, 3])
self.assertEqual(knp.tile(x, [2]).shape, (None, 6))
@ -1227,14 +1251,26 @@ class NumpyOneInputOpsStaticShapeTest(testing.TestCase):
x = KerasTensor([2, 3])
self.assertEqual(knp.arccos(x).shape, (2, 3))
def test_arccosh(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.arccosh(x).shape, (2, 3))
def test_arcsin(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.arcsin(x).shape, (2, 3))
def test_arcsinh(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.arcsinh(x).shape, (2, 3))
def test_arctan(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.arctan(x).shape, (2, 3))
def test_arctanh(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.arctanh(x).shape, (2, 3))
def test_argmax(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.argmax(x).shape, ())
@ -1297,6 +1333,10 @@ class NumpyOneInputOpsStaticShapeTest(testing.TestCase):
x = KerasTensor([2, 3])
self.assertEqual(knp.cos(x).shape, (2, 3))
def test_cosh(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.cosh(x).shape, (2, 3))
def test_count_nonzero(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.count_nonzero(x).shape, ())
@ -1532,6 +1572,10 @@ class NumpyOneInputOpsStaticShapeTest(testing.TestCase):
x = KerasTensor([2, 3])
self.assertEqual(knp.sin(x).shape, (2, 3))
def test_sinh(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.sinh(x).shape, (2, 3))
def test_size(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.size(x).shape, ())
@ -1579,6 +1623,10 @@ class NumpyOneInputOpsStaticShapeTest(testing.TestCase):
x = KerasTensor([2, 3])
self.assertEqual(knp.tan(x).shape, (2, 3))
def test_tanh(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.tanh(x).shape, (2, 3))
def test_tile(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.tile(x, [2]).shape, (2, 6))
@ -2266,18 +2314,42 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
np.transpose(x, axes=(1, 0, 3, 2, 4)),
)
def test_arcos(self):
def test_arccos(self):
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
self.assertAllClose(knp.arccos(x), np.arccos(x))
self.assertAllClose(knp.Arccos()(x), np.arccos(x))
def test_arccosh(self):
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
self.assertAllClose(knp.arccosh(x), np.arccosh(x))
self.assertAllClose(knp.Arccosh()(x), np.arccosh(x))
def test_arcsin(self):
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
self.assertAllClose(knp.arcsin(x), np.arcsin(x))
self.assertAllClose(knp.Arcsin()(x), np.arcsin(x))
def test_arcsinh(self):
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
self.assertAllClose(knp.arcsinh(x), np.arcsinh(x))
self.assertAllClose(knp.Arcsinh()(x), np.arcsinh(x))
def test_arctan(self):
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
self.assertAllClose(knp.arctan(x), np.arctan(x))
self.assertAllClose(knp.Arctan()(x), np.arctan(x))
def test_arctanh(self):
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
self.assertAllClose(knp.arctanh(x), np.arctanh(x))
self.assertAllClose(knp.Arctanh()(x), np.arctanh(x))
def test_argmax(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(knp.argmax(x), np.argmax(x))
@ -2433,6 +2505,11 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(knp.cos(x), np.cos(x))
self.assertAllClose(knp.Cos()(x), np.cos(x))
def test_cosh(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(knp.cosh(x), np.cosh(x))
self.assertAllClose(knp.Cosh()(x), np.cosh(x))
def test_count_nonzero(self):
x = np.array([[0, 2, 3], [3, 2, 0]])
self.assertAllClose(knp.count_nonzero(x), np.count_nonzero(x))
@ -2899,6 +2976,11 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(knp.sin(x), np.sin(x))
self.assertAllClose(knp.Sin()(x), np.sin(x))
def test_sinh(self):
x = np.array([[1, -2, 3], [-3, 2, -1]])
self.assertAllClose(knp.sinh(x), np.sinh(x))
self.assertAllClose(knp.Sinh()(x), np.sinh(x))
def test_size(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(knp.size(x), np.size(x))
@ -2993,6 +3075,11 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(knp.tan(x), np.tan(x))
self.assertAllClose(knp.Tan()(x), np.tan(x))
def test_tanh(self):
x = np.array([[1, -2, 3], [-3, 2, -1]])
self.assertAllClose(knp.tanh(x), np.tanh(x))
self.assertAllClose(knp.Tanh()(x), np.tanh(x))
def test_tile(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(knp.tile(x, [2, 3]), np.tile(x, [2, 3]))