Progress on saving/reloading.

This commit is contained in:
Francois Chollet 2023-04-25 12:59:32 -07:00
parent 83c356a791
commit 879a6c244c
19 changed files with 451 additions and 483 deletions

@ -1,4 +1,3 @@
from keras_core.backend.common import backend_utils
from keras_core.backend.common import random
from keras_core.backend.common.variables import KerasVariable
from keras_core.backend.common.variables import standardize_dtype

@ -1,79 +0,0 @@
def _compute_conv_transpose_output_length(
input_length,
kernel_size,
padding,
output_padding=None,
stride=1,
dilation=1,
):
"""Computes output size of a transposed convolution given input size."""
assert padding in {"same", "valid"}
if input_length is None:
return None
# Get the dilated kernel size
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
# Infer length if output padding is None, else compute the exact length
if output_padding is None:
if padding == "valid":
length = input_length * stride + max(kernel_size - stride, 0)
else:
length = input_length * stride
else:
if padding == "same":
pad = kernel_size // 2
else:
pad = 0
length = (
(input_length - 1) * stride + kernel_size - 2 * pad + output_padding
)
return length
def compute_conv_transpose_output_shape(
inputs,
kernel,
strides,
padding,
output_padding=None,
data_format="channels_last",
dilation_rate=1,
):
num_spatial_dims = len(inputs.shape) - 2
kernel_spatial_shape = kernel.shape[:-2]
if isinstance(output_padding, int):
output_padding = (output_padding,) * len(kernel_spatial_shape)
if isinstance(strides, int):
strides = (strides,) * num_spatial_dims
if isinstance(dilation_rate, int):
dilation_rate = (dilation_rate,) * num_spatial_dims
if data_format == "channels_last":
inputs_spatial_shape = inputs.shape[1:-1]
else:
inputs_spatial_shape = inputs.shape[2:]
output_shape = []
for i in range(num_spatial_dims):
current_output_padding = (
None if output_padding is None else output_padding[i]
)
output_shape.append(
_compute_conv_transpose_output_length(
inputs_spatial_shape[i],
kernel_spatial_shape[i],
padding=padding,
output_padding=current_output_padding,
stride=strides[i],
dilation=dilation_rate[0],
)
)
if data_format == "channels_last":
output_shape = [inputs.shape[0]] + output_shape + [kernel.shape[-2]]
else:
output_shape = [inputs.shape[0], kernel.shape[-1]] + output_shape
return output_shape

@ -1,8 +1,6 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax import nn as jnn
from jax import numpy as jnp
def relu(x):
@ -65,293 +63,46 @@ def log_softmax(x, axis=-1):
return jnn.log_softmax(x, axis=axis)
def _convert_to_spatial_operand(
x,
num_spatial_dims,
data_format="channels_last",
include_batch_and_channels=True,
):
# Helper function that converts an operand to a spatial operand.
x = (x,) * num_spatial_dims if isinstance(x, int) else x
if not include_batch_and_channels:
return x
if data_format == "channels_last":
x = (1,) + x + (1,)
else:
x = (1,) + (1,) + x
return x
def max_pool(inputs, pool_size, strides, padding):
# TODO: Implement `max_pool` with JAX ops.
raise NotImplementedError
def _pool(
inputs,
initial_value,
reduce_fn,
pool_size,
strides=None,
padding="valid",
):
"""Helper function to define pooling functions.
Args:
inputs: input data of shape `N+2`.
initial_value: the initial value for the reduction.
reduce_fn: a reduce function of the form `(T, T) -> T`.
pool_size: a sequence of `N` integers, representing the window size to
reduce over.
strides: a sequence of `N` integers, representing the inter-window
strides (default: `(1, ..., 1)`).
padding: either the string `same` or `valid`.
Returns:
The output of the reduction for each window slice.
"""
if padding not in ("same", "valid"):
raise ValueError(
f"Invalid padding '{padding}', must be 'same' or 'valid'."
)
padding = padding.upper()
return lax.reduce_window(
inputs,
initial_value,
reduce_fn,
pool_size,
strides,
padding,
)
def average_pool(inputs, pool_size, strides, padding):
# TODO: Implement `average_pool` with JAX ops.
raise NotImplementedError
def max_pool(
inputs,
pool_size,
strides=None,
padding="valid",
data_format="channels_last",
):
num_spatial_dims = inputs.ndim - 2
pool_size = _convert_to_spatial_operand(
pool_size, num_spatial_dims, data_format
)
strides = pool_size if strides is None else strides
strides = _convert_to_spatial_operand(
strides, num_spatial_dims, data_format
)
return _pool(inputs, -jnp.inf, lax.max, pool_size, strides, padding)
def average_pool(
inputs,
pool_size,
strides,
padding,
data_format="channels_last",
):
num_spatial_dims = inputs.ndim - 2
pool_size = _convert_to_spatial_operand(
pool_size, num_spatial_dims, data_format
)
strides = pool_size if strides is None else strides
strides = _convert_to_spatial_operand(
strides, num_spatial_dims, data_format
)
pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding)
if padding == "valid":
# Avoid the extra reduce_window.
return pooled / np.prod(pool_size)
else:
# Count the number of valid entries at each input point, then use that
# for computing average. Assumes that any two arrays of same shape will
# be padded the same. Avoid broadcasting on axis where pooling is
# skipped.
shape = [
(a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size)
]
window_counts = _pool(
jnp.ones(shape, inputs.dtype),
0.0,
lax.add,
pool_size,
strides,
padding,
)
return pooled / window_counts
def _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format="channels_last",
transpose=False,
):
"""Create a `lax.ConvDimensionNumbers` for the given inputs."""
num_dims = num_spatial_dims + 2
if data_format == "channels_last":
spatial_dims = tuple(range(1, num_dims - 1))
inputs_dn = (0, num_dims - 1) + spatial_dims
else:
spatial_dims = tuple(range(2, num_dims))
inputs_dn = (0, 1) + spatial_dims
if transpose:
kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
else:
kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
return lax.ConvDimensionNumbers(
lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn
)
def conv(
inputs,
kernel,
strides=1,
padding="valid",
data_format="channel_last",
dilation_rate=1,
):
num_spatial_dims = inputs.ndim - 2
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format,
transpose=False,
)
strides = _convert_to_spatial_operand(
strides,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
dilation_rate = _convert_to_spatial_operand(
dilation_rate,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
def conv(inputs, kernel, strides, padding, dilation_rate=None):
# TODO: Add missing args.
return jax.lax.conv_general_dilated(
inputs,
kernel,
strides,
padding,
rhs_dilation=dilation_rate,
dimension_numbers=dimension_numbers,
inputs, kernel, strides, padding, rhs_dilation=dilation_rate
)
def depthwise_conv(
inputs,
kernel,
strides=1,
padding="valid",
data_format="channel_last",
dilation_rate=1,
):
num_spatial_dims = inputs.ndim - 2
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format,
transpose=False,
)
strides = _convert_to_spatial_operand(
strides,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
dilation_rate = _convert_to_spatial_operand(
dilation_rate,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
feature_group_count = (
inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1]
)
kernel = jnp.reshape(
kernel,
kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]),
)
return jax.lax.conv_general_dilated(
inputs,
kernel,
strides,
padding,
rhs_dilation=dilation_rate,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
)
def depthwise_conv(inputs, filter, strides, padding):
# TODO: Implement `depthwise_conv` with `conv_general_dilated`.
raise NotImplementedError
def separable_conv(
inputs,
depthwise_kernel,
pointwise_kernel,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
inputs, depthwise_kernel, pointwise_kernel, strides, padding
):
depthwise_conv_output = depthwise_conv(
inputs,
depthwise_kernel,
strides,
padding,
data_format,
dilation_rate,
)
return conv(
depthwise_conv_output,
pointwise_kernel,
strides=1,
padding="valid",
data_format=data_format,
dilation_rate=dilation_rate,
)
# TODO: Implement `separable_conv` with `conv_general_dilated`.
raise NotImplementedError
def conv_transpose(
inputs,
kernel,
strides=1,
strides,
output_padding,
padding="valid",
output_padding=None,
data_format="channels_last",
dilation_rate=1,
):
num_spatial_dims = inputs.ndim - 2
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format,
transpose=False,
)
strides = _convert_to_spatial_operand(
strides,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
dilation_rate = _convert_to_spatial_operand(
dilation_rate,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
if output_padding is not None:
raise ValueError(
"Custom `output_padding` is not supported yet, please set "
"`output_padding=None`."
)
padding = padding.upper()
return jax.lax.conv_transpose(
inputs,
kernel,
strides,
padding,
rhs_dilation=dilation_rate,
dimension_numbers=dimension_numbers,
transpose_kernel=True,
)
# TODO: Implement `conv_transpose`.
raise NotImplementedError
def one_hot(x, num_classes, axis=-1):

@ -23,3 +23,11 @@ def mean(x, axis=None, keepdims=False):
def max(x, axis=None, keepdims=False):
return jnp.max(x, axis=axis, keepdims=keepdims)
def ones(shape, dtype="float32"):
return jnp.ones(shape, dtype=dtype)
def zeros(shape, dtype="float32"):
return jnp.zeros(shape, dtype=dtype)

@ -1,9 +1,5 @@
import tensorflow as tf
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_output_shape,
)
def relu(x):
return tf.nn.relu(x)
@ -256,8 +252,10 @@ def separable_conv(
if data_format == "channels_last":
strides = (1,) + strides + (1,)
spatial_start_dim = 1
else:
strides = (1, 1) + strides
spatial_start_dim = 2
return tf.nn.separable_conv2d(
inputs,
depthwise_kernel,
@ -269,6 +267,100 @@ def separable_conv(
)
def _deconv_output_length(
input_length,
kernel_size,
padding,
output_padding=None,
stride=1,
dilation=1,
):
"""Determines output length of a transposed convolution given input length.
Args:
input_length: Integer.
kernel_size: Integer.
padding: one of `"same"` or `"valid"`.
output_padding: Integer, amount of padding along the output dimension.
Can be set to `None` in which case the output length is inferred.
stride: Integer.
dilation: Integer.
Returns:
The output length (integer).
"""
assert padding in {"same", "valid"}
if input_length is None:
return None
# Get the dilated kernel size
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
# Infer length if output padding is None, else compute the exact length
if output_padding is None:
if padding == "valid":
length = input_length * stride + max(kernel_size - stride, 0)
else:
length = input_length * stride
else:
if padding == "same":
pad = kernel_size // 2
else:
pad = 0
length = (
(input_length - 1) * stride + kernel_size - 2 * pad + output_padding
)
return length
def compute_output_shape_conv_transpose(
inputs,
kernel,
strides,
padding,
output_padding=None,
data_format="channels_last",
dilation_rate=1,
):
num_spatial_dims = len(inputs.shape) - 2
kernel_spatial_shape = kernel.shape[:-2]
if isinstance(output_padding, int):
output_padding = (output_padding,) * len(kernel_spatial_shape)
if isinstance(strides, int):
strides = (strides,) * num_spatial_dims
if isinstance(dilation_rate, int):
dilation_rate = (dilation_rate,) * num_spatial_dims
if data_format == "channels_last":
inputs_spatial_shape = inputs.shape[1:-1]
else:
inputs_spatial_shape = inputs.shape[2:]
output_shape = []
for i in range(num_spatial_dims):
current_output_padding = (
None if output_padding is None else output_padding[i]
)
output_shape.append(
_deconv_output_length(
inputs_spatial_shape[i],
kernel_spatial_shape[i],
padding=padding,
output_padding=current_output_padding,
stride=strides[i],
dilation=dilation_rate[0],
)
)
if data_format == "channels_last":
output_shape = [inputs.shape[0]] + output_shape + [kernel.shape[-2]]
else:
output_shape = [inputs.shape[0], kernel.shape[-1]] + output_shape
return output_shape
def conv_transpose(
inputs,
kernel,
@ -279,7 +371,7 @@ def conv_transpose(
dilation_rate=1,
):
tf_data_format = _convert_data_format(data_format, len(inputs.shape))
output_shape = compute_conv_transpose_output_shape(
output_shape = compute_output_shape_conv_transpose(
inputs,
kernel,
strides,

@ -1,3 +1,4 @@
import tensorflow as tf
from tensorflow.experimental import numpy as tfnp
@ -23,3 +24,13 @@ def mean(x, axis=None, keepdims=False):
def max(x, axis=None, keepdims=False):
return tfnp.max(x, axis=axis, keepdims=keepdims)
def ones(shape, dtype="float32"):
with tf.init_scope():
return tf.ones(shape, dtype=dtype)
def zeros(shape, dtype="float32"):
with tf.init_scope():
return tf.zeros(shape, dtype=dtype)

@ -8,6 +8,7 @@ from keras_core.backend.config import floatx
def tf_draw_seed(seed):
# TF ops only accept int32/64 seeds but our base seed is uint32.
with tf.init_scope():
return tf.cast(draw_seed(seed), dtype="int32")
@ -34,6 +35,7 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""
dtype = dtype or floatx()
seed = tf_draw_seed(seed)
with tf.init_scope():
return tf.random.stateless_normal(
shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
)
@ -63,6 +65,7 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
"""
dtype = dtype or floatx()
seed = tf_draw_seed(seed)
with tf.init_scope():
return tf.random.stateless_uniform(
shape=shape,
minval=minval,
@ -95,6 +98,7 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""
dtype = dtype or floatx()
seed = tf_draw_seed(seed)
with tf.init_scope():
return tf.random.stateless_truncated_normal(
shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
)
@ -102,6 +106,7 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
def dropout(inputs, rate, noise_shape=None, seed=None):
seed = tf_draw_seed(seed)
with tf.init_scope():
return tf.nn.experimental.stateless_dropout(
inputs,
rate=rate,

@ -43,12 +43,13 @@ ALL_OBJECTS_DICT.update(
@keras_core_export("keras_core.initializers.serialize")
def serialize(initializer):
"""Returns the initializer configuration as a Python dict."""
return serialization_lib.serialize_keras_object(initializer)
@keras_core_export("keras_core.initializers.deserialize")
def deserialize(config, custom_objects=None):
"""Return a Keras initializer object via its config."""
"""Returns a Keras initializer object via its configuration."""
return serialization_lib.deserialize_keras_object(
config,
module_objects=ALL_OBJECTS_DICT,
@ -58,7 +59,7 @@ def deserialize(config, custom_objects=None):
@keras_core_export("keras_core.initializers.get")
def get(identifier):
"""Retrieve a Keras initializer object via an identifier.
"""Retrieves a Keras initializer object via an identifier.
The `identifier` may be the string name of a initializers function or class
(case-sensitively).

@ -470,6 +470,22 @@ class Layer(Operation):
"""
all_vars = self._variables
if len(store.keys()) != len(all_vars):
if len(all_vars) == 0 and not self.built:
raise ValueError(
f"Layer '{self.name}' was never built "
"and thus it doesn't have any variables. "
f"However the weights file lists {len(store.keys())} "
"variables for this layer. In most cases, "
"this indicates that you need to implement the "
"`def build_from_config(self, config)` method "
"on the layer. "
"You might also want to implement the method "
"that generates the config at saving time, "
"`def get_build_config(self)`. "
"The method `build_from_config()` is meant "
"to create the state "
"of the layer (i.e. its variables) upon deserialization.",
)
raise ValueError(
f"Layer '{self.name}' expected {len(all_vars)} variables, "
"but received "

@ -1,15 +1,15 @@
from keras_core import backend
from keras_core import operations as ops
from keras_core.metrics import reduction_metrics
from keras_core.api_export import keras_core_export
from keras_core.losses.loss import squeeze_to_same_rank
from keras_core.backend import floatx
from keras_core.metrics import reduction_metrics
def accuracy(y_true, y_pred):
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
y_true, y_pred = squeeze_to_same_rank(y_true, y_pred)
return ops.cast(ops.equal(y_true, y_pred), dtype=floatx())
return ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx())
@keras_core_export("keras_core.metrics.Accuracy")

@ -1,4 +1,5 @@
import os
import warnings
from keras_core import backend
from keras_core.api_export import keras_core_export
@ -205,6 +206,69 @@ class Model(Trainer, Layer):
"files."
)
def build_from_config(self, config):
def is_shape_tuple(s):
return isinstance(s, (list, tuple)) and all(
d is None or isinstance(d, int) for d in s
)
if config:
failure = False
if "input_shape" in config:
# Case: all inputs are in the first arg (possibly nested).
input_shape = config["input_shape"]
if is_shape_tuple(input_shape):
input_shape = tuple(input_shape)
if isinstance(input_shape, list):
input_tensors = [
backend.KerasTensor(shape) for shape in input_shape
]
elif isinstance(input_shape, dict):
input_tensors = {
k: backend.KerasTensor(shape)
for k, shape in input_shape.items()
}
else:
input_tensors = backend.KerasTensor(input_shape)
try:
self(input_tensors)
self._build_shapes_dict = config
except:
failure = True
elif "shapes_dict" in config:
# Case: inputs were recorded as multiple keyword arguments.
if all(
is_shape_tuple(s) for s in config["shapes_dict"].values()
):
# Case: all input keyword arguments were plain tensors.
input_tensors = {
k: backend.KerasTensor(v)
for k, v in config["shapes_dict"].items()
}
try:
self(**input_tensors)
self._build_shapes_dict = config["shapes_dict"]
except:
failure = True
else:
# Not supported: nested input keyword arguments.
failure = True
if failure:
warnings.warn(
f"Model '{self.name}' had a build config, but the model "
"cannot be built automatically in "
"`build_from_config(config)`. "
"You should implement "
"`def build_from_config(self, config)`, "
"and you might also want to implement the method "
" that generates the config at saving time, "
"`def get_build_config(self)`. "
"The method `build_from_config()` is meant to "
"create the state of the model (i.e. its variables) "
"upon deserialization.",
stacklevel=2,
)
def export(self, filepath):
raise NotImplementedError

@ -25,14 +25,13 @@ conv_transpose
ctc ??
"""
import math
import numpy as np
from keras_core import backend
from keras_core.backend import KerasTensor
from keras_core.backend import any_symbolic_tensors
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_output_shape,
)
from keras_core.operations.operation import Operation
@ -902,7 +901,7 @@ class ConvTranspose(Operation):
)
def compute_output_spec(self, inputs, kernel):
output_shape = compute_conv_transpose_output_shape(
output_shape = backend.nn.compute_output_shape_conv_transpose(
inputs,
kernel,
self.strides,
@ -935,7 +934,7 @@ def conv_transpose(
`data_format="channels_first"`. Pooling happens over the spatial
dimensions only.
kernel: Tensor of rank N+2. `kernel` has shape
[kernel_spatial_shape, num_output_channels, num_input_channels],
[kernel_spatial_shape, num_input_channels, num_output_channels],
`num_input_channels` should match the number of channels in
`inputs`.
strides: int or int tuple/list of `len(inputs_spatial_shape)`,

@ -281,6 +281,10 @@ class NNOpsDynamicShapeTest(testing.TestCase):
)
@pytest.mark.skipif(
backend() != "tensorflow",
reason="Not have other backend support yet.",
)
class NNOpsStaticShapeTest(testing.TestCase):
def test_relu(self):
x = KerasTensor([1, 2, 3])
@ -543,6 +547,10 @@ class NNOpsStaticShapeTest(testing.TestCase):
)
@pytest.mark.skipif(
backend() != "tensorflow",
reason="Not have other backend support yet.",
)
class NNOpsCorrectnessTest(testing.TestCase):
def test_relu(self):
x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)
@ -769,7 +777,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
expected = tf.nn.conv3d(
inputs_3d, kernel, (1, 1, 1, 1, 1), padding="VALID"
)
self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
self.assertAllClose(outputs, expected)
outputs = knn.conv(
inputs_3d,
@ -785,13 +793,13 @@ class NNOpsCorrectnessTest(testing.TestCase):
padding="VALID",
dilations=(1, 1, 1, 1, 1),
)
self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
self.assertAllClose(outputs, expected)
outputs = knn.conv(inputs_3d, kernel, 2, padding="same")
expected = tf.nn.conv3d(
inputs_3d, kernel, (1, 2, 2, 2, 1), padding="SAME"
)
self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
self.assertAllClose(outputs, expected)
def test_depthwise_conv(self):
# Test 2D conv.
@ -904,6 +912,14 @@ class NNOpsCorrectnessTest(testing.TestCase):
)
self.assertAllClose(outputs, expected)
outputs = knn.conv_transpose(
inputs_1d, kernel, 5, output_padding=4, padding="valid"
)
expected = tf.nn.conv_transpose(
inputs_1d, kernel, [2, 21, 5], 5, padding="VALID"
)
self.assertAllClose(outputs, expected)
# Test 2D conv.
inputs_2d = np.arange(96, dtype=float).reshape([2, 4, 4, 3])
kernel = np.arange(60, dtype=float).reshape([2, 2, 5, 3])
@ -919,3 +935,21 @@ class NNOpsCorrectnessTest(testing.TestCase):
inputs_2d, kernel, [2, 8, 8, 5], 2, padding="SAME"
)
self.assertAllClose(outputs, expected)
outputs = knn.conv_transpose(
inputs_2d,
kernel,
5,
output_padding=4,
padding="valid",
dilation_rate=(1, 1),
)
expected = tf.nn.conv_transpose(
inputs_2d,
kernel,
[2, 21, 21, 5],
5,
padding="VALID",
dilations=(1, 1),
)
self.assertAllClose(outputs, expected)

@ -2923,26 +2923,26 @@ def sum(x, axis=None, keepdims=False):
class Zeros(Operation):
def call(self, shape, dtype="float32"):
return backend.execute("zeros", shape, dtype)
return backend.numpy.zeros(shape, dtype=dtype)
def compute_output_spec(self, shape, dtype="float32"):
return KerasTensor(shape, dtype=dtype)
def zeros(shape, dtype="float32"):
return backend.execute("zeros", shape, dtype)
return backend.numpy.zeros(shape, dtype=dtype)
class Ones(Operation):
def call(self, shape, dtype="float32"):
return backend.execute("ones", shape, dtype)
return backend.numpy.ones(shape, dtype=dtype)
def compute_output_spec(self, shape, dtype="float32"):
return KerasTensor(shape, dtype=dtype)
def ones(shape, dtype="float32"):
return backend.execute("ones", shape, dtype)
return backend.numpy.ones(shape, dtype=dtype)
class Eye(Operation):

@ -1,3 +1,79 @@
from keras_core.api_export import keras_core_export
from keras_core.optimizers.adam import Adam
from keras_core.optimizers.optimizer import Optimizer
from keras_core.optimizers.sgd import SGD
from keras_core.saving import serialization_lib
ALL_OBJECTS = {
Optimizer,
Adam,
SGD,
}
ALL_OBJECTS_DICT = {cls.__name__.lower(): cls for cls in ALL_OBJECTS}
@keras_core_export("keras_core.optimizers.serialize")
def serialize(optimizer):
"""Returns the optimizer configuration as a Python dict.
Args:
optimizer: An `Optimizer` instance to serialize.
Returns:
Python dict which contains the configuration of the optimizer.
"""
return serialization_lib.serialize_keras_object(optimizer)
@keras_core_export("keras_core.optimizers.deserialize")
def deserialize(config, custom_objects=None):
"""Returns a Keras optimizer object via its configuration.
Args:
config: Optimizer configuration dictionary.
custom_objects: Optional dictionary mapping names (strings) to custom
objects (classes and functions) to be considered during
deserialization.
Returns:
A Keras Optimizer instance.
"""
# Make deserialization case-insensitive for built-in optimizers.
if config["class_name"].lower() in ALL_OBJECTS_DICT:
config["class_name"] = config["class_name"].lower()
print("deserialize:", config)
return serialization_lib.deserialize_keras_object(
config,
module_objects=ALL_OBJECTS_DICT,
custom_objects=custom_objects,
)
@keras_core_export("keras_core.optimizers.get")
def get(identifier):
"""Retrieves a Keras Optimizer instance.
Args:
identifier: Optimizer identifier, one of:
- String: name of an optimizer
- Dictionary: configuration dictionary.
- Keras Optimizer instance (it will be returned unchanged).
Returns:
A Keras Optimizer instance.
"""
print("call get with", identifier)
if isinstance(identifier, Optimizer):
return identifier
elif isinstance(identifier, dict):
return deserialize(identifier)
elif isinstance(identifier, str):
config = {"class_name": identifier, "config": {}}
opt = deserialize(config)
print(opt)
return opt
else:
raise ValueError(
f"Could not interpret optimizer identifier: {identifier}"
)

@ -36,6 +36,8 @@ _ASSETS_DIRNAME = "assets"
ATTR_SKIPLIST = frozenset(
{
"_operations",
"_layers",
"_functional",
"_losses",
"_inbound_nodes",
@ -108,7 +110,7 @@ def save_model(model, filepath, weights_format="h5"):
zip_filepath = os.path.join(get_temp_dir(), "tmp_model.keras")
else:
zip_filepath = filepath
try:
with zipfile.ZipFile(zip_filepath, "w") as zf:
with zf.open(_METADATA_FILENAME, "w") as f:
f.write(metadata_json.encode())
@ -116,9 +118,7 @@ def save_model(model, filepath, weights_format="h5"):
f.write(config_json.encode())
if weights_format == "h5":
weights_store = H5IOStore(
_VARS_FNAME + ".h5", archive=zf, mode="w"
)
weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="w")
elif weights_format == "npz":
weights_store = NpzIOStore(
_VARS_FNAME + ".npz", archive=zf, mode="w"
@ -147,8 +147,6 @@ def save_model(model, filepath, weights_format="h5"):
# writing to GCS. Hence writing to local and copying to filepath.
gfile.copy(zip_filepath, filepath, overwrite=True)
os.remove(zip_filepath)
except Exception as e:
raise e
def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
@ -161,7 +159,6 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
f"Received: filepath={filepath}"
)
try:
with gfile.GFile(filepath, mode="r+b") as gfile_handle, zipfile.ZipFile(
gfile_handle, "r"
) as zf:
@ -182,9 +179,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
all_filenames = zf.namelist()
if _VARS_FNAME + ".h5" in all_filenames:
weights_store = H5IOStore(
_VARS_FNAME + ".h5", archive=zf, mode="r"
)
weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="r")
elif _VARS_FNAME + ".npz" in all_filenames:
weights_store = NpzIOStore(
_VARS_FNAME + ".npz", archive=zf, mode="r"
@ -209,10 +204,6 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
weights_store.close()
if asset_store:
asset_store.close()
except Exception as e:
raise e
else:
return model

@ -1,5 +1,5 @@
import json
import os
import shutil
import tempfile
import unittest
@ -12,7 +12,7 @@ class TestCase(unittest.TestCase):
def get_temp_dir(self):
temp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: os.rmdir(temp_dir))
self.addCleanup(lambda: shutil.rmtree(temp_dir))
return temp_dir
def assertAllClose(self, x1, x2, atol=1e-7, rtol=1e-7):

@ -3,6 +3,7 @@ import warnings
from keras_core import backend
from keras_core import metrics as metrics_module
from keras_core import operations as ops
from keras_core import optimizers
from keras_core.saving import serialization_lib
from keras_core.trainers.compile_utils import CompileLoss
from keras_core.trainers.compile_utils import CompileMetrics
@ -26,8 +27,7 @@ class Trainer:
run_eagerly=False,
jit_compile=True,
):
# TODO: get from module
self.optimizer = optimizer
self.optimizer = optimizers.get(optimizer)
if loss is not None:
self._compile_loss = CompileLoss(loss, loss_weights)
else: