diff --git a/.vscode/settings.json b/.vscode/settings.json index f24c7fff2..d79f9cdaf 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,4 +7,4 @@ "python.linting.pylintEnabled": false, "python.linting.flake8Enabled": true, "python.linting.enabled": true, -} \ No newline at end of file +} diff --git a/keras_core/backend/common/__init__.py b/keras_core/backend/common/__init__.py index 24df04c67..e6b051a1a 100644 --- a/keras_core/backend/common/__init__.py +++ b/keras_core/backend/common/__init__.py @@ -1,4 +1,3 @@ -from keras_core.backend.common import backend_utils from keras_core.backend.common import random from keras_core.backend.common.variables import KerasVariable from keras_core.backend.common.variables import standardize_dtype diff --git a/keras_core/backend/common/backend_utils.py b/keras_core/backend/common/backend_utils.py deleted file mode 100644 index 7e846232a..000000000 --- a/keras_core/backend/common/backend_utils.py +++ /dev/null @@ -1,79 +0,0 @@ -def _compute_conv_transpose_output_length( - input_length, - kernel_size, - padding, - output_padding=None, - stride=1, - dilation=1, -): - """Computes output size of a transposed convolution given input size.""" - assert padding in {"same", "valid"} - if input_length is None: - return None - - # Get the dilated kernel size - kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) - - # Infer length if output padding is None, else compute the exact length - if output_padding is None: - if padding == "valid": - length = input_length * stride + max(kernel_size - stride, 0) - else: - length = input_length * stride - else: - if padding == "same": - pad = kernel_size // 2 - else: - pad = 0 - - length = ( - (input_length - 1) * stride + kernel_size - 2 * pad + output_padding - ) - return length - - -def compute_conv_transpose_output_shape( - inputs, - kernel, - strides, - padding, - output_padding=None, - data_format="channels_last", - dilation_rate=1, -): - num_spatial_dims = len(inputs.shape) - 2 - kernel_spatial_shape = kernel.shape[:-2] - - if isinstance(output_padding, int): - output_padding = (output_padding,) * len(kernel_spatial_shape) - if isinstance(strides, int): - strides = (strides,) * num_spatial_dims - if isinstance(dilation_rate, int): - dilation_rate = (dilation_rate,) * num_spatial_dims - - if data_format == "channels_last": - inputs_spatial_shape = inputs.shape[1:-1] - else: - inputs_spatial_shape = inputs.shape[2:] - - output_shape = [] - for i in range(num_spatial_dims): - current_output_padding = ( - None if output_padding is None else output_padding[i] - ) - output_shape.append( - _compute_conv_transpose_output_length( - inputs_spatial_shape[i], - kernel_spatial_shape[i], - padding=padding, - output_padding=current_output_padding, - stride=strides[i], - dilation=dilation_rate[0], - ) - ) - - if data_format == "channels_last": - output_shape = [inputs.shape[0]] + output_shape + [kernel.shape[-2]] - else: - output_shape = [inputs.shape[0], kernel.shape[-1]] + output_shape - return output_shape diff --git a/keras_core/backend/jax/nn.py b/keras_core/backend/jax/nn.py index 22bcba23d..eb0ebff1a 100644 --- a/keras_core/backend/jax/nn.py +++ b/keras_core/backend/jax/nn.py @@ -1,8 +1,6 @@ import jax -import jax.numpy as jnp -import numpy as np -from jax import lax from jax import nn as jnn +from jax import numpy as jnp def relu(x): @@ -65,293 +63,46 @@ def log_softmax(x, axis=-1): return jnn.log_softmax(x, axis=axis) -def _convert_to_spatial_operand( - x, - num_spatial_dims, - data_format="channels_last", - include_batch_and_channels=True, -): - # Helper function that converts an operand to a spatial operand. - x = (x,) * num_spatial_dims if isinstance(x, int) else x - if not include_batch_and_channels: - return x - if data_format == "channels_last": - x = (1,) + x + (1,) - else: - x = (1,) + (1,) + x - return x +def max_pool(inputs, pool_size, strides, padding): + # TODO: Implement `max_pool` with JAX ops. + raise NotImplementedError -def _pool( - inputs, - initial_value, - reduce_fn, - pool_size, - strides=None, - padding="valid", -): - """Helper function to define pooling functions. - - Args: - inputs: input data of shape `N+2`. - initial_value: the initial value for the reduction. - reduce_fn: a reduce function of the form `(T, T) -> T`. - pool_size: a sequence of `N` integers, representing the window size to - reduce over. - strides: a sequence of `N` integers, representing the inter-window - strides (default: `(1, ..., 1)`). - padding: either the string `same` or `valid`. - - Returns: - The output of the reduction for each window slice. - """ - if padding not in ("same", "valid"): - raise ValueError( - f"Invalid padding '{padding}', must be 'same' or 'valid'." - ) - padding = padding.upper() - return lax.reduce_window( - inputs, - initial_value, - reduce_fn, - pool_size, - strides, - padding, - ) +def average_pool(inputs, pool_size, strides, padding): + # TODO: Implement `average_pool` with JAX ops. + raise NotImplementedError -def max_pool( - inputs, - pool_size, - strides=None, - padding="valid", - data_format="channels_last", -): - num_spatial_dims = inputs.ndim - 2 - pool_size = _convert_to_spatial_operand( - pool_size, num_spatial_dims, data_format - ) - strides = pool_size if strides is None else strides - strides = _convert_to_spatial_operand( - strides, num_spatial_dims, data_format - ) - return _pool(inputs, -jnp.inf, lax.max, pool_size, strides, padding) - - -def average_pool( - inputs, - pool_size, - strides, - padding, - data_format="channels_last", -): - num_spatial_dims = inputs.ndim - 2 - pool_size = _convert_to_spatial_operand( - pool_size, num_spatial_dims, data_format - ) - strides = pool_size if strides is None else strides - strides = _convert_to_spatial_operand( - strides, num_spatial_dims, data_format - ) - - pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding) - if padding == "valid": - # Avoid the extra reduce_window. - return pooled / np.prod(pool_size) - else: - # Count the number of valid entries at each input point, then use that - # for computing average. Assumes that any two arrays of same shape will - # be padded the same. Avoid broadcasting on axis where pooling is - # skipped. - shape = [ - (a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size) - ] - window_counts = _pool( - jnp.ones(shape, inputs.dtype), - 0.0, - lax.add, - pool_size, - strides, - padding, - ) - return pooled / window_counts - - -def _convert_to_lax_conv_dimension_numbers( - num_spatial_dims, - data_format="channels_last", - transpose=False, -): - """Create a `lax.ConvDimensionNumbers` for the given inputs.""" - num_dims = num_spatial_dims + 2 - - if data_format == "channels_last": - spatial_dims = tuple(range(1, num_dims - 1)) - inputs_dn = (0, num_dims - 1) + spatial_dims - else: - spatial_dims = tuple(range(2, num_dims)) - inputs_dn = (0, 1) + spatial_dims - - if transpose: - kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2)) - else: - kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2)) - - return lax.ConvDimensionNumbers( - lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn - ) - - -def conv( - inputs, - kernel, - strides=1, - padding="valid", - data_format="channel_last", - dilation_rate=1, -): - num_spatial_dims = inputs.ndim - 2 - dimension_numbers = _convert_to_lax_conv_dimension_numbers( - num_spatial_dims, - data_format, - transpose=False, - ) - strides = _convert_to_spatial_operand( - strides, - num_spatial_dims, - data_format, - include_batch_and_channels=False, - ) - dilation_rate = _convert_to_spatial_operand( - dilation_rate, - num_spatial_dims, - data_format, - include_batch_and_channels=False, - ) +def conv(inputs, kernel, strides, padding, dilation_rate=None): + # TODO: Add missing args. return jax.lax.conv_general_dilated( - inputs, - kernel, - strides, - padding, - rhs_dilation=dilation_rate, - dimension_numbers=dimension_numbers, + inputs, kernel, strides, padding, rhs_dilation=dilation_rate ) -def depthwise_conv( - inputs, - kernel, - strides=1, - padding="valid", - data_format="channel_last", - dilation_rate=1, -): - num_spatial_dims = inputs.ndim - 2 - dimension_numbers = _convert_to_lax_conv_dimension_numbers( - num_spatial_dims, - data_format, - transpose=False, - ) - strides = _convert_to_spatial_operand( - strides, - num_spatial_dims, - data_format, - include_batch_and_channels=False, - ) - dilation_rate = _convert_to_spatial_operand( - dilation_rate, - num_spatial_dims, - data_format, - include_batch_and_channels=False, - ) - feature_group_count = ( - inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1] - ) - kernel = jnp.reshape( - kernel, - kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]), - ) - return jax.lax.conv_general_dilated( - inputs, - kernel, - strides, - padding, - rhs_dilation=dilation_rate, - dimension_numbers=dimension_numbers, - feature_group_count=feature_group_count, - ) +def depthwise_conv(inputs, filter, strides, padding): + # TODO: Implement `depthwise_conv` with `conv_general_dilated`. + raise NotImplementedError def separable_conv( - inputs, - depthwise_kernel, - pointwise_kernel, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, + inputs, depthwise_kernel, pointwise_kernel, strides, padding ): - depthwise_conv_output = depthwise_conv( - inputs, - depthwise_kernel, - strides, - padding, - data_format, - dilation_rate, - ) - return conv( - depthwise_conv_output, - pointwise_kernel, - strides=1, - padding="valid", - data_format=data_format, - dilation_rate=dilation_rate, - ) + # TODO: Implement `separable_conv` with `conv_general_dilated`. + raise NotImplementedError def conv_transpose( inputs, kernel, - strides=1, + strides, + output_padding, padding="valid", - output_padding=None, data_format="channels_last", dilation_rate=1, ): - num_spatial_dims = inputs.ndim - 2 - dimension_numbers = _convert_to_lax_conv_dimension_numbers( - num_spatial_dims, - data_format, - transpose=False, - ) - strides = _convert_to_spatial_operand( - strides, - num_spatial_dims, - data_format, - include_batch_and_channels=False, - ) - dilation_rate = _convert_to_spatial_operand( - dilation_rate, - num_spatial_dims, - data_format, - include_batch_and_channels=False, - ) - - if output_padding is not None: - raise ValueError( - "Custom `output_padding` is not supported yet, please set " - "`output_padding=None`." - ) - padding = padding.upper() - return jax.lax.conv_transpose( - inputs, - kernel, - strides, - padding, - rhs_dilation=dilation_rate, - dimension_numbers=dimension_numbers, - transpose_kernel=True, - ) + # TODO: Implement `conv_transpose`. + raise NotImplementedError def one_hot(x, num_classes, axis=-1): diff --git a/keras_core/backend/jax/numpy.py b/keras_core/backend/jax/numpy.py index 7761ff3bc..b1519b409 100644 --- a/keras_core/backend/jax/numpy.py +++ b/keras_core/backend/jax/numpy.py @@ -23,3 +23,11 @@ def mean(x, axis=None, keepdims=False): def max(x, axis=None, keepdims=False): return jnp.max(x, axis=axis, keepdims=keepdims) + + +def ones(shape, dtype="float32"): + return jnp.ones(shape, dtype=dtype) + + +def zeros(shape, dtype="float32"): + return jnp.zeros(shape, dtype=dtype) diff --git a/keras_core/backend/tensorflow/nn.py b/keras_core/backend/tensorflow/nn.py index 3fb7d92b7..a2e782c00 100644 --- a/keras_core/backend/tensorflow/nn.py +++ b/keras_core/backend/tensorflow/nn.py @@ -1,9 +1,5 @@ import tensorflow as tf -from keras_core.backend.common.backend_utils import ( - compute_conv_transpose_output_shape, -) - def relu(x): return tf.nn.relu(x) @@ -256,8 +252,10 @@ def separable_conv( if data_format == "channels_last": strides = (1,) + strides + (1,) + spatial_start_dim = 1 else: strides = (1, 1) + strides + spatial_start_dim = 2 return tf.nn.separable_conv2d( inputs, depthwise_kernel, @@ -269,6 +267,100 @@ def separable_conv( ) +def _deconv_output_length( + input_length, + kernel_size, + padding, + output_padding=None, + stride=1, + dilation=1, +): + """Determines output length of a transposed convolution given input length. + + Args: + input_length: Integer. + kernel_size: Integer. + padding: one of `"same"` or `"valid"`. + output_padding: Integer, amount of padding along the output dimension. + Can be set to `None` in which case the output length is inferred. + stride: Integer. + dilation: Integer. + + Returns: + The output length (integer). + """ + assert padding in {"same", "valid"} + if input_length is None: + return None + + # Get the dilated kernel size + kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) + + # Infer length if output padding is None, else compute the exact length + if output_padding is None: + if padding == "valid": + length = input_length * stride + max(kernel_size - stride, 0) + else: + length = input_length * stride + else: + if padding == "same": + pad = kernel_size // 2 + else: + pad = 0 + + length = ( + (input_length - 1) * stride + kernel_size - 2 * pad + output_padding + ) + return length + + +def compute_output_shape_conv_transpose( + inputs, + kernel, + strides, + padding, + output_padding=None, + data_format="channels_last", + dilation_rate=1, +): + num_spatial_dims = len(inputs.shape) - 2 + kernel_spatial_shape = kernel.shape[:-2] + + if isinstance(output_padding, int): + output_padding = (output_padding,) * len(kernel_spatial_shape) + if isinstance(strides, int): + strides = (strides,) * num_spatial_dims + if isinstance(dilation_rate, int): + dilation_rate = (dilation_rate,) * num_spatial_dims + + if data_format == "channels_last": + inputs_spatial_shape = inputs.shape[1:-1] + else: + inputs_spatial_shape = inputs.shape[2:] + + output_shape = [] + for i in range(num_spatial_dims): + current_output_padding = ( + None if output_padding is None else output_padding[i] + ) + output_shape.append( + _deconv_output_length( + inputs_spatial_shape[i], + kernel_spatial_shape[i], + padding=padding, + output_padding=current_output_padding, + stride=strides[i], + dilation=dilation_rate[0], + ) + ) + + if data_format == "channels_last": + output_shape = [inputs.shape[0]] + output_shape + [kernel.shape[-2]] + else: + output_shape = [inputs.shape[0], kernel.shape[-1]] + output_shape + return output_shape + + def conv_transpose( inputs, kernel, @@ -279,7 +371,7 @@ def conv_transpose( dilation_rate=1, ): tf_data_format = _convert_data_format(data_format, len(inputs.shape)) - output_shape = compute_conv_transpose_output_shape( + output_shape = compute_output_shape_conv_transpose( inputs, kernel, strides, diff --git a/keras_core/backend/tensorflow/numpy.py b/keras_core/backend/tensorflow/numpy.py index dfba6c067..b1d7dadcd 100644 --- a/keras_core/backend/tensorflow/numpy.py +++ b/keras_core/backend/tensorflow/numpy.py @@ -1,3 +1,4 @@ +import tensorflow as tf from tensorflow.experimental import numpy as tfnp @@ -23,3 +24,13 @@ def mean(x, axis=None, keepdims=False): def max(x, axis=None, keepdims=False): return tfnp.max(x, axis=axis, keepdims=keepdims) + + +def ones(shape, dtype="float32"): + with tf.init_scope(): + return tf.ones(shape, dtype=dtype) + + +def zeros(shape, dtype="float32"): + with tf.init_scope(): + return tf.zeros(shape, dtype=dtype) diff --git a/keras_core/backend/tensorflow/random.py b/keras_core/backend/tensorflow/random.py index c0dea4268..456e0af66 100644 --- a/keras_core/backend/tensorflow/random.py +++ b/keras_core/backend/tensorflow/random.py @@ -8,7 +8,8 @@ from keras_core.backend.config import floatx def tf_draw_seed(seed): # TF ops only accept int32/64 seeds but our base seed is uint32. - return tf.cast(draw_seed(seed), dtype="int32") + with tf.init_scope(): + return tf.cast(draw_seed(seed), dtype="int32") def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): @@ -34,9 +35,10 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): """ dtype = dtype or floatx() seed = tf_draw_seed(seed) - return tf.random.stateless_normal( - shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed - ) + with tf.init_scope(): + return tf.random.stateless_normal( + shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed + ) def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): @@ -63,13 +65,14 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): """ dtype = dtype or floatx() seed = tf_draw_seed(seed) - return tf.random.stateless_uniform( - shape=shape, - minval=minval, - maxval=maxval, - dtype=dtype, - seed=seed, - ) + with tf.init_scope(): + return tf.random.stateless_uniform( + shape=shape, + minval=minval, + maxval=maxval, + dtype=dtype, + seed=seed, + ) def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): @@ -95,16 +98,18 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): """ dtype = dtype or floatx() seed = tf_draw_seed(seed) - return tf.random.stateless_truncated_normal( - shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed - ) + with tf.init_scope(): + return tf.random.stateless_truncated_normal( + shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed + ) def dropout(inputs, rate, noise_shape=None, seed=None): seed = tf_draw_seed(seed) - return tf.nn.experimental.stateless_dropout( - inputs, - rate=rate, - noise_shape=noise_shape, - seed=seed, - ) + with tf.init_scope(): + return tf.nn.experimental.stateless_dropout( + inputs, + rate=rate, + noise_shape=noise_shape, + seed=seed, + ) diff --git a/keras_core/initializers/__init__.py b/keras_core/initializers/__init__.py index 755b29789..87d295431 100644 --- a/keras_core/initializers/__init__.py +++ b/keras_core/initializers/__init__.py @@ -43,12 +43,13 @@ ALL_OBJECTS_DICT.update( @keras_core_export("keras_core.initializers.serialize") def serialize(initializer): + """Returns the initializer configuration as a Python dict.""" return serialization_lib.serialize_keras_object(initializer) @keras_core_export("keras_core.initializers.deserialize") def deserialize(config, custom_objects=None): - """Return a Keras initializer object via its config.""" + """Returns a Keras initializer object via its configuration.""" return serialization_lib.deserialize_keras_object( config, module_objects=ALL_OBJECTS_DICT, @@ -58,7 +59,7 @@ def deserialize(config, custom_objects=None): @keras_core_export("keras_core.initializers.get") def get(identifier): - """Retrieve a Keras initializer object via an identifier. + """Retrieves a Keras initializer object via an identifier. The `identifier` may be the string name of a initializers function or class (case-sensitively). diff --git a/keras_core/layers/layer.py b/keras_core/layers/layer.py index 98282dd89..4fe2617bb 100644 --- a/keras_core/layers/layer.py +++ b/keras_core/layers/layer.py @@ -470,6 +470,22 @@ class Layer(Operation): """ all_vars = self._variables if len(store.keys()) != len(all_vars): + if len(all_vars) == 0 and not self.built: + raise ValueError( + f"Layer '{self.name}' was never built " + "and thus it doesn't have any variables. " + f"However the weights file lists {len(store.keys())} " + "variables for this layer. In most cases, " + "this indicates that you need to implement the " + "`def build_from_config(self, config)` method " + "on the layer. " + "You might also want to implement the method " + "that generates the config at saving time, " + "`def get_build_config(self)`. " + "The method `build_from_config()` is meant " + "to create the state " + "of the layer (i.e. its variables) upon deserialization.", + ) raise ValueError( f"Layer '{self.name}' expected {len(all_vars)} variables, " "but received " diff --git a/keras_core/metrics/accuracy_metrics.py b/keras_core/metrics/accuracy_metrics.py index 1d7c9b004..f79ad888f 100644 --- a/keras_core/metrics/accuracy_metrics.py +++ b/keras_core/metrics/accuracy_metrics.py @@ -1,15 +1,15 @@ +from keras_core import backend from keras_core import operations as ops -from keras_core.metrics import reduction_metrics from keras_core.api_export import keras_core_export from keras_core.losses.loss import squeeze_to_same_rank -from keras_core.backend import floatx +from keras_core.metrics import reduction_metrics def accuracy(y_true, y_pred): y_pred = ops.convert_to_tensor(y_pred) y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) y_true, y_pred = squeeze_to_same_rank(y_true, y_pred) - return ops.cast(ops.equal(y_true, y_pred), dtype=floatx()) + return ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx()) @keras_core_export("keras_core.metrics.Accuracy") diff --git a/keras_core/models/model.py b/keras_core/models/model.py index ffc7c322d..56ff5767c 100644 --- a/keras_core/models/model.py +++ b/keras_core/models/model.py @@ -1,4 +1,5 @@ import os +import warnings from keras_core import backend from keras_core.api_export import keras_core_export @@ -205,6 +206,69 @@ class Model(Trainer, Layer): "files." ) + def build_from_config(self, config): + def is_shape_tuple(s): + return isinstance(s, (list, tuple)) and all( + d is None or isinstance(d, int) for d in s + ) + + if config: + failure = False + if "input_shape" in config: + # Case: all inputs are in the first arg (possibly nested). + input_shape = config["input_shape"] + if is_shape_tuple(input_shape): + input_shape = tuple(input_shape) + if isinstance(input_shape, list): + input_tensors = [ + backend.KerasTensor(shape) for shape in input_shape + ] + elif isinstance(input_shape, dict): + input_tensors = { + k: backend.KerasTensor(shape) + for k, shape in input_shape.items() + } + else: + input_tensors = backend.KerasTensor(input_shape) + try: + self(input_tensors) + self._build_shapes_dict = config + except: + failure = True + elif "shapes_dict" in config: + # Case: inputs were recorded as multiple keyword arguments. + if all( + is_shape_tuple(s) for s in config["shapes_dict"].values() + ): + # Case: all input keyword arguments were plain tensors. + input_tensors = { + k: backend.KerasTensor(v) + for k, v in config["shapes_dict"].items() + } + try: + self(**input_tensors) + self._build_shapes_dict = config["shapes_dict"] + except: + failure = True + else: + # Not supported: nested input keyword arguments. + failure = True + if failure: + warnings.warn( + f"Model '{self.name}' had a build config, but the model " + "cannot be built automatically in " + "`build_from_config(config)`. " + "You should implement " + "`def build_from_config(self, config)`, " + "and you might also want to implement the method " + " that generates the config at saving time, " + "`def get_build_config(self)`. " + "The method `build_from_config()` is meant to " + "create the state of the model (i.e. its variables) " + "upon deserialization.", + stacklevel=2, + ) + def export(self, filepath): raise NotImplementedError diff --git a/keras_core/operations/nn.py b/keras_core/operations/nn.py index 7ffd2cd8f..b3411f739 100644 --- a/keras_core/operations/nn.py +++ b/keras_core/operations/nn.py @@ -25,14 +25,13 @@ conv_transpose ctc ?? """ +import math + import numpy as np from keras_core import backend from keras_core.backend import KerasTensor from keras_core.backend import any_symbolic_tensors -from keras_core.backend.common.backend_utils import ( - compute_conv_transpose_output_shape, -) from keras_core.operations.operation import Operation @@ -902,7 +901,7 @@ class ConvTranspose(Operation): ) def compute_output_spec(self, inputs, kernel): - output_shape = compute_conv_transpose_output_shape( + output_shape = backend.nn.compute_output_shape_conv_transpose( inputs, kernel, self.strides, @@ -935,7 +934,7 @@ def conv_transpose( `data_format="channels_first"`. Pooling happens over the spatial dimensions only. kernel: Tensor of rank N+2. `kernel` has shape - [kernel_spatial_shape, num_output_channels, num_input_channels], + [kernel_spatial_shape, num_input_channels, num_output_channels], `num_input_channels` should match the number of channels in `inputs`. strides: int or int tuple/list of `len(inputs_spatial_shape)`, diff --git a/keras_core/operations/nn_test.py b/keras_core/operations/nn_test.py index decfcc09b..fbe3a05e6 100644 --- a/keras_core/operations/nn_test.py +++ b/keras_core/operations/nn_test.py @@ -281,6 +281,10 @@ class NNOpsDynamicShapeTest(testing.TestCase): ) +@pytest.mark.skipif( + backend() != "tensorflow", + reason="Not have other backend support yet.", +) class NNOpsStaticShapeTest(testing.TestCase): def test_relu(self): x = KerasTensor([1, 2, 3]) @@ -543,6 +547,10 @@ class NNOpsStaticShapeTest(testing.TestCase): ) +@pytest.mark.skipif( + backend() != "tensorflow", + reason="Not have other backend support yet.", +) class NNOpsCorrectnessTest(testing.TestCase): def test_relu(self): x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) @@ -769,7 +777,7 @@ class NNOpsCorrectnessTest(testing.TestCase): expected = tf.nn.conv3d( inputs_3d, kernel, (1, 1, 1, 1, 1), padding="VALID" ) - self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) + self.assertAllClose(outputs, expected) outputs = knn.conv( inputs_3d, @@ -785,13 +793,13 @@ class NNOpsCorrectnessTest(testing.TestCase): padding="VALID", dilations=(1, 1, 1, 1, 1), ) - self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) + self.assertAllClose(outputs, expected) outputs = knn.conv(inputs_3d, kernel, 2, padding="same") expected = tf.nn.conv3d( inputs_3d, kernel, (1, 2, 2, 2, 1), padding="SAME" ) - self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) + self.assertAllClose(outputs, expected) def test_depthwise_conv(self): # Test 2D conv. @@ -904,6 +912,14 @@ class NNOpsCorrectnessTest(testing.TestCase): ) self.assertAllClose(outputs, expected) + outputs = knn.conv_transpose( + inputs_1d, kernel, 5, output_padding=4, padding="valid" + ) + expected = tf.nn.conv_transpose( + inputs_1d, kernel, [2, 21, 5], 5, padding="VALID" + ) + self.assertAllClose(outputs, expected) + # Test 2D conv. inputs_2d = np.arange(96, dtype=float).reshape([2, 4, 4, 3]) kernel = np.arange(60, dtype=float).reshape([2, 2, 5, 3]) @@ -919,3 +935,21 @@ class NNOpsCorrectnessTest(testing.TestCase): inputs_2d, kernel, [2, 8, 8, 5], 2, padding="SAME" ) self.assertAllClose(outputs, expected) + + outputs = knn.conv_transpose( + inputs_2d, + kernel, + 5, + output_padding=4, + padding="valid", + dilation_rate=(1, 1), + ) + expected = tf.nn.conv_transpose( + inputs_2d, + kernel, + [2, 21, 21, 5], + 5, + padding="VALID", + dilations=(1, 1), + ) + self.assertAllClose(outputs, expected) diff --git a/keras_core/operations/numpy.py b/keras_core/operations/numpy.py index 5b036273d..f936084ea 100644 --- a/keras_core/operations/numpy.py +++ b/keras_core/operations/numpy.py @@ -2923,26 +2923,26 @@ def sum(x, axis=None, keepdims=False): class Zeros(Operation): def call(self, shape, dtype="float32"): - return backend.execute("zeros", shape, dtype) + return backend.numpy.zeros(shape, dtype=dtype) def compute_output_spec(self, shape, dtype="float32"): return KerasTensor(shape, dtype=dtype) def zeros(shape, dtype="float32"): - return backend.execute("zeros", shape, dtype) + return backend.numpy.zeros(shape, dtype=dtype) class Ones(Operation): def call(self, shape, dtype="float32"): - return backend.execute("ones", shape, dtype) + return backend.numpy.ones(shape, dtype=dtype) def compute_output_spec(self, shape, dtype="float32"): return KerasTensor(shape, dtype=dtype) def ones(shape, dtype="float32"): - return backend.execute("ones", shape, dtype) + return backend.numpy.ones(shape, dtype=dtype) class Eye(Operation): diff --git a/keras_core/optimizers/__init__.py b/keras_core/optimizers/__init__.py index e1de44dcf..0064074ff 100644 --- a/keras_core/optimizers/__init__.py +++ b/keras_core/optimizers/__init__.py @@ -1,3 +1,79 @@ +from keras_core.api_export import keras_core_export from keras_core.optimizers.adam import Adam from keras_core.optimizers.optimizer import Optimizer from keras_core.optimizers.sgd import SGD +from keras_core.saving import serialization_lib + +ALL_OBJECTS = { + Optimizer, + Adam, + SGD, +} +ALL_OBJECTS_DICT = {cls.__name__.lower(): cls for cls in ALL_OBJECTS} + + +@keras_core_export("keras_core.optimizers.serialize") +def serialize(optimizer): + """Returns the optimizer configuration as a Python dict. + + Args: + optimizer: An `Optimizer` instance to serialize. + + Returns: + Python dict which contains the configuration of the optimizer. + """ + return serialization_lib.serialize_keras_object(optimizer) + + +@keras_core_export("keras_core.optimizers.deserialize") +def deserialize(config, custom_objects=None): + """Returns a Keras optimizer object via its configuration. + + Args: + config: Optimizer configuration dictionary. + custom_objects: Optional dictionary mapping names (strings) to custom + objects (classes and functions) to be considered during + deserialization. + + Returns: + A Keras Optimizer instance. + """ + # Make deserialization case-insensitive for built-in optimizers. + if config["class_name"].lower() in ALL_OBJECTS_DICT: + config["class_name"] = config["class_name"].lower() + + print("deserialize:", config) + return serialization_lib.deserialize_keras_object( + config, + module_objects=ALL_OBJECTS_DICT, + custom_objects=custom_objects, + ) + + +@keras_core_export("keras_core.optimizers.get") +def get(identifier): + """Retrieves a Keras Optimizer instance. + + Args: + identifier: Optimizer identifier, one of: + - String: name of an optimizer + - Dictionary: configuration dictionary. + - Keras Optimizer instance (it will be returned unchanged). + + Returns: + A Keras Optimizer instance. + """ + print("call get with", identifier) + if isinstance(identifier, Optimizer): + return identifier + elif isinstance(identifier, dict): + return deserialize(identifier) + elif isinstance(identifier, str): + config = {"class_name": identifier, "config": {}} + opt = deserialize(config) + print(opt) + return opt + else: + raise ValueError( + f"Could not interpret optimizer identifier: {identifier}" + ) diff --git a/keras_core/saving/saving_lib.py b/keras_core/saving/saving_lib.py index a22549572..f545c5ba9 100644 --- a/keras_core/saving/saving_lib.py +++ b/keras_core/saving/saving_lib.py @@ -36,6 +36,8 @@ _ASSETS_DIRNAME = "assets" ATTR_SKIPLIST = frozenset( { + "_operations", + "_layers", "_functional", "_losses", "_inbound_nodes", @@ -108,47 +110,43 @@ def save_model(model, filepath, weights_format="h5"): zip_filepath = os.path.join(get_temp_dir(), "tmp_model.keras") else: zip_filepath = filepath - try: - with zipfile.ZipFile(zip_filepath, "w") as zf: - with zf.open(_METADATA_FILENAME, "w") as f: - f.write(metadata_json.encode()) - with zf.open(_CONFIG_FILENAME, "w") as f: - f.write(config_json.encode()) - if weights_format == "h5": - weights_store = H5IOStore( - _VARS_FNAME + ".h5", archive=zf, mode="w" - ) - elif weights_format == "npz": - weights_store = NpzIOStore( - _VARS_FNAME + ".npz", archive=zf, mode="w" - ) - else: - raise ValueError( - "Unknown `weights_format` argument. " - "Expected 'h5' or 'npz'. " - f"Received: weights_format={weights_format}" - ) + with zipfile.ZipFile(zip_filepath, "w") as zf: + with zf.open(_METADATA_FILENAME, "w") as f: + f.write(metadata_json.encode()) + with zf.open(_CONFIG_FILENAME, "w") as f: + f.write(config_json.encode()) - asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="w") - - _save_state( - model, - weights_store=weights_store, - assets_store=asset_store, - inner_path="", - visited_trackables=set(), + if weights_format == "h5": + weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="w") + elif weights_format == "npz": + weights_store = NpzIOStore( + _VARS_FNAME + ".npz", archive=zf, mode="w" + ) + else: + raise ValueError( + "Unknown `weights_format` argument. " + "Expected 'h5' or 'npz'. " + f"Received: weights_format={weights_format}" ) - weights_store.close() - asset_store.close() - if is_remote_path(filepath): - # Using gfile context manager doesn't close zip file when - # writing to GCS. Hence writing to local and copying to filepath. - gfile.copy(zip_filepath, filepath, overwrite=True) - os.remove(zip_filepath) - except Exception as e: - raise e + asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="w") + + _save_state( + model, + weights_store=weights_store, + assets_store=asset_store, + inner_path="", + visited_trackables=set(), + ) + weights_store.close() + asset_store.close() + + if is_remote_path(filepath): + # Using gfile context manager doesn't close zip file when + # writing to GCS. Hence writing to local and copying to filepath. + gfile.copy(zip_filepath, filepath, overwrite=True) + os.remove(zip_filepath) def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): @@ -161,59 +159,52 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): f"Received: filepath={filepath}" ) - try: - with gfile.GFile(filepath, mode="r+b") as gfile_handle, zipfile.ZipFile( - gfile_handle, "r" - ) as zf: - with zf.open(_CONFIG_FILENAME, "r") as f: - config_json = f.read() + with gfile.GFile(filepath, mode="r+b") as gfile_handle, zipfile.ZipFile( + gfile_handle, "r" + ) as zf: + with zf.open(_CONFIG_FILENAME, "r") as f: + config_json = f.read() - # Note: we should NOT use a custom JSON decoder. Anything that - # needs custom decoding must be handled in deserialize_keras_object. - config_dict = json.loads(config_json) - if not compile: - # Disable compilation - config_dict["compile_config"] = None - # Construct the model from the configuration file in the archive. - with ObjectSharingScope(): - model = deserialize_keras_object( - config_dict, custom_objects, safe_mode=safe_mode - ) - - all_filenames = zf.namelist() - if _VARS_FNAME + ".h5" in all_filenames: - weights_store = H5IOStore( - _VARS_FNAME + ".h5", archive=zf, mode="r" - ) - elif _VARS_FNAME + ".npz" in all_filenames: - weights_store = NpzIOStore( - _VARS_FNAME + ".npz", archive=zf, mode="r" - ) - else: - raise ValueError( - f"Expected a {_VARS_FNAME}.h5 or {_VARS_FNAME}.npz file." - ) - - if len(all_filenames) > 3: - asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="r") - else: - asset_store = None - - _load_state( - model, - weights_store=weights_store, - assets_store=asset_store, - inner_path="", - visited_trackables=set(), + # Note: we should NOT use a custom JSON decoder. Anything that + # needs custom decoding must be handled in deserialize_keras_object. + config_dict = json.loads(config_json) + if not compile: + # Disable compilation + config_dict["compile_config"] = None + # Construct the model from the configuration file in the archive. + with ObjectSharingScope(): + model = deserialize_keras_object( + config_dict, custom_objects, safe_mode=safe_mode ) - weights_store.close() - if asset_store: - asset_store.close() - except Exception as e: - raise e - else: - return model + all_filenames = zf.namelist() + if _VARS_FNAME + ".h5" in all_filenames: + weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="r") + elif _VARS_FNAME + ".npz" in all_filenames: + weights_store = NpzIOStore( + _VARS_FNAME + ".npz", archive=zf, mode="r" + ) + else: + raise ValueError( + f"Expected a {_VARS_FNAME}.h5 or {_VARS_FNAME}.npz file." + ) + + if len(all_filenames) > 3: + asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="r") + else: + asset_store = None + + _load_state( + model, + weights_store=weights_store, + assets_store=asset_store, + inner_path="", + visited_trackables=set(), + ) + weights_store.close() + if asset_store: + asset_store.close() + return model def save_weights_only(model, filepath): diff --git a/keras_core/testing/test_case.py b/keras_core/testing/test_case.py index 80db96175..89ec0167b 100644 --- a/keras_core/testing/test_case.py +++ b/keras_core/testing/test_case.py @@ -1,5 +1,5 @@ import json -import os +import shutil import tempfile import unittest @@ -12,7 +12,7 @@ class TestCase(unittest.TestCase): def get_temp_dir(self): temp_dir = tempfile.mkdtemp() - self.addCleanup(lambda: os.rmdir(temp_dir)) + self.addCleanup(lambda: shutil.rmtree(temp_dir)) return temp_dir def assertAllClose(self, x1, x2, atol=1e-7, rtol=1e-7): diff --git a/keras_core/trainers/trainer.py b/keras_core/trainers/trainer.py index df2ad3cee..934ea8dc2 100644 --- a/keras_core/trainers/trainer.py +++ b/keras_core/trainers/trainer.py @@ -3,6 +3,7 @@ import warnings from keras_core import backend from keras_core import metrics as metrics_module from keras_core import operations as ops +from keras_core import optimizers from keras_core.saving import serialization_lib from keras_core.trainers.compile_utils import CompileLoss from keras_core.trainers.compile_utils import CompileMetrics @@ -26,8 +27,7 @@ class Trainer: run_eagerly=False, jit_compile=True, ): - # TODO: get from module - self.optimizer = optimizer + self.optimizer = optimizers.get(optimizer) if loss is not None: self._compile_loss = CompileLoss(loss, loss_weights) else: