Progress on saving/reloading.
This commit is contained in:
parent
83c356a791
commit
879a6c244c
@ -1,4 +1,3 @@
|
||||
from keras_core.backend.common import backend_utils
|
||||
from keras_core.backend.common import random
|
||||
from keras_core.backend.common.variables import KerasVariable
|
||||
from keras_core.backend.common.variables import standardize_dtype
|
||||
|
@ -1,79 +0,0 @@
|
||||
def _compute_conv_transpose_output_length(
|
||||
input_length,
|
||||
kernel_size,
|
||||
padding,
|
||||
output_padding=None,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
):
|
||||
"""Computes output size of a transposed convolution given input size."""
|
||||
assert padding in {"same", "valid"}
|
||||
if input_length is None:
|
||||
return None
|
||||
|
||||
# Get the dilated kernel size
|
||||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||
|
||||
# Infer length if output padding is None, else compute the exact length
|
||||
if output_padding is None:
|
||||
if padding == "valid":
|
||||
length = input_length * stride + max(kernel_size - stride, 0)
|
||||
else:
|
||||
length = input_length * stride
|
||||
else:
|
||||
if padding == "same":
|
||||
pad = kernel_size // 2
|
||||
else:
|
||||
pad = 0
|
||||
|
||||
length = (
|
||||
(input_length - 1) * stride + kernel_size - 2 * pad + output_padding
|
||||
)
|
||||
return length
|
||||
|
||||
|
||||
def compute_conv_transpose_output_shape(
|
||||
inputs,
|
||||
kernel,
|
||||
strides,
|
||||
padding,
|
||||
output_padding=None,
|
||||
data_format="channels_last",
|
||||
dilation_rate=1,
|
||||
):
|
||||
num_spatial_dims = len(inputs.shape) - 2
|
||||
kernel_spatial_shape = kernel.shape[:-2]
|
||||
|
||||
if isinstance(output_padding, int):
|
||||
output_padding = (output_padding,) * len(kernel_spatial_shape)
|
||||
if isinstance(strides, int):
|
||||
strides = (strides,) * num_spatial_dims
|
||||
if isinstance(dilation_rate, int):
|
||||
dilation_rate = (dilation_rate,) * num_spatial_dims
|
||||
|
||||
if data_format == "channels_last":
|
||||
inputs_spatial_shape = inputs.shape[1:-1]
|
||||
else:
|
||||
inputs_spatial_shape = inputs.shape[2:]
|
||||
|
||||
output_shape = []
|
||||
for i in range(num_spatial_dims):
|
||||
current_output_padding = (
|
||||
None if output_padding is None else output_padding[i]
|
||||
)
|
||||
output_shape.append(
|
||||
_compute_conv_transpose_output_length(
|
||||
inputs_spatial_shape[i],
|
||||
kernel_spatial_shape[i],
|
||||
padding=padding,
|
||||
output_padding=current_output_padding,
|
||||
stride=strides[i],
|
||||
dilation=dilation_rate[0],
|
||||
)
|
||||
)
|
||||
|
||||
if data_format == "channels_last":
|
||||
output_shape = [inputs.shape[0]] + output_shape + [kernel.shape[-2]]
|
||||
else:
|
||||
output_shape = [inputs.shape[0], kernel.shape[-1]] + output_shape
|
||||
return output_shape
|
@ -1,8 +1,6 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax import lax
|
||||
from jax import nn as jnn
|
||||
from jax import numpy as jnp
|
||||
|
||||
|
||||
def relu(x):
|
||||
@ -65,293 +63,46 @@ def log_softmax(x, axis=-1):
|
||||
return jnn.log_softmax(x, axis=axis)
|
||||
|
||||
|
||||
def _convert_to_spatial_operand(
|
||||
x,
|
||||
num_spatial_dims,
|
||||
data_format="channels_last",
|
||||
include_batch_and_channels=True,
|
||||
):
|
||||
# Helper function that converts an operand to a spatial operand.
|
||||
x = (x,) * num_spatial_dims if isinstance(x, int) else x
|
||||
if not include_batch_and_channels:
|
||||
return x
|
||||
if data_format == "channels_last":
|
||||
x = (1,) + x + (1,)
|
||||
else:
|
||||
x = (1,) + (1,) + x
|
||||
return x
|
||||
def max_pool(inputs, pool_size, strides, padding):
|
||||
# TODO: Implement `max_pool` with JAX ops.
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _pool(
|
||||
inputs,
|
||||
initial_value,
|
||||
reduce_fn,
|
||||
pool_size,
|
||||
strides=None,
|
||||
padding="valid",
|
||||
):
|
||||
"""Helper function to define pooling functions.
|
||||
|
||||
Args:
|
||||
inputs: input data of shape `N+2`.
|
||||
initial_value: the initial value for the reduction.
|
||||
reduce_fn: a reduce function of the form `(T, T) -> T`.
|
||||
pool_size: a sequence of `N` integers, representing the window size to
|
||||
reduce over.
|
||||
strides: a sequence of `N` integers, representing the inter-window
|
||||
strides (default: `(1, ..., 1)`).
|
||||
padding: either the string `same` or `valid`.
|
||||
|
||||
Returns:
|
||||
The output of the reduction for each window slice.
|
||||
"""
|
||||
if padding not in ("same", "valid"):
|
||||
raise ValueError(
|
||||
f"Invalid padding '{padding}', must be 'same' or 'valid'."
|
||||
)
|
||||
padding = padding.upper()
|
||||
return lax.reduce_window(
|
||||
inputs,
|
||||
initial_value,
|
||||
reduce_fn,
|
||||
pool_size,
|
||||
strides,
|
||||
padding,
|
||||
)
|
||||
def average_pool(inputs, pool_size, strides, padding):
|
||||
# TODO: Implement `average_pool` with JAX ops.
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def max_pool(
|
||||
inputs,
|
||||
pool_size,
|
||||
strides=None,
|
||||
padding="valid",
|
||||
data_format="channels_last",
|
||||
):
|
||||
num_spatial_dims = inputs.ndim - 2
|
||||
pool_size = _convert_to_spatial_operand(
|
||||
pool_size, num_spatial_dims, data_format
|
||||
)
|
||||
strides = pool_size if strides is None else strides
|
||||
strides = _convert_to_spatial_operand(
|
||||
strides, num_spatial_dims, data_format
|
||||
)
|
||||
return _pool(inputs, -jnp.inf, lax.max, pool_size, strides, padding)
|
||||
|
||||
|
||||
def average_pool(
|
||||
inputs,
|
||||
pool_size,
|
||||
strides,
|
||||
padding,
|
||||
data_format="channels_last",
|
||||
):
|
||||
num_spatial_dims = inputs.ndim - 2
|
||||
pool_size = _convert_to_spatial_operand(
|
||||
pool_size, num_spatial_dims, data_format
|
||||
)
|
||||
strides = pool_size if strides is None else strides
|
||||
strides = _convert_to_spatial_operand(
|
||||
strides, num_spatial_dims, data_format
|
||||
)
|
||||
|
||||
pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding)
|
||||
if padding == "valid":
|
||||
# Avoid the extra reduce_window.
|
||||
return pooled / np.prod(pool_size)
|
||||
else:
|
||||
# Count the number of valid entries at each input point, then use that
|
||||
# for computing average. Assumes that any two arrays of same shape will
|
||||
# be padded the same. Avoid broadcasting on axis where pooling is
|
||||
# skipped.
|
||||
shape = [
|
||||
(a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size)
|
||||
]
|
||||
window_counts = _pool(
|
||||
jnp.ones(shape, inputs.dtype),
|
||||
0.0,
|
||||
lax.add,
|
||||
pool_size,
|
||||
strides,
|
||||
padding,
|
||||
)
|
||||
return pooled / window_counts
|
||||
|
||||
|
||||
def _convert_to_lax_conv_dimension_numbers(
|
||||
num_spatial_dims,
|
||||
data_format="channels_last",
|
||||
transpose=False,
|
||||
):
|
||||
"""Create a `lax.ConvDimensionNumbers` for the given inputs."""
|
||||
num_dims = num_spatial_dims + 2
|
||||
|
||||
if data_format == "channels_last":
|
||||
spatial_dims = tuple(range(1, num_dims - 1))
|
||||
inputs_dn = (0, num_dims - 1) + spatial_dims
|
||||
else:
|
||||
spatial_dims = tuple(range(2, num_dims))
|
||||
inputs_dn = (0, 1) + spatial_dims
|
||||
|
||||
if transpose:
|
||||
kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
|
||||
else:
|
||||
kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
|
||||
|
||||
return lax.ConvDimensionNumbers(
|
||||
lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn
|
||||
)
|
||||
|
||||
|
||||
def conv(
|
||||
inputs,
|
||||
kernel,
|
||||
strides=1,
|
||||
padding="valid",
|
||||
data_format="channel_last",
|
||||
dilation_rate=1,
|
||||
):
|
||||
num_spatial_dims = inputs.ndim - 2
|
||||
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
|
||||
num_spatial_dims,
|
||||
data_format,
|
||||
transpose=False,
|
||||
)
|
||||
strides = _convert_to_spatial_operand(
|
||||
strides,
|
||||
num_spatial_dims,
|
||||
data_format,
|
||||
include_batch_and_channels=False,
|
||||
)
|
||||
dilation_rate = _convert_to_spatial_operand(
|
||||
dilation_rate,
|
||||
num_spatial_dims,
|
||||
data_format,
|
||||
include_batch_and_channels=False,
|
||||
)
|
||||
def conv(inputs, kernel, strides, padding, dilation_rate=None):
|
||||
# TODO: Add missing args.
|
||||
return jax.lax.conv_general_dilated(
|
||||
inputs,
|
||||
kernel,
|
||||
strides,
|
||||
padding,
|
||||
rhs_dilation=dilation_rate,
|
||||
dimension_numbers=dimension_numbers,
|
||||
inputs, kernel, strides, padding, rhs_dilation=dilation_rate
|
||||
)
|
||||
|
||||
|
||||
def depthwise_conv(
|
||||
inputs,
|
||||
kernel,
|
||||
strides=1,
|
||||
padding="valid",
|
||||
data_format="channel_last",
|
||||
dilation_rate=1,
|
||||
):
|
||||
num_spatial_dims = inputs.ndim - 2
|
||||
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
|
||||
num_spatial_dims,
|
||||
data_format,
|
||||
transpose=False,
|
||||
)
|
||||
strides = _convert_to_spatial_operand(
|
||||
strides,
|
||||
num_spatial_dims,
|
||||
data_format,
|
||||
include_batch_and_channels=False,
|
||||
)
|
||||
dilation_rate = _convert_to_spatial_operand(
|
||||
dilation_rate,
|
||||
num_spatial_dims,
|
||||
data_format,
|
||||
include_batch_and_channels=False,
|
||||
)
|
||||
feature_group_count = (
|
||||
inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1]
|
||||
)
|
||||
kernel = jnp.reshape(
|
||||
kernel,
|
||||
kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]),
|
||||
)
|
||||
return jax.lax.conv_general_dilated(
|
||||
inputs,
|
||||
kernel,
|
||||
strides,
|
||||
padding,
|
||||
rhs_dilation=dilation_rate,
|
||||
dimension_numbers=dimension_numbers,
|
||||
feature_group_count=feature_group_count,
|
||||
)
|
||||
def depthwise_conv(inputs, filter, strides, padding):
|
||||
# TODO: Implement `depthwise_conv` with `conv_general_dilated`.
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def separable_conv(
|
||||
inputs,
|
||||
depthwise_kernel,
|
||||
pointwise_kernel,
|
||||
strides=1,
|
||||
padding="valid",
|
||||
data_format="channels_last",
|
||||
dilation_rate=1,
|
||||
inputs, depthwise_kernel, pointwise_kernel, strides, padding
|
||||
):
|
||||
depthwise_conv_output = depthwise_conv(
|
||||
inputs,
|
||||
depthwise_kernel,
|
||||
strides,
|
||||
padding,
|
||||
data_format,
|
||||
dilation_rate,
|
||||
)
|
||||
return conv(
|
||||
depthwise_conv_output,
|
||||
pointwise_kernel,
|
||||
strides=1,
|
||||
padding="valid",
|
||||
data_format=data_format,
|
||||
dilation_rate=dilation_rate,
|
||||
)
|
||||
# TODO: Implement `separable_conv` with `conv_general_dilated`.
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def conv_transpose(
|
||||
inputs,
|
||||
kernel,
|
||||
strides=1,
|
||||
strides,
|
||||
output_padding,
|
||||
padding="valid",
|
||||
output_padding=None,
|
||||
data_format="channels_last",
|
||||
dilation_rate=1,
|
||||
):
|
||||
num_spatial_dims = inputs.ndim - 2
|
||||
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
|
||||
num_spatial_dims,
|
||||
data_format,
|
||||
transpose=False,
|
||||
)
|
||||
strides = _convert_to_spatial_operand(
|
||||
strides,
|
||||
num_spatial_dims,
|
||||
data_format,
|
||||
include_batch_and_channels=False,
|
||||
)
|
||||
dilation_rate = _convert_to_spatial_operand(
|
||||
dilation_rate,
|
||||
num_spatial_dims,
|
||||
data_format,
|
||||
include_batch_and_channels=False,
|
||||
)
|
||||
|
||||
if output_padding is not None:
|
||||
raise ValueError(
|
||||
"Custom `output_padding` is not supported yet, please set "
|
||||
"`output_padding=None`."
|
||||
)
|
||||
padding = padding.upper()
|
||||
return jax.lax.conv_transpose(
|
||||
inputs,
|
||||
kernel,
|
||||
strides,
|
||||
padding,
|
||||
rhs_dilation=dilation_rate,
|
||||
dimension_numbers=dimension_numbers,
|
||||
transpose_kernel=True,
|
||||
)
|
||||
# TODO: Implement `conv_transpose`.
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def one_hot(x, num_classes, axis=-1):
|
||||
|
@ -23,3 +23,11 @@ def mean(x, axis=None, keepdims=False):
|
||||
|
||||
def max(x, axis=None, keepdims=False):
|
||||
return jnp.max(x, axis=axis, keepdims=keepdims)
|
||||
|
||||
|
||||
def ones(shape, dtype="float32"):
|
||||
return jnp.ones(shape, dtype=dtype)
|
||||
|
||||
|
||||
def zeros(shape, dtype="float32"):
|
||||
return jnp.zeros(shape, dtype=dtype)
|
||||
|
@ -1,9 +1,5 @@
|
||||
import tensorflow as tf
|
||||
|
||||
from keras_core.backend.common.backend_utils import (
|
||||
compute_conv_transpose_output_shape,
|
||||
)
|
||||
|
||||
|
||||
def relu(x):
|
||||
return tf.nn.relu(x)
|
||||
@ -256,8 +252,10 @@ def separable_conv(
|
||||
|
||||
if data_format == "channels_last":
|
||||
strides = (1,) + strides + (1,)
|
||||
spatial_start_dim = 1
|
||||
else:
|
||||
strides = (1, 1) + strides
|
||||
spatial_start_dim = 2
|
||||
return tf.nn.separable_conv2d(
|
||||
inputs,
|
||||
depthwise_kernel,
|
||||
@ -269,6 +267,100 @@ def separable_conv(
|
||||
)
|
||||
|
||||
|
||||
def _deconv_output_length(
|
||||
input_length,
|
||||
kernel_size,
|
||||
padding,
|
||||
output_padding=None,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
):
|
||||
"""Determines output length of a transposed convolution given input length.
|
||||
|
||||
Args:
|
||||
input_length: Integer.
|
||||
kernel_size: Integer.
|
||||
padding: one of `"same"` or `"valid"`.
|
||||
output_padding: Integer, amount of padding along the output dimension.
|
||||
Can be set to `None` in which case the output length is inferred.
|
||||
stride: Integer.
|
||||
dilation: Integer.
|
||||
|
||||
Returns:
|
||||
The output length (integer).
|
||||
"""
|
||||
assert padding in {"same", "valid"}
|
||||
if input_length is None:
|
||||
return None
|
||||
|
||||
# Get the dilated kernel size
|
||||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||
|
||||
# Infer length if output padding is None, else compute the exact length
|
||||
if output_padding is None:
|
||||
if padding == "valid":
|
||||
length = input_length * stride + max(kernel_size - stride, 0)
|
||||
else:
|
||||
length = input_length * stride
|
||||
else:
|
||||
if padding == "same":
|
||||
pad = kernel_size // 2
|
||||
else:
|
||||
pad = 0
|
||||
|
||||
length = (
|
||||
(input_length - 1) * stride + kernel_size - 2 * pad + output_padding
|
||||
)
|
||||
return length
|
||||
|
||||
|
||||
def compute_output_shape_conv_transpose(
|
||||
inputs,
|
||||
kernel,
|
||||
strides,
|
||||
padding,
|
||||
output_padding=None,
|
||||
data_format="channels_last",
|
||||
dilation_rate=1,
|
||||
):
|
||||
num_spatial_dims = len(inputs.shape) - 2
|
||||
kernel_spatial_shape = kernel.shape[:-2]
|
||||
|
||||
if isinstance(output_padding, int):
|
||||
output_padding = (output_padding,) * len(kernel_spatial_shape)
|
||||
if isinstance(strides, int):
|
||||
strides = (strides,) * num_spatial_dims
|
||||
if isinstance(dilation_rate, int):
|
||||
dilation_rate = (dilation_rate,) * num_spatial_dims
|
||||
|
||||
if data_format == "channels_last":
|
||||
inputs_spatial_shape = inputs.shape[1:-1]
|
||||
else:
|
||||
inputs_spatial_shape = inputs.shape[2:]
|
||||
|
||||
output_shape = []
|
||||
for i in range(num_spatial_dims):
|
||||
current_output_padding = (
|
||||
None if output_padding is None else output_padding[i]
|
||||
)
|
||||
output_shape.append(
|
||||
_deconv_output_length(
|
||||
inputs_spatial_shape[i],
|
||||
kernel_spatial_shape[i],
|
||||
padding=padding,
|
||||
output_padding=current_output_padding,
|
||||
stride=strides[i],
|
||||
dilation=dilation_rate[0],
|
||||
)
|
||||
)
|
||||
|
||||
if data_format == "channels_last":
|
||||
output_shape = [inputs.shape[0]] + output_shape + [kernel.shape[-2]]
|
||||
else:
|
||||
output_shape = [inputs.shape[0], kernel.shape[-1]] + output_shape
|
||||
return output_shape
|
||||
|
||||
|
||||
def conv_transpose(
|
||||
inputs,
|
||||
kernel,
|
||||
@ -279,7 +371,7 @@ def conv_transpose(
|
||||
dilation_rate=1,
|
||||
):
|
||||
tf_data_format = _convert_data_format(data_format, len(inputs.shape))
|
||||
output_shape = compute_conv_transpose_output_shape(
|
||||
output_shape = compute_output_shape_conv_transpose(
|
||||
inputs,
|
||||
kernel,
|
||||
strides,
|
||||
|
@ -1,3 +1,4 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow.experimental import numpy as tfnp
|
||||
|
||||
|
||||
@ -23,3 +24,13 @@ def mean(x, axis=None, keepdims=False):
|
||||
|
||||
def max(x, axis=None, keepdims=False):
|
||||
return tfnp.max(x, axis=axis, keepdims=keepdims)
|
||||
|
||||
|
||||
def ones(shape, dtype="float32"):
|
||||
with tf.init_scope():
|
||||
return tf.ones(shape, dtype=dtype)
|
||||
|
||||
|
||||
def zeros(shape, dtype="float32"):
|
||||
with tf.init_scope():
|
||||
return tf.zeros(shape, dtype=dtype)
|
||||
|
@ -8,6 +8,7 @@ from keras_core.backend.config import floatx
|
||||
|
||||
def tf_draw_seed(seed):
|
||||
# TF ops only accept int32/64 seeds but our base seed is uint32.
|
||||
with tf.init_scope():
|
||||
return tf.cast(draw_seed(seed), dtype="int32")
|
||||
|
||||
|
||||
@ -34,6 +35,7 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
|
||||
"""
|
||||
dtype = dtype or floatx()
|
||||
seed = tf_draw_seed(seed)
|
||||
with tf.init_scope():
|
||||
return tf.random.stateless_normal(
|
||||
shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
|
||||
)
|
||||
@ -63,6 +65,7 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
|
||||
"""
|
||||
dtype = dtype or floatx()
|
||||
seed = tf_draw_seed(seed)
|
||||
with tf.init_scope():
|
||||
return tf.random.stateless_uniform(
|
||||
shape=shape,
|
||||
minval=minval,
|
||||
@ -95,6 +98,7 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
|
||||
"""
|
||||
dtype = dtype or floatx()
|
||||
seed = tf_draw_seed(seed)
|
||||
with tf.init_scope():
|
||||
return tf.random.stateless_truncated_normal(
|
||||
shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
|
||||
)
|
||||
@ -102,6 +106,7 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
|
||||
|
||||
def dropout(inputs, rate, noise_shape=None, seed=None):
|
||||
seed = tf_draw_seed(seed)
|
||||
with tf.init_scope():
|
||||
return tf.nn.experimental.stateless_dropout(
|
||||
inputs,
|
||||
rate=rate,
|
||||
|
@ -43,12 +43,13 @@ ALL_OBJECTS_DICT.update(
|
||||
|
||||
@keras_core_export("keras_core.initializers.serialize")
|
||||
def serialize(initializer):
|
||||
"""Returns the initializer configuration as a Python dict."""
|
||||
return serialization_lib.serialize_keras_object(initializer)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.initializers.deserialize")
|
||||
def deserialize(config, custom_objects=None):
|
||||
"""Return a Keras initializer object via its config."""
|
||||
"""Returns a Keras initializer object via its configuration."""
|
||||
return serialization_lib.deserialize_keras_object(
|
||||
config,
|
||||
module_objects=ALL_OBJECTS_DICT,
|
||||
@ -58,7 +59,7 @@ def deserialize(config, custom_objects=None):
|
||||
|
||||
@keras_core_export("keras_core.initializers.get")
|
||||
def get(identifier):
|
||||
"""Retrieve a Keras initializer object via an identifier.
|
||||
"""Retrieves a Keras initializer object via an identifier.
|
||||
|
||||
The `identifier` may be the string name of a initializers function or class
|
||||
(case-sensitively).
|
||||
|
@ -470,6 +470,22 @@ class Layer(Operation):
|
||||
"""
|
||||
all_vars = self._variables
|
||||
if len(store.keys()) != len(all_vars):
|
||||
if len(all_vars) == 0 and not self.built:
|
||||
raise ValueError(
|
||||
f"Layer '{self.name}' was never built "
|
||||
"and thus it doesn't have any variables. "
|
||||
f"However the weights file lists {len(store.keys())} "
|
||||
"variables for this layer. In most cases, "
|
||||
"this indicates that you need to implement the "
|
||||
"`def build_from_config(self, config)` method "
|
||||
"on the layer. "
|
||||
"You might also want to implement the method "
|
||||
"that generates the config at saving time, "
|
||||
"`def get_build_config(self)`. "
|
||||
"The method `build_from_config()` is meant "
|
||||
"to create the state "
|
||||
"of the layer (i.e. its variables) upon deserialization.",
|
||||
)
|
||||
raise ValueError(
|
||||
f"Layer '{self.name}' expected {len(all_vars)} variables, "
|
||||
"but received "
|
||||
|
@ -1,15 +1,15 @@
|
||||
from keras_core import backend
|
||||
from keras_core import operations as ops
|
||||
from keras_core.metrics import reduction_metrics
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.losses.loss import squeeze_to_same_rank
|
||||
from keras_core.backend import floatx
|
||||
from keras_core.metrics import reduction_metrics
|
||||
|
||||
|
||||
def accuracy(y_true, y_pred):
|
||||
y_pred = ops.convert_to_tensor(y_pred)
|
||||
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
|
||||
y_true, y_pred = squeeze_to_same_rank(y_true, y_pred)
|
||||
return ops.cast(ops.equal(y_true, y_pred), dtype=floatx())
|
||||
return ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx())
|
||||
|
||||
|
||||
@keras_core_export("keras_core.metrics.Accuracy")
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
@ -205,6 +206,69 @@ class Model(Trainer, Layer):
|
||||
"files."
|
||||
)
|
||||
|
||||
def build_from_config(self, config):
|
||||
def is_shape_tuple(s):
|
||||
return isinstance(s, (list, tuple)) and all(
|
||||
d is None or isinstance(d, int) for d in s
|
||||
)
|
||||
|
||||
if config:
|
||||
failure = False
|
||||
if "input_shape" in config:
|
||||
# Case: all inputs are in the first arg (possibly nested).
|
||||
input_shape = config["input_shape"]
|
||||
if is_shape_tuple(input_shape):
|
||||
input_shape = tuple(input_shape)
|
||||
if isinstance(input_shape, list):
|
||||
input_tensors = [
|
||||
backend.KerasTensor(shape) for shape in input_shape
|
||||
]
|
||||
elif isinstance(input_shape, dict):
|
||||
input_tensors = {
|
||||
k: backend.KerasTensor(shape)
|
||||
for k, shape in input_shape.items()
|
||||
}
|
||||
else:
|
||||
input_tensors = backend.KerasTensor(input_shape)
|
||||
try:
|
||||
self(input_tensors)
|
||||
self._build_shapes_dict = config
|
||||
except:
|
||||
failure = True
|
||||
elif "shapes_dict" in config:
|
||||
# Case: inputs were recorded as multiple keyword arguments.
|
||||
if all(
|
||||
is_shape_tuple(s) for s in config["shapes_dict"].values()
|
||||
):
|
||||
# Case: all input keyword arguments were plain tensors.
|
||||
input_tensors = {
|
||||
k: backend.KerasTensor(v)
|
||||
for k, v in config["shapes_dict"].items()
|
||||
}
|
||||
try:
|
||||
self(**input_tensors)
|
||||
self._build_shapes_dict = config["shapes_dict"]
|
||||
except:
|
||||
failure = True
|
||||
else:
|
||||
# Not supported: nested input keyword arguments.
|
||||
failure = True
|
||||
if failure:
|
||||
warnings.warn(
|
||||
f"Model '{self.name}' had a build config, but the model "
|
||||
"cannot be built automatically in "
|
||||
"`build_from_config(config)`. "
|
||||
"You should implement "
|
||||
"`def build_from_config(self, config)`, "
|
||||
"and you might also want to implement the method "
|
||||
" that generates the config at saving time, "
|
||||
"`def get_build_config(self)`. "
|
||||
"The method `build_from_config()` is meant to "
|
||||
"create the state of the model (i.e. its variables) "
|
||||
"upon deserialization.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def export(self, filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -25,14 +25,13 @@ conv_transpose
|
||||
ctc ??
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core.backend import KerasTensor
|
||||
from keras_core.backend import any_symbolic_tensors
|
||||
from keras_core.backend.common.backend_utils import (
|
||||
compute_conv_transpose_output_shape,
|
||||
)
|
||||
from keras_core.operations.operation import Operation
|
||||
|
||||
|
||||
@ -902,7 +901,7 @@ class ConvTranspose(Operation):
|
||||
)
|
||||
|
||||
def compute_output_spec(self, inputs, kernel):
|
||||
output_shape = compute_conv_transpose_output_shape(
|
||||
output_shape = backend.nn.compute_output_shape_conv_transpose(
|
||||
inputs,
|
||||
kernel,
|
||||
self.strides,
|
||||
@ -935,7 +934,7 @@ def conv_transpose(
|
||||
`data_format="channels_first"`. Pooling happens over the spatial
|
||||
dimensions only.
|
||||
kernel: Tensor of rank N+2. `kernel` has shape
|
||||
[kernel_spatial_shape, num_output_channels, num_input_channels],
|
||||
[kernel_spatial_shape, num_input_channels, num_output_channels],
|
||||
`num_input_channels` should match the number of channels in
|
||||
`inputs`.
|
||||
strides: int or int tuple/list of `len(inputs_spatial_shape)`,
|
||||
|
@ -281,6 +281,10 @@ class NNOpsDynamicShapeTest(testing.TestCase):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
backend() != "tensorflow",
|
||||
reason="Not have other backend support yet.",
|
||||
)
|
||||
class NNOpsStaticShapeTest(testing.TestCase):
|
||||
def test_relu(self):
|
||||
x = KerasTensor([1, 2, 3])
|
||||
@ -543,6 +547,10 @@ class NNOpsStaticShapeTest(testing.TestCase):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
backend() != "tensorflow",
|
||||
reason="Not have other backend support yet.",
|
||||
)
|
||||
class NNOpsCorrectnessTest(testing.TestCase):
|
||||
def test_relu(self):
|
||||
x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)
|
||||
@ -769,7 +777,7 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
||||
expected = tf.nn.conv3d(
|
||||
inputs_3d, kernel, (1, 1, 1, 1, 1), padding="VALID"
|
||||
)
|
||||
self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
outputs = knn.conv(
|
||||
inputs_3d,
|
||||
@ -785,13 +793,13 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
||||
padding="VALID",
|
||||
dilations=(1, 1, 1, 1, 1),
|
||||
)
|
||||
self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
outputs = knn.conv(inputs_3d, kernel, 2, padding="same")
|
||||
expected = tf.nn.conv3d(
|
||||
inputs_3d, kernel, (1, 2, 2, 2, 1), padding="SAME"
|
||||
)
|
||||
self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
def test_depthwise_conv(self):
|
||||
# Test 2D conv.
|
||||
@ -904,6 +912,14 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
||||
)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
outputs = knn.conv_transpose(
|
||||
inputs_1d, kernel, 5, output_padding=4, padding="valid"
|
||||
)
|
||||
expected = tf.nn.conv_transpose(
|
||||
inputs_1d, kernel, [2, 21, 5], 5, padding="VALID"
|
||||
)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
# Test 2D conv.
|
||||
inputs_2d = np.arange(96, dtype=float).reshape([2, 4, 4, 3])
|
||||
kernel = np.arange(60, dtype=float).reshape([2, 2, 5, 3])
|
||||
@ -919,3 +935,21 @@ class NNOpsCorrectnessTest(testing.TestCase):
|
||||
inputs_2d, kernel, [2, 8, 8, 5], 2, padding="SAME"
|
||||
)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
||||
outputs = knn.conv_transpose(
|
||||
inputs_2d,
|
||||
kernel,
|
||||
5,
|
||||
output_padding=4,
|
||||
padding="valid",
|
||||
dilation_rate=(1, 1),
|
||||
)
|
||||
expected = tf.nn.conv_transpose(
|
||||
inputs_2d,
|
||||
kernel,
|
||||
[2, 21, 21, 5],
|
||||
5,
|
||||
padding="VALID",
|
||||
dilations=(1, 1),
|
||||
)
|
||||
self.assertAllClose(outputs, expected)
|
||||
|
@ -2923,26 +2923,26 @@ def sum(x, axis=None, keepdims=False):
|
||||
|
||||
class Zeros(Operation):
|
||||
def call(self, shape, dtype="float32"):
|
||||
return backend.execute("zeros", shape, dtype)
|
||||
return backend.numpy.zeros(shape, dtype=dtype)
|
||||
|
||||
def compute_output_spec(self, shape, dtype="float32"):
|
||||
return KerasTensor(shape, dtype=dtype)
|
||||
|
||||
|
||||
def zeros(shape, dtype="float32"):
|
||||
return backend.execute("zeros", shape, dtype)
|
||||
return backend.numpy.zeros(shape, dtype=dtype)
|
||||
|
||||
|
||||
class Ones(Operation):
|
||||
def call(self, shape, dtype="float32"):
|
||||
return backend.execute("ones", shape, dtype)
|
||||
return backend.numpy.ones(shape, dtype=dtype)
|
||||
|
||||
def compute_output_spec(self, shape, dtype="float32"):
|
||||
return KerasTensor(shape, dtype=dtype)
|
||||
|
||||
|
||||
def ones(shape, dtype="float32"):
|
||||
return backend.execute("ones", shape, dtype)
|
||||
return backend.numpy.ones(shape, dtype=dtype)
|
||||
|
||||
|
||||
class Eye(Operation):
|
||||
|
@ -1,3 +1,79 @@
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.optimizers.adam import Adam
|
||||
from keras_core.optimizers.optimizer import Optimizer
|
||||
from keras_core.optimizers.sgd import SGD
|
||||
from keras_core.saving import serialization_lib
|
||||
|
||||
ALL_OBJECTS = {
|
||||
Optimizer,
|
||||
Adam,
|
||||
SGD,
|
||||
}
|
||||
ALL_OBJECTS_DICT = {cls.__name__.lower(): cls for cls in ALL_OBJECTS}
|
||||
|
||||
|
||||
@keras_core_export("keras_core.optimizers.serialize")
|
||||
def serialize(optimizer):
|
||||
"""Returns the optimizer configuration as a Python dict.
|
||||
|
||||
Args:
|
||||
optimizer: An `Optimizer` instance to serialize.
|
||||
|
||||
Returns:
|
||||
Python dict which contains the configuration of the optimizer.
|
||||
"""
|
||||
return serialization_lib.serialize_keras_object(optimizer)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.optimizers.deserialize")
|
||||
def deserialize(config, custom_objects=None):
|
||||
"""Returns a Keras optimizer object via its configuration.
|
||||
|
||||
Args:
|
||||
config: Optimizer configuration dictionary.
|
||||
custom_objects: Optional dictionary mapping names (strings) to custom
|
||||
objects (classes and functions) to be considered during
|
||||
deserialization.
|
||||
|
||||
Returns:
|
||||
A Keras Optimizer instance.
|
||||
"""
|
||||
# Make deserialization case-insensitive for built-in optimizers.
|
||||
if config["class_name"].lower() in ALL_OBJECTS_DICT:
|
||||
config["class_name"] = config["class_name"].lower()
|
||||
|
||||
print("deserialize:", config)
|
||||
return serialization_lib.deserialize_keras_object(
|
||||
config,
|
||||
module_objects=ALL_OBJECTS_DICT,
|
||||
custom_objects=custom_objects,
|
||||
)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.optimizers.get")
|
||||
def get(identifier):
|
||||
"""Retrieves a Keras Optimizer instance.
|
||||
|
||||
Args:
|
||||
identifier: Optimizer identifier, one of:
|
||||
- String: name of an optimizer
|
||||
- Dictionary: configuration dictionary.
|
||||
- Keras Optimizer instance (it will be returned unchanged).
|
||||
|
||||
Returns:
|
||||
A Keras Optimizer instance.
|
||||
"""
|
||||
print("call get with", identifier)
|
||||
if isinstance(identifier, Optimizer):
|
||||
return identifier
|
||||
elif isinstance(identifier, dict):
|
||||
return deserialize(identifier)
|
||||
elif isinstance(identifier, str):
|
||||
config = {"class_name": identifier, "config": {}}
|
||||
opt = deserialize(config)
|
||||
print(opt)
|
||||
return opt
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Could not interpret optimizer identifier: {identifier}"
|
||||
)
|
||||
|
@ -36,6 +36,8 @@ _ASSETS_DIRNAME = "assets"
|
||||
|
||||
ATTR_SKIPLIST = frozenset(
|
||||
{
|
||||
"_operations",
|
||||
"_layers",
|
||||
"_functional",
|
||||
"_losses",
|
||||
"_inbound_nodes",
|
||||
@ -108,7 +110,7 @@ def save_model(model, filepath, weights_format="h5"):
|
||||
zip_filepath = os.path.join(get_temp_dir(), "tmp_model.keras")
|
||||
else:
|
||||
zip_filepath = filepath
|
||||
try:
|
||||
|
||||
with zipfile.ZipFile(zip_filepath, "w") as zf:
|
||||
with zf.open(_METADATA_FILENAME, "w") as f:
|
||||
f.write(metadata_json.encode())
|
||||
@ -116,9 +118,7 @@ def save_model(model, filepath, weights_format="h5"):
|
||||
f.write(config_json.encode())
|
||||
|
||||
if weights_format == "h5":
|
||||
weights_store = H5IOStore(
|
||||
_VARS_FNAME + ".h5", archive=zf, mode="w"
|
||||
)
|
||||
weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="w")
|
||||
elif weights_format == "npz":
|
||||
weights_store = NpzIOStore(
|
||||
_VARS_FNAME + ".npz", archive=zf, mode="w"
|
||||
@ -147,8 +147,6 @@ def save_model(model, filepath, weights_format="h5"):
|
||||
# writing to GCS. Hence writing to local and copying to filepath.
|
||||
gfile.copy(zip_filepath, filepath, overwrite=True)
|
||||
os.remove(zip_filepath)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
||||
@ -161,7 +159,6 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
||||
f"Received: filepath={filepath}"
|
||||
)
|
||||
|
||||
try:
|
||||
with gfile.GFile(filepath, mode="r+b") as gfile_handle, zipfile.ZipFile(
|
||||
gfile_handle, "r"
|
||||
) as zf:
|
||||
@ -182,9 +179,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
||||
|
||||
all_filenames = zf.namelist()
|
||||
if _VARS_FNAME + ".h5" in all_filenames:
|
||||
weights_store = H5IOStore(
|
||||
_VARS_FNAME + ".h5", archive=zf, mode="r"
|
||||
)
|
||||
weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="r")
|
||||
elif _VARS_FNAME + ".npz" in all_filenames:
|
||||
weights_store = NpzIOStore(
|
||||
_VARS_FNAME + ".npz", archive=zf, mode="r"
|
||||
@ -209,10 +204,6 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
||||
weights_store.close()
|
||||
if asset_store:
|
||||
asset_store.close()
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@ -12,7 +12,7 @@ class TestCase(unittest.TestCase):
|
||||
|
||||
def get_temp_dir(self):
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: os.rmdir(temp_dir))
|
||||
self.addCleanup(lambda: shutil.rmtree(temp_dir))
|
||||
return temp_dir
|
||||
|
||||
def assertAllClose(self, x1, x2, atol=1e-7, rtol=1e-7):
|
||||
|
@ -3,6 +3,7 @@ import warnings
|
||||
from keras_core import backend
|
||||
from keras_core import metrics as metrics_module
|
||||
from keras_core import operations as ops
|
||||
from keras_core import optimizers
|
||||
from keras_core.saving import serialization_lib
|
||||
from keras_core.trainers.compile_utils import CompileLoss
|
||||
from keras_core.trainers.compile_utils import CompileMetrics
|
||||
@ -26,8 +27,7 @@ class Trainer:
|
||||
run_eagerly=False,
|
||||
jit_compile=True,
|
||||
):
|
||||
# TODO: get from module
|
||||
self.optimizer = optimizer
|
||||
self.optimizer = optimizers.get(optimizer)
|
||||
if loss is not None:
|
||||
self._compile_loss = CompileLoss(loss, loss_weights)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user