From 2481069ed4e43ce19c60e66fb5276450031aef56 Mon Sep 17 00:00:00 2001 From: Aritra Roy Gosthipaty Date: Wed, 19 Jul 2023 01:08:48 +0530 Subject: [PATCH] Adding: Numpy Backend (#483) * chore: adding numpy backend * creview comments * review comments * chore: adding math * chore: adding random module * chore: adding ranndom in init * review comments * chore: adding numpy and nn for numpy backend * chore: adding generic pool, max, and average pool * chore: adding the conv ops * chore: reformat code and using jax for conv and pool * chore: added self value * chore: activation tests pass * chore: adding post build method * chore: adding necessaity methods to the numpy trainer * chore: fixing utils test * chore: fixing losses test suite * chore: fix backend tests * chore: fixing initializers test * chore: fixing accuracy metrics test * chore: fixing ops test * chore: review comments * chore: init with image and fixing random tests * chore: skipping random seed set for numpy backend * chore: adding single resize image method * chore: skipping tests for applications and layers * chore: skipping tests for models * chore: skipping testsor saving * chore: skipping tests for trainers * chore:ixing one hot * chore: fixing vmap in numpy and metrics test * chore: adding a wrapper to numpy sum, started fixing layer tests * fix: is_tensor now accepts numpy scalars * chore: adding draw seed * fix: warn message for numpy masking * fix: checking whether kernel are tensors * chore: adding rnn * chore: adding dynamic backend for numpy * fix: axis cannot be None for normalize * chore: adding jax resize for numpy image * chore: adding rnn implementation in numpy * chore: using pytest fixtures * change: numpy import string * chore: review comments * chore: adding numpy to backend list of github actions * chore: remove debug print statements --- conftest.py | 21 + keras_core/applications/applications_test.py | 1 + .../applications/imagenet_utils_test.py | 2 + keras_core/backend/__init__.py | 7 + keras_core/backend/numpy/__init__.py | 20 + keras_core/backend/numpy/core.py | 212 +++++++ keras_core/backend/numpy/image.py | 45 ++ keras_core/backend/numpy/layer.py | 3 + keras_core/backend/numpy/math.py | 76 +++ keras_core/backend/numpy/nn.py | 517 ++++++++++++++++ keras_core/backend/numpy/numpy.py | 571 ++++++++++++++++++ keras_core/backend/numpy/random.py | 88 +++ keras_core/backend/numpy/rnn.py | 236 ++++++++ keras_core/backend/numpy/trainer.py | 18 + keras_core/callbacks/callback_test.py | 2 + keras_core/callbacks/csv_logger_test.py | 3 + keras_core/callbacks/early_stopping_test.py | 6 + keras_core/callbacks/lambda_callback_test.py | 2 + .../callbacks/learning_rate_scheduler_test.py | 6 + keras_core/callbacks/model_checkpoint_test.py | 2 + .../callbacks/reduce_lr_on_plateau_test.py | 7 + keras_core/callbacks/remote_monitor_test.py | 3 + keras_core/callbacks/tensorboard_test.py | 11 + keras_core/callbacks/terminate_on_nan_test.py | 2 + .../layers/activations/activation_test.py | 3 + keras_core/layers/activations/elu_test.py | 2 + .../layers/activations/leaky_relu_test.py | 2 + keras_core/layers/activations/prelu_test.py | 2 + keras_core/layers/activations/relu_test.py | 2 + keras_core/layers/activations/softmax_test.py | 2 + .../attention/multi_head_attention_test.py | 10 + keras_core/layers/convolutional/conv_test.py | 4 + .../convolutional/conv_transpose_test.py | 3 + .../convolutional/depthwise_conv_test.py | 3 + .../convolutional/separable_conv_test.py | 3 + keras_core/layers/core/dense_test.py | 2 + keras_core/layers/core/einsum_dense_test.py | 2 + keras_core/layers/core/embedding_test.py | 2 + keras_core/layers/core/identity_test.py | 3 + keras_core/layers/core/lambda_layer_test.py | 2 + keras_core/layers/core/masking_test.py | 3 + keras_core/layers/core/wrapper_test.py | 3 + keras_core/layers/layer.py | 8 + keras_core/layers/layer_test.py | 5 + keras_core/layers/merging/merging_test.py | 1 + .../normalization/batch_normalization_test.py | 2 + .../normalization/group_normalization_test.py | 2 + .../normalization/layer_normalization_test.py | 2 + .../spectral_normalization_test.py | 2 + .../normalization/unit_normalization_test.py | 2 + .../layers/pooling/average_pooling_test.py | 1 + .../pooling/global_average_pooling_test.py | 2 + .../layers/pooling/global_max_pooling_test.py | 2 + keras_core/layers/pooling/max_pooling_test.py | 2 + .../layers/preprocessing/center_crop_test.py | 2 + .../preprocessing/normalization_test.py | 2 + .../preprocessing/random_brightness_test.py | 2 + .../preprocessing/random_contrast_test.py | 2 + .../layers/preprocessing/rescaling_test.py | 3 + .../activity_regularization_test.py | 2 + .../layers/regularization/dropout_test.py | 1 + .../regularization/gaussian_dropout_test.py | 2 + .../regularization/gaussian_noise_test.py | 2 + .../regularization/spatial_dropout_test.py | 4 + .../layers/reshaping/cropping1d_test.py | 7 +- .../layers/reshaping/cropping2d_test.py | 1 + .../layers/reshaping/cropping3d_test.py | 2 + keras_core/layers/reshaping/flatten_test.py | 2 + keras_core/layers/reshaping/permute_test.py | 1 + .../layers/reshaping/repeat_vector_test.py | 1 + keras_core/layers/reshaping/reshape_test.py | 1 + .../layers/reshaping/up_sampling1d_test.py | 1 + .../layers/reshaping/up_sampling2d_test.py | 2 + .../layers/reshaping/up_sampling3d_test.py | 2 + keras_core/layers/rnn/bidirectional_test.py | 2 + keras_core/layers/rnn/conv_lstm1d_test.py | 2 + keras_core/layers/rnn/conv_lstm2d_test.py | 2 + keras_core/layers/rnn/conv_lstm3d_test.py | 2 + .../layers/rnn/dropout_rnn_cell_test.py | 3 + keras_core/layers/rnn/gru_test.py | 2 + keras_core/layers/rnn/lstm_test.py | 2 + keras_core/layers/rnn/rnn_test.py | 1 + keras_core/layers/rnn/simple_rnn_test.py | 2 + .../layers/rnn/stacked_rnn_cells_test.py | 2 + .../layers/rnn/time_distributed_test.py | 2 + keras_core/losses/loss_test.py | 13 + keras_core/metrics/accuracy_metrics.py | 11 +- keras_core/metrics/confusion_metrics_test.py | 10 +- keras_core/models/cloning_test.py | 2 + keras_core/models/functional_test.py | 15 + keras_core/models/model.py | 2 + keras_core/models/model_test.py | 2 + keras_core/models/sequential_test.py | 2 + keras_core/ops/core_test.py | 4 +- keras_core/optimizers/adam_test.py | 2 + .../schedules/learning_rate_schedule_test.py | 2 + .../saving/legacy/legacy_h5_format_test.py | 2 + keras_core/saving/saving_lib_test.py | 2 + keras_core/saving/serialization_lib_test.py | 3 + keras_core/trainers/trainer_test.py | 3 + keras_core/utils/backend_utils.py | 8 + keras_core/utils/numerical_utils.py | 3 + keras_core/utils/rng_utils_test.py | 6 + keras_core/utils/traceback_utils.py | 2 + 104 files changed, 2089 insertions(+), 14 deletions(-) create mode 100644 keras_core/backend/numpy/__init__.py create mode 100644 keras_core/backend/numpy/core.py create mode 100644 keras_core/backend/numpy/image.py create mode 100644 keras_core/backend/numpy/layer.py create mode 100644 keras_core/backend/numpy/math.py create mode 100644 keras_core/backend/numpy/nn.py create mode 100644 keras_core/backend/numpy/numpy.py create mode 100644 keras_core/backend/numpy/random.py create mode 100644 keras_core/backend/numpy/rnn.py create mode 100644 keras_core/backend/numpy/trainer.py diff --git a/conftest.py b/conftest.py index e9cc25098..25cbb4398 100644 --- a/conftest.py +++ b/conftest.py @@ -5,3 +5,24 @@ try: import torch # noqa: F401 except ImportError: pass + +import pytest + +from keras_core.backend import backend + + +def pytest_configure(config): + config.addinivalue_line( + "markers", + "requires_trainable_backend: mark test for trainable backend only", + ) + + +def pytest_collection_modifyitems(config, items): + requires_trainable_backend = pytest.mark.skipif( + backend() == "numpy", + reason="Trainer not implemented for NumPy backend.", + ) + for item in items: + if "requires_trainable_backend" in item.keywords: + item.add_marker(requires_trainable_backend) diff --git a/keras_core/applications/applications_test.py b/keras_core/applications/applications_test.py index b1049e544..24f23fd30 100644 --- a/keras_core/applications/applications_test.py +++ b/keras_core/applications/applications_test.py @@ -107,6 +107,7 @@ def _get_elephant(target_size): os.environ.get("SKIP_APPLICATIONS_TESTS"), reason="Env variable set to skip.", ) +@pytest.mark.requires_trainable_backend class ApplicationsTest(testing.TestCase, parameterized.TestCase): @parameterized.named_parameters(MODEL_LIST) def test_application_notop_variable_input_channels(self, app, last_dim, _): diff --git a/keras_core/applications/imagenet_utils_test.py b/keras_core/applications/imagenet_utils_test.py index ba6b3d294..6e61ccefb 100644 --- a/keras_core/applications/imagenet_utils_test.py +++ b/keras_core/applications/imagenet_utils_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from absl.testing import parameterized import keras_core as keras @@ -74,6 +75,7 @@ class TestImageNetUtils(testing.TestCase, parameterized.TestCase): {"testcase_name": "mode_caffe", "mode": "caffe"}, ] ) + @pytest.mark.requires_trainable_backend def test_preprocess_input_symbolic(self, mode): # Test image batch x = np.random.uniform(0, 255, (2, 10, 10, 3)) diff --git a/keras_core/backend/__init__.py b/keras_core/backend/__init__.py index ab1d839ec..3f6ab97df 100644 --- a/keras_core/backend/__init__.py +++ b/keras_core/backend/__init__.py @@ -37,5 +37,12 @@ elif backend() == "jax": elif backend() == "torch": print_msg("Using PyTorch backend.") from keras_core.backend.torch import * # noqa: F403 +elif backend() == "numpy": + print_msg( + "Using NumPy backend.\nThe NumPy backend does not support " + "training. It should only be used for inference, evaluation, " + "and debugging." + ) + from keras_core.backend.numpy import * # noqa: F403 else: raise ValueError(f"Unable to import backend : {backend()}") diff --git a/keras_core/backend/numpy/__init__.py b/keras_core/backend/numpy/__init__.py new file mode 100644 index 000000000..3dc0f6ad3 --- /dev/null +++ b/keras_core/backend/numpy/__init__.py @@ -0,0 +1,20 @@ +from keras_core.backend.numpy import core +from keras_core.backend.numpy import image +from keras_core.backend.numpy import math +from keras_core.backend.numpy import nn +from keras_core.backend.numpy import numpy +from keras_core.backend.numpy import random +from keras_core.backend.numpy.core import DYNAMIC_SHAPES_OK +from keras_core.backend.numpy.core import Variable +from keras_core.backend.numpy.core import cast +from keras_core.backend.numpy.core import compute_output_spec +from keras_core.backend.numpy.core import cond +from keras_core.backend.numpy.core import convert_to_numpy +from keras_core.backend.numpy.core import convert_to_tensor +from keras_core.backend.numpy.core import is_tensor +from keras_core.backend.numpy.core import name_scope +from keras_core.backend.numpy.core import shape +from keras_core.backend.numpy.core import vectorized_map +from keras_core.backend.numpy.rnn import gru +from keras_core.backend.numpy.rnn import lstm +from keras_core.backend.numpy.rnn import rnn diff --git a/keras_core/backend/numpy/core.py b/keras_core/backend/numpy/core.py new file mode 100644 index 000000000..1168105dd --- /dev/null +++ b/keras_core/backend/numpy/core.py @@ -0,0 +1,212 @@ +from contextlib import nullcontext + +import numpy as np +from tensorflow import nest + +from keras_core.backend.common import KerasVariable +from keras_core.backend.common import standardize_dtype +from keras_core.backend.common.keras_tensor import KerasTensor +from keras_core.backend.common.stateless_scope import StatelessScope + +DYNAMIC_SHAPES_OK = True + + +class Variable(KerasVariable): + def _initialize(self, value): + self._value = np.array(value, dtype=self._dtype) + + def _direct_assign(self, value): + self._value = np.array(value, dtype=self._dtype) + + def _convert_to_tensor(self, value, dtype=None): + return convert_to_tensor(value, dtype=dtype) + + # Overload native accessor. + def __array__(self): + return self.value + + +def convert_to_tensor(x, dtype=None): + if dtype is not None: + dtype = standardize_dtype(dtype) + if isinstance(x, Variable): + if dtype and dtype != x.dtype: + return x.value.astype(dtype) + return x.value + return np.array(x, dtype=dtype) + + +def convert_to_numpy(x): + return np.array(x) + + +def is_tensor(x): + if isinstance(x, (np.generic, np.ndarray)): + return True + return False + + +def shape(x): + return x.shape + + +def cast(x, dtype): + return convert_to_tensor(x, dtype=dtype) + + +def cond(pred, true_fn, false_fn): + if pred: + return true_fn() + return false_fn() + + +def name_scope(name): + # There is no need for a named context for NumPy. + return nullcontext() + + +def vectorized_map(function, elements): + if len(elements) == 1: + return function(elements) + else: + batch_size = elements[0].shape[0] + output_store = list() + for index in range(batch_size): + output_store.append(function([x[index] for x in elements])) + return np.stack(output_store) + + +# Shape / dtype inference util +def compute_output_spec(fn, *args, **kwargs): + with StatelessScope(): + + def has_none_shape(x): + if isinstance(x, KerasTensor): + return None in x.shape + return False + + none_in_shape = any(map(has_none_shape, nest.flatten((args, kwargs)))) + + def convert_keras_tensor_to_numpy(x, fill_value=None): + if isinstance(x, KerasTensor): + shape = list(x.shape) + if fill_value: + for i, e in enumerate(shape): + if e is None: + shape[i] = fill_value + return np.empty( + shape=shape, + dtype=x.dtype, + ) + return x + + args_1, kwargs_1 = nest.map_structure( + lambda x: convert_keras_tensor_to_numpy(x, fill_value=83), + (args, kwargs), + ) + outputs_1 = fn(*args_1, **kwargs_1) + + outputs = outputs_1 + + if none_in_shape: + args_2, kwargs_2 = nest.map_structure( + lambda x: convert_keras_tensor_to_numpy(x, fill_value=89), + (args, kwargs), + ) + outputs_2 = fn(*args_2, **kwargs_2) + + flat_out_1 = nest.flatten(outputs_1) + flat_out_2 = nest.flatten(outputs_2) + + flat_out = [] + for x1, x2 in zip(flat_out_1, flat_out_2): + shape = list(x1.shape) + for i, e in enumerate(x2.shape): + if e != shape[i]: + shape[i] = None + flat_out.append(KerasTensor(shape, standardize_dtype(x1.dtype))) + outputs = nest.pack_sequence_as(outputs_1, flat_out) + + def convert_numpy_to_keras_tensor(x): + if is_tensor(x): + return KerasTensor(x.shape, standardize_dtype(x.dtype)) + return x + + output_spec = nest.map_structure(convert_numpy_to_keras_tensor, outputs) + return output_spec + + +def scatter(indices, values, shape): + indices = convert_to_tensor(indices) + values = convert_to_tensor(values) + zeros = np.zeros(shape, dtype=values.dtype) + + index_length = indices.shape[-1] + value_shape = shape[index_length:] + indices = np.reshape(indices, [-1, index_length]) + values = np.reshape(values, [-1] + list(value_shape)) + + for i in range(indices.shape[0]): + index = indices[i] + zeros[tuple(index)] += values[i] + return zeros + + +def scatter_update(inputs, indices, updates): + indices = np.array(indices) + indices = np.transpose(indices) + inputs[tuple(indices)] = updates + return inputs + + +def slice(inputs, start_indices, lengths): + # Validate inputs + assert len(start_indices) == len(lengths) + + # Generate list of indices arrays for each dimension + indices = [ + np.arange(start, start + length) + for start, length in zip(start_indices, lengths) + ] + + # Use np.ix_ to create a multidimensional index array + mesh = np.ix_(*indices) + + return inputs[mesh] + + +def slice_update(inputs, start_indices, updates): + # Generate list of indices arrays for each dimension + indices = [ + np.arange(start, start + length) + for start, length in zip(start_indices, updates.shape) + ] + + # Use np.ix_ to create a multidimensional index array + mesh = np.ix_(*indices) + inputs[mesh] = updates + return inputs + + +def while_loop( + cond, + body, + loop_vars, + maximum_iterations=None, +): + current_iter = 0 + iteration_check = ( + lambda iter: maximum_iterations is None or iter < maximum_iterations + ) + loop_vars = tuple([convert_to_tensor(v) for v in loop_vars]) + while cond(*loop_vars) and iteration_check(current_iter): + loop_vars = body(*loop_vars) + if not isinstance(loop_vars, (list, tuple)): + loop_vars = (loop_vars,) + loop_vars = tuple(loop_vars) + current_iter += 1 + return loop_vars + + +def stop_gradient(x): + pass diff --git a/keras_core/backend/numpy/image.py b/keras_core/backend/numpy/image.py new file mode 100644 index 000000000..33736f0f1 --- /dev/null +++ b/keras_core/backend/numpy/image.py @@ -0,0 +1,45 @@ +import jax +import numpy as np + +RESIZE_METHODS = ( + "bilinear", + "nearest", + "lanczos3", + "lanczos5", + "bicubic", +) + + +def resize( + image, size, method="bilinear", antialias=False, data_format="channels_last" +): + if method not in RESIZE_METHODS: + raise ValueError( + "Invalid value for argument `method`. Expected of one " + f"{RESIZE_METHODS}. Received: method={method}" + ) + if not len(size) == 2: + raise ValueError( + "Argument `size` must be a tuple of two elements " + f"(height, width). Received: size={size}" + ) + size = tuple(size) + if len(image.shape) == 4: + if data_format == "channels_last": + size = (image.shape[0],) + size + (image.shape[-1],) + else: + size = (image.shape[0], image.shape[1]) + size + elif len(image.shape) == 3: + if data_format == "channels_last": + size = size + (image.shape[-1],) + else: + size = (image.shape[0],) + size + else: + raise ValueError( + "Invalid input rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"image.shape={image.shape}" + ) + return np.array( + jax.image.resize(image, size, method=method, antialias=antialias) + ) diff --git a/keras_core/backend/numpy/layer.py b/keras_core/backend/numpy/layer.py new file mode 100644 index 000000000..daf2bb96f --- /dev/null +++ b/keras_core/backend/numpy/layer.py @@ -0,0 +1,3 @@ +class NumpyLayer: + def _post_build(self): + pass diff --git a/keras_core/backend/numpy/math.py b/keras_core/backend/numpy/math.py new file mode 100644 index 000000000..864274891 --- /dev/null +++ b/keras_core/backend/numpy/math.py @@ -0,0 +1,76 @@ +import numpy as np + + +def segment_sum(data, segment_ids, num_segments=None, sorted=False): + if num_segments is None: + num_segments = np.amax(segment_ids) + 1 + + valid_indices = segment_ids >= 0 # Ignore segment_ids that are -1 + valid_data = data[valid_indices] + valid_segment_ids = segment_ids[valid_indices] + + data_shape = list(valid_data.shape) + data_shape[ + 0 + ] = num_segments # Replace first dimension (which corresponds to segments) + + if sorted: + result = np.zeros(data_shape, dtype=valid_data.dtype) + np.add.at(result, valid_segment_ids, valid_data) + else: + sort_indices = np.argsort(valid_segment_ids) + sorted_segment_ids = valid_segment_ids[sort_indices] + sorted_data = valid_data[sort_indices] + + result = np.zeros(data_shape, dtype=valid_data.dtype) + np.add.at(result, sorted_segment_ids, sorted_data) + + return result + + +def top_k(x, k, sorted=False): + sorted_indices = np.argsort(x, axis=-1)[..., ::-1] + sorted_values = np.sort(x, axis=-1)[..., ::-1] + + if sorted: + # Take the k largest values. + top_k_values = sorted_values[..., :k] + top_k_indices = sorted_indices[..., :k] + else: + # Partition the array such that all values larger than the k-th + # largest value are to the right of it. + top_k_values = np.partition(x, -k, axis=-1)[..., -k:] + top_k_indices = np.argpartition(x, -k, axis=-1)[..., -k:] + + # Get the indices in sorted order. + idx = np.argsort(-top_k_values, axis=-1) + + # Get the top k values and their indices. + top_k_values = np.take_along_axis(top_k_values, idx, axis=-1) + top_k_indices = np.take_along_axis(top_k_indices, idx, axis=-1) + + return top_k_values, top_k_indices + + +def in_top_k(targets, predictions, k): + targets = targets[:, None] + topk_values = top_k(predictions, k)[0] + targets_values = np.take_along_axis(predictions, targets, axis=-1) + mask = targets_values >= topk_values + return np.any(mask, axis=-1) + + +def logsumexp(x, axis=None, keepdims=False): + max_x = np.max(x, axis=axis, keepdims=True) + result = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) + max_x + return np.squeeze(result) if not keepdims else result + + +def qr(x, mode="reduced"): + if mode not in {"reduced", "complete"}: + raise ValueError( + "`mode` argument value not supported. " + "Expected one of {'reduced', 'complete'}. " + f"Received: mode={mode}" + ) + return np.linalg.qr(x, mode=mode) diff --git a/keras_core/backend/numpy/nn.py b/keras_core/backend/numpy/nn.py new file mode 100644 index 000000000..d83f0ba83 --- /dev/null +++ b/keras_core/backend/numpy/nn.py @@ -0,0 +1,517 @@ +import jax +import numpy as np +from jax import lax +from jax import numpy as jnp + +from keras_core.backend import standardize_data_format +from keras_core.backend.common.backend_utils import ( + compute_conv_transpose_padding, +) +from keras_core.backend.config import epsilon +from keras_core.backend.numpy.core import is_tensor + + +def relu(x): + return np.maximum(x, 0.0) + + +def relu6(x): + return np.clip(x, 0.0, 6.0) + + +def sigmoid(x): + return 1.0 / (1.0 + np.exp(-x)) + + +def tanh(x): + return np.tanh(x) + + +def softplus(x): + return np.log(1.0 + np.exp(x)) + + +def softsign(x): + return x / (1.0 + np.abs(x)) + + +def silu(x): + return x * (1.0 / (1.0 + np.exp(-x))) + + +def swish(x): + return x * (1.0 / (1.0 + np.exp(-x))) + + +def log_sigmoid(x): + return np.log(1.0 / (1.0 + np.exp(-x))) + + +def leaky_relu(x, negative_slope=0.2): + return np.maximum(x, negative_slope * x) + + +def hard_sigmoid(x): + x = (x / 6.0) + 0.5 + return np.where(x <= 0.0, 0.0, np.where(x >= 1.0, 1.0, x)) + + +def elu(x, alpha=1.0): + return np.where(x >= 0.0, x, alpha * (np.exp(x) - 1.0)) + + +def selu( + x, + alpha=1.6732632423543772848170429916717, + scale=1.0507009873554804934193349852946, +): + return scale * np.where(x >= 0.0, x, alpha * (np.exp(x) - 1.0)) + + +def gelu(x, approximate=True): + if approximate: + return ( + 0.5 + * x + * ( + 1.0 + + np.tanh( + np.sqrt(2.0 / np.pi) * (x + 0.044715 * np.power(x, 3)) + ) + ) + ) + else: + from scipy.stats import norm + + return x * norm.cdf(x) + + +def softmax(x, axis=None): + exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) + return exp_x / np.sum(exp_x, axis=axis, keepdims=True) + + +def log_softmax(x, axis=None): + max_x = np.max(x, axis=axis, keepdims=True) + logsumexp = np.log(np.exp(x - max_x).sum(axis=axis, keepdims=True)) + return x - max_x - logsumexp + + +def _convert_to_spatial_operand( + x, + num_spatial_dims, + data_format="channels_last", + include_batch_and_channels=True, +): + # Helper function that converts an operand to a spatial operand. + x = (x,) * num_spatial_dims if isinstance(x, int) else x + if not include_batch_and_channels: + return x + if data_format == "channels_last": + x = (1,) + x + (1,) + else: + x = (1,) + (1,) + x + return x + + +def _pool( + inputs, + initial_value, + reduce_fn, + pool_size, + strides=None, + padding="valid", +): + """Helper function to define pooling functions. + + Args: + inputs: input data of shape `N+2`. + initial_value: the initial value for the reduction. + reduce_fn: a reduce function of the form `(T, T) -> T`. + pool_size: a sequence of `N` integers, representing the window size to + reduce over. + strides: a sequence of `N` integers, representing the inter-window + strides (default: `(1, ..., 1)`). + padding: either the string `same` or `valid`. + + Returns: + The output of the reduction for each window slice. + """ + if padding not in ("same", "valid"): + raise ValueError( + f"Invalid padding '{padding}', must be 'same' or 'valid'." + ) + padding = padding.upper() + return np.array( + lax.reduce_window( + inputs, + initial_value, + reduce_fn, + pool_size, + strides, + padding, + ) + ) + + +def max_pool( + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, +): + data_format = standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + pool_size = _convert_to_spatial_operand( + pool_size, num_spatial_dims, data_format + ) + strides = pool_size if strides is None else strides + strides = _convert_to_spatial_operand( + strides, num_spatial_dims, data_format + ) + return _pool(inputs, -jnp.inf, lax.max, pool_size, strides, padding) + + +def average_pool( + inputs, + pool_size, + strides, + padding, + data_format=None, +): + data_format = standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + pool_size = _convert_to_spatial_operand( + pool_size, num_spatial_dims, data_format + ) + strides = pool_size if strides is None else strides + strides = _convert_to_spatial_operand( + strides, num_spatial_dims, data_format + ) + + pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding) + if padding == "valid": + # Avoid the extra reduce_window. + return pooled / np.prod(pool_size) + else: + # Count the number of valid entries at each input point, then use that + # for computing average. Assumes that any two arrays of same shape will + # be padded the same. Avoid broadcasting on axis where pooling is + # skipped. + shape = [ + (a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size) + ] + window_counts = _pool( + jnp.ones(shape, inputs.dtype), + 0.0, + lax.add, + pool_size, + strides, + padding, + ) + return pooled / window_counts + + +def _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format="channels_last", + transpose=False, +): + """Create a `lax.ConvDimensionNumbers` for the given inputs.""" + num_dims = num_spatial_dims + 2 + + if data_format == "channels_last": + spatial_dims = tuple(range(1, num_dims - 1)) + inputs_dn = (0, num_dims - 1) + spatial_dims + else: + spatial_dims = tuple(range(2, num_dims)) + inputs_dn = (0, 1) + spatial_dims + + if transpose: + kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2)) + else: + kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2)) + + return lax.ConvDimensionNumbers( + lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn + ) + + +def conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + data_format = standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + dimension_numbers = _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format, + transpose=False, + ) + strides = _convert_to_spatial_operand( + strides, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + dilation_rate = _convert_to_spatial_operand( + dilation_rate, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + if data_format == "channels_last": + channels = inputs.shape[-1] + else: + channels = inputs.shape[1] + kernel_in_channels = kernel.shape[-2] + if channels % kernel_in_channels > 0: + raise ValueError( + "The number of input channels must be evenly divisible by " + f"kernel's in_channels. Received input channels {channels} and " + f"kernel in_channels {kernel_in_channels}. " + ) + feature_group_count = channels // kernel_in_channels + return np.array( + jax.lax.conv_general_dilated( + inputs, + kernel if is_tensor(kernel) else kernel.numpy(), + strides, + padding, + rhs_dilation=dilation_rate, + dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count, + ) + ) + + +def depthwise_conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + data_format = standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + dimension_numbers = _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format, + transpose=False, + ) + strides = _convert_to_spatial_operand( + strides, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + dilation_rate = _convert_to_spatial_operand( + dilation_rate, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + feature_group_count = ( + inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1] + ) + kernel = jnp.reshape( + kernel if is_tensor(kernel) else kernel.numpy(), + kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]), + ) + return np.array( + jax.lax.conv_general_dilated( + inputs, + kernel, + strides, + padding, + rhs_dilation=dilation_rate, + dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count, + ) + ) + + +def separable_conv( + inputs, + depthwise_kernel, + pointwise_kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + data_format = standardize_data_format(data_format) + depthwise_conv_output = depthwise_conv( + inputs, + depthwise_kernel, + strides, + padding, + data_format, + dilation_rate, + ) + return conv( + depthwise_conv_output, + pointwise_kernel, + strides=1, + padding="valid", + data_format=data_format, + dilation_rate=dilation_rate, + ) + + +def conv_transpose( + inputs, + kernel, + strides=1, + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=1, +): + data_format = standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + padding_values = compute_conv_transpose_padding( + inputs.shape, + kernel.shape, + strides, + padding, + output_padding, + data_format, + dilation_rate, + ) + dimension_numbers = _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format, + transpose=False, + ) + strides = _convert_to_spatial_operand( + strides, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + dilation_rate = _convert_to_spatial_operand( + dilation_rate, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + + return np.array( + jax.lax.conv_transpose( + inputs, + kernel if is_tensor(kernel) else kernel.numpy(), + strides, + padding=padding_values, + rhs_dilation=dilation_rate, + dimension_numbers=dimension_numbers, + transpose_kernel=True, + ) + ) + + +def one_hot(x, num_classes, axis=-1, dtype="float32"): + input_shape = x.shape + + # Shrink the last dimension if the shape is (..., 1). + if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: + input_shape = tuple(input_shape[:-1]) + + x = x.reshape(-1) + if not num_classes: + num_classes = np.max(x) + 1 + + batch_size = x.shape[0] + categorical = np.zeros((batch_size, num_classes), dtype=dtype) + valid_indices = x >= 0 + categorical[np.arange(batch_size)[valid_indices], x[valid_indices]] = 1 + + # First, reshape the array with the extra dimension at the end + output_shape = input_shape + (num_classes,) + categorical = np.reshape(categorical, output_shape) + + # Then, move this new dimension to the right place (according to axis) + if axis != -1: + categorical = np.moveaxis(categorical, -1, axis) + + return categorical + + +def categorical_crossentropy(target, output, from_logits=False, axis=-1): + target = np.array(target) + output = np.array(output) + + if target.shape != output.shape: + raise ValueError( + "Arguments `target` and `output` must have the same shape. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + if len(target.shape) < 1: + raise ValueError( + "Arguments `target` and `output` must be at least rank 1. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + + if from_logits: + log_prob = log_softmax(output, axis=axis) + else: + output = output / np.sum(output, axis, keepdims=True) + output = np.clip(output, epsilon(), 1.0 - epsilon()) + log_prob = np.log(output) + return -np.sum(target * log_prob, axis=axis) + + +def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): + target = np.array(target, dtype="int32") + output = np.array(output) + if len(target.shape) == len(output.shape) and target.shape[-1] == 1: + target = np.squeeze(target, axis=-1) + + if len(output.shape) < 1: + raise ValueError( + "Argument `output` must be at least rank 1. " + "Received: " + f"output.shape={output.shape}" + ) + if target.shape != output.shape[:-1]: + raise ValueError( + "Arguments `target` and `output` must have the same shape " + "up until the last dimension: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + if from_logits: + log_prob = log_softmax(output, axis=axis) + else: + output = output / np.sum(output, axis, keepdims=True) + output = np.clip(output, epsilon(), 1.0 - epsilon()) + log_prob = np.log(output) + target = one_hot(target, output.shape[axis], axis=axis) + return -np.sum(target * log_prob, axis=axis) + + +def binary_crossentropy(target, output, from_logits=False): + target = np.array(target) + output = np.array(output) + + if target.shape != output.shape: + raise ValueError( + "Arguments `target` and `output` must have the same shape. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + + if from_logits: + output = sigmoid(output) + + output = np.clip(output, epsilon(), 1.0 - epsilon()) + bce = target * np.log(output) + bce += (1.0 - target) * np.log(1.0 - output) + return -bce diff --git a/keras_core/backend/numpy/numpy.py b/keras_core/backend/numpy/numpy.py new file mode 100644 index 000000000..4ba804a2c --- /dev/null +++ b/keras_core/backend/numpy/numpy.py @@ -0,0 +1,571 @@ +import numpy as np + + +def add(x1, x2): + return np.add(x1, x2) + + +def einsum(subscripts, *operands, **kwargs): + return np.einsum(subscripts, *operands, **kwargs) + + +def subtract(x1, x2): + return np.subtract(x1, x2) + + +def matmul(x1, x2): + return np.matmul(x1, x2) + + +def multiply(x1, x2): + return np.multiply(x1, x2) + + +def mean(x, axis=None, keepdims=False): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.mean(x, axis=axis, keepdims=keepdims) + + +def max(x, axis=None, keepdims=False, initial=None): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.max(x, axis=axis, keepdims=keepdims, initial=initial) + + +def ones(shape, dtype="float32"): + return np.ones(shape, dtype=dtype) + + +def zeros(shape, dtype="float32"): + return np.zeros(shape, dtype=dtype) + + +def absolute(x): + return np.absolute(x) + + +def abs(x): + return absolute(x) + + +def all(x, axis=None, keepdims=False): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.all(x, axis=axis, keepdims=keepdims) + + +def any(x, axis=None, keepdims=False): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.any(x, axis=axis, keepdims=keepdims) + + +def amax(x, axis=None, keepdims=False): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.amax(x, axis=axis, keepdims=keepdims) + + +def amin(x, axis=None, keepdims=False): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.amin(x, axis=axis, keepdims=keepdims) + + +def append( + x1, + x2, + axis=None, +): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.append(x1, x2, axis=axis) + + +def arange(start, stop=None, step=None, dtype=None): + return np.arange(start, stop, step=step, dtype=dtype) + + +def arccos(x): + return np.arccos(x) + + +def arcsin(x): + return np.arcsin(x) + + +def arctan(x): + return np.arctan(x) + + +def arctan2(x1, x2): + return np.arctan2(x1, x2) + + +def argmax(x, axis=None): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.argmax(x, axis=axis) + + +def argmin(x, axis=None): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.argmin(x, axis=axis) + + +def argsort(x, axis=-1): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.argsort(x, axis=axis) + + +def array(x, dtype=None): + return np.array(x, dtype=dtype) + + +def average(x, axis=None, weights=None): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.average(x, weights=weights, axis=axis) + + +def bincount(x, weights=None, minlength=0): + return np.bincount(x, weights, minlength) + + +def broadcast_to(x, shape): + return np.broadcast_to(x, shape) + + +def ceil(x): + return np.ceil(x) + + +def clip(x, x_min, x_max): + return np.clip(x, x_min, x_max) + + +def concatenate(xs, axis=0): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.concatenate(xs, axis=axis) + + +def conjugate(x): + return np.conjugate(x) + + +def conj(x): + return conjugate(x) + + +def copy(x): + return np.copy(x) + + +def cos(x): + return np.cos(x) + + +def count_nonzero(x, axis=None): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.count_nonzero(x, axis=axis) + + +def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.cross( + x1, + x2, + axisa=axisa, + axisb=axisb, + axisc=axisc, + axis=axis, + ) + + +def cumprod(x, axis=None): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.cumprod(x, axis=axis) + + +def cumsum(x, axis=None): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.cumsum(x, axis=axis) + + +def diag(x, k=0): + return np.diag(x, k=k) + + +def diagonal(x, offset=0, axis1=0, axis2=1): + axis1 = tuple(axis1) if isinstance(axis1, list) else axis1 + axis2 = tuple(axis2) if isinstance(axis2, list) else axis2 + return np.diagonal( + x, + offset=offset, + axis1=axis1, + axis2=axis2, + ) + + +def dot(x, y): + return np.dot(x, y) + + +def empty(shape, dtype="float32"): + return np.empty(shape, dtype=dtype) + + +def equal(x1, x2): + return np.equal(x1, x2) + + +def exp(x): + return np.exp(x) + + +def expand_dims(x, axis): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.expand_dims(x, axis) + + +def expm1(x): + return np.expm1(x) + + +def flip(x, axis=None): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.flip(x, axis=axis) + + +def floor(x): + return np.floor(x) + + +def full(shape, fill_value, dtype=None): + return np.full(shape, fill_value, dtype=dtype) + + +def full_like(x, fill_value, dtype=None): + return np.full_like(x, fill_value, dtype=dtype) + + +def greater(x1, x2): + return np.greater(x1, x2) + + +def greater_equal(x1, x2): + return np.greater_equal(x1, x2) + + +def hstack(xs): + return np.hstack(xs) + + +def identity(n, dtype="float32"): + return np.identity(n, dtype=dtype) + + +def imag(x): + return np.imag(x) + + +def isclose(x1, x2): + return np.isclose(x1, x2) + + +def isfinite(x): + return np.isfinite(x) + + +def isinf(x): + return np.isinf(x) + + +def isnan(x): + return np.isnan(x) + + +def less(x1, x2): + return np.less(x1, x2) + + +def less_equal(x1, x2): + return np.less_equal(x1, x2) + + +def linspace( + start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 +): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.linspace( + start, + stop, + num=num, + endpoint=endpoint, + retstep=retstep, + dtype=dtype, + axis=axis, + ) + + +def log(x): + return np.log(x) + + +def log10(x): + return np.log10(x) + + +def log1p(x): + return np.log1p(x) + + +def log2(x): + return np.log2(x) + + +def logaddexp(x1, x2): + return np.logaddexp(x1, x2) + + +def logical_and(x1, x2): + return np.logical_and(x1, x2) + + +def logical_not(x): + return np.logical_not(x) + + +def logical_or(x1, x2): + return np.logical_or(x1, x2) + + +def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): + return np.logspace( + start, + stop, + num=num, + endpoint=endpoint, + base=base, + dtype=dtype, + axis=axis, + ) + + +def maximum(x1, x2): + return np.maximum(x1, x2) + + +def meshgrid(*x, indexing="xy"): + return np.meshgrid(*x, indexing=indexing) + + +def min(x, axis=None, keepdims=False, initial=None): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.min(x, axis=axis, keepdims=keepdims, initial=initial) + + +def minimum(x1, x2): + return np.minimum(x1, x2) + + +def mod(x1, x2): + return np.mod(x1, x2) + + +def moveaxis(x, source, destination): + return np.moveaxis(x, source=source, destination=destination) + + +def nan_to_num(x): + return np.nan_to_num(x) + + +def ndim(x): + return np.ndim(x) + + +def nonzero(x): + return np.nonzero(x) + + +def not_equal(x1, x2): + return np.not_equal(x1, x2) + + +def zeros_like(x, dtype=None): + return np.zeros_like(x, dtype=dtype) + + +def ones_like(x, dtype=None): + return np.ones_like(x, dtype=dtype) + + +def outer(x1, x2): + return np.outer(x1, x2) + + +def pad(x, pad_width, mode="constant"): + return np.pad(x, pad_width, mode=mode) + + +def prod(x, axis=None, keepdims=False, dtype=None): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.prod(x, axis=axis, keepdims=keepdims, dtype=dtype) + + +def ravel(x): + return np.ravel(x) + + +def real(x): + return np.real(x) + + +def reciprocal(x): + return np.reciprocal(x) + + +def repeat(x, repeats, axis=None): + return np.repeat(x, repeats, axis=axis) + + +def reshape(x, new_shape): + return np.reshape(x, new_shape) + + +def roll(x, shift, axis=None): + return np.roll(x, shift, axis=axis) + + +def sign(x): + return np.sign(x) + + +def sin(x): + return np.sin(x) + + +def size(x): + return np.size(x) + + +def sort(x, axis=-1): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.sort(x, axis=axis) + + +def split(x, indices_or_sections, axis=0): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.split(x, indices_or_sections, axis=axis) + + +def stack(x, axis=0): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.stack(x, axis=axis) + + +def std(x, axis=None, keepdims=False): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.std(x, axis=axis, keepdims=keepdims) + + +def swapaxes(x, axis1, axis2): + return np.swapaxes(x, axis1=axis1, axis2=axis2) + + +def take(x, indices, axis=None): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.take(x, indices, axis=axis) + + +def take_along_axis(x, indices, axis=None): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.take_along_axis(x, indices, axis=axis) + + +def tan(x): + return np.tan(x) + + +def tensordot(x1, x2, axes=2): + axes = tuple(axes) if isinstance(axes, list) else axes + return np.tensordot(x1, x2, axes=axes) + + +def round(x, decimals=0): + return np.round(x, decimals=decimals) + + +def tile(x, repeats): + return np.tile(x, repeats) + + +def trace(x, offset=0, axis1=0, axis2=1): + axis1 = tuple(axis1) if isinstance(axis1, list) else axis1 + axis2 = tuple(axis2) if isinstance(axis2, list) else axis2 + return np.trace(x, offset=offset, axis1=axis1, axis2=axis2) + + +def tri(N, M=None, k=0, dtype="float32"): + return np.tri(N, M=M, k=k, dtype=dtype) + + +def tril(x, k=0): + return np.tril(x, k=k) + + +def triu(x, k=0): + return np.triu(x, k=k) + + +def vdot(x1, x2): + return np.vdot(x1, x2) + + +def vstack(xs): + return np.vstack(xs) + + +def where(condition, x1, x2): + return np.where(condition, x1, x2) + + +def divide(x1, x2): + return np.divide(x1, x2) + + +def true_divide(x1, x2): + return np.true_divide(x1, x2) + + +def power(x1, x2): + return np.power(x1, x2) + + +def negative(x): + return np.negative(x) + + +def square(x): + return np.square(x) + + +def sqrt(x): + return np.sqrt(x) + + +def squeeze(x, axis=None): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.squeeze(x, axis=axis) + + +def transpose(x, axes=None): + axes = tuple(axes) if isinstance(axes, list) else axes + return np.transpose(x, axes=axes) + + +def var(x, axis=None, keepdims=False): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.var(x, axis=axis, keepdims=keepdims) + + +def sum(x, axis=None, keepdims=False): + axis = tuple(axis) if isinstance(axis, list) else axis + return np.sum(x, axis=axis, keepdims=keepdims) + + +def eye(N, M=None, k=0, dtype="float32"): + return np.eye(N, M=M, k=k, dtype=dtype) diff --git a/keras_core/backend/numpy/random.py b/keras_core/backend/numpy/random.py new file mode 100644 index 000000000..d794e5de1 --- /dev/null +++ b/keras_core/backend/numpy/random.py @@ -0,0 +1,88 @@ +import numpy as np + +from keras_core.backend.config import floatx +from keras_core.backend.numpy.nn import softmax +from keras_core.random.seed_generator import SeedGenerator +from keras_core.random.seed_generator import draw_seed +from keras_core.random.seed_generator import make_default_seed + + +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + return rng.normal(size=shape, loc=mean, scale=stddev).astype(dtype) + + +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + return rng.uniform(size=shape, low=minval, high=maxval).astype(dtype) + + +def categorical(logits, num_samples, dtype="int64", seed=None): + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + output = [] + for logits_instance in logits: + probabilities = softmax(logits_instance) + classes = np.arange(logits_instance.shape[-1]) + samples = rng.choice(classes, size=num_samples, p=probabilities) + output.append(samples) + return np.array(output).astype(dtype) + + +def randint(shape, minval, maxval, dtype="int32", seed=None): + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + output = rng.integers(low=minval, high=maxval, size=shape, dtype=dtype) + return output + + +def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + + lower_bound = mean - 2 * stddev + upper_bound = mean + 2 * stddev + + flat_shape = np.prod(shape) + random_numbers = np.empty(0) + + # loop until we have enough valid numbers to fill our desired shape + while random_numbers.shape[0] < flat_shape: + # Generate a batch of random numbers from a normal distribution + batch = rng.normal(loc=mean, scale=stddev, size=flat_shape) + + # Filter the numbers to keep only those within the specified bounds + valid = batch[(batch >= lower_bound) & (batch <= upper_bound)] + + # Append the valid numbers to the result array + random_numbers = np.append(random_numbers, valid) + + # Truncate the result array to the desired size and reshape it + return random_numbers[:flat_shape].astype(dtype).reshape(shape) + + +def dropout(inputs, rate, noise_shape=None, seed=None): + seed = draw_seed(seed) + + keep_prob = 1.0 - rate + + # If noise_shape is not provided, use the shape of inputs + if noise_shape is None: + noise_shape = inputs.shape + else: + # If noise_shape is provided, replace None with corresponding + # input shape + noise_shape = [ + n if n is not None else inputs.shape[i] + for i, n in enumerate(noise_shape) + ] + + rng = np.random.default_rng(seed) + mask = rng.uniform(size=noise_shape) < keep_prob + mask = np.broadcast_to(mask, inputs.shape) + return np.where(mask, inputs / keep_prob, np.zeros_like(inputs)) diff --git a/keras_core/backend/numpy/rnn.py b/keras_core/backend/numpy/rnn.py new file mode 100644 index 000000000..c0d14576e --- /dev/null +++ b/keras_core/backend/numpy/rnn.py @@ -0,0 +1,236 @@ +import numpy as np +import tree + +from keras_core.utils.nest import pack_sequence_as + + +def rnn( + step_function, + inputs, + initial_states, + go_backwards=False, + mask=None, + constants=None, + unroll=False, + input_length=None, + time_major=False, + zero_output_for_mask=False, + return_all_outputs=True, +): + def swap_batch_timestep(input_t): + # Swap the batch and timestep dim for the incoming tensor. + axes = list(range(len(input_t.shape))) + axes[0], axes[1] = 1, 0 + return np.transpose(input_t, axes) + + if not time_major: + inputs = tree.map_structure(swap_batch_timestep, inputs) + + flattened_inputs = tree.flatten(inputs) + time_steps = flattened_inputs[0].shape[0] + + if mask is not None: + if mask.dtype != "bool": + mask = mask.astype("bool") + if len(mask.shape) == 2: + mask = np.expand_dims(mask, axis=-1) + if not time_major: + mask = swap_batch_timestep(mask) + + if constants is None: + constants = [] + + def _expand_mask(mask_t, input_t, fixed_dim=1): + if tree.is_nested(mask_t): + raise ValueError( + f"mask_t is expected to be tensor, but got {mask_t}" + ) + if tree.is_nested(input_t): + raise ValueError( + f"input_t is expected to be tensor, but got {input_t}" + ) + rank_diff = len(input_t.shape) - len(mask_t.shape) + for _ in range(rank_diff): + mask_t = np.expand_dims(mask_t, -1) + multiples = [1] * fixed_dim + list(input_t.shape[fixed_dim:]) + return np.tile(mask_t, multiples) + + if unroll: + if not time_steps: + raise ValueError("Unrolling requires a fixed number of timesteps.") + states = tuple(initial_states) + successive_states = [] + successive_outputs = [] + + # Process the input tensors. The input tensor need to be split on the + # time_step dim, and reverse if go_backwards is True. In the case of + # nested input, the input is flattened and then transformed + # individually. The result of this will be a tuple of lists, each of + # the item in tuple is list of the tensor with shape (batch, feature) + def _process_single_input_t(input_t): + input_t = unstack(input_t) # unstack for time_step dim + if go_backwards: + input_t.reverse() + return input_t + + if tree.is_nested(inputs): + processed_input = tree.map_structure( + _process_single_input_t, inputs + ) + else: + processed_input = (_process_single_input_t(inputs),) + + def _get_input_tensor(time): + inp = [t_[time] for t_ in processed_input] + return pack_sequence_as(inputs, inp) + + if mask is not None: + mask_list = unstack(mask) + if go_backwards: + mask_list.reverse() + + for i in range(time_steps): + inp = _get_input_tensor(i) + mask_t = mask_list[i] + output, new_states = step_function( + inp, tuple(states) + tuple(constants) + ) + tiled_mask_t = _expand_mask(mask_t, output) + + if not successive_outputs: + prev_output = np.zeros_like(output) + else: + prev_output = successive_outputs[-1] + + output = np.where(tiled_mask_t, output, prev_output) + + flat_states = tree.flatten(states) + flat_new_states = tree.flatten(new_states) + tiled_mask_t = tuple( + _expand_mask(mask_t, s) for s in flat_states + ) + flat_final_states = tuple( + np.where(m, s, ps) + for m, s, ps in zip( + tiled_mask_t, flat_new_states, flat_states + ) + ) + states = pack_sequence_as(states, flat_final_states) + + if return_all_outputs: + successive_outputs.append(output) + successive_states.append(states) + else: + successive_outputs = [output] + successive_states = [states] + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = np.stack(successive_outputs) + + else: # mask is None + for i in range(time_steps): + inp = _get_input_tensor(i) + output, states = step_function( + inp, tuple(states) + tuple(constants) + ) + if return_all_outputs: + successive_outputs.append(output) + successive_states.append(states) + else: + successive_outputs = [output] + successive_states = [states] + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = np.stack(successive_outputs) + + else: # Unroll == False + if mask is not None: + + def _step(states, current_input): + current_input, current_mask = current_input + is_masked = np.all( + np.logical_not(current_mask), axis=-1, keepdims=True + ) + + output_t, new_states = step_function(current_input, states) + + if zero_output_for_mask: + masked_outs = np.where( + is_masked, np.zeros_like(output_t), output_t + ) + else: + # Assume the first state is the previous output. + output_tm1 = states[0] + masked_outs = np.where(is_masked, output_tm1, output_t) + + new_states = [ + np.where(is_masked, s, ns) + for s, ns in zip(states, new_states) + ] + return (new_states, masked_outs) + + scan_xs = (inputs, mask) + + else: + + def _step(states, current_input): + output_t, new_states = step_function(current_input, states) + return new_states, output_t + + scan_xs = inputs + + new_states, outputs = numpy_scan( + f=_step, + init=initial_states, + xs=scan_xs, + reverse=go_backwards, + mask=mask, + ) + + if go_backwards: + outputs = np.flip(outputs, axis=0) + last_output = outputs[-1] + + if not time_major: + outputs = tree.map_structure(swap_batch_timestep, outputs) + + return last_output, outputs, new_states + + +def lstm(*args, **kwargs): + raise NotImplementedError + + +def gru(*args, **kwargs): + raise NotImplementedError + + +def unstack(x, axis=0): + return [x.take(i, axis) for i in range(x.shape[axis])] + + +def numpy_scan(f, init, xs, reverse=False, mask=None): + states = init + outputs = [] + + if mask is not None: + x, mask = xs + x = np.flip(x, axis=0) if reverse else x + mask = np.flip(mask, axis=0) if reverse else mask + + for each_x, each_mask in zip(x, mask): + states, output = f(states, (each_x, each_mask)) + outputs.append(output) + else: + xs = np.flip(xs, axis=0) if reverse else xs + + for x in xs: + states, output = f(states, x) + outputs.append(output) + + outputs = np.array(outputs) + + if reverse: + outputs = np.flip(outputs, axis=0) + + return states, outputs diff --git a/keras_core/backend/numpy/trainer.py b/keras_core/backend/numpy/trainer.py new file mode 100644 index 000000000..875d3468f --- /dev/null +++ b/keras_core/backend/numpy/trainer.py @@ -0,0 +1,18 @@ +class NumpyTrainer: + def fit(self): + raise NotImplementedError("Trainer not implemented for NumPy backend.") + + def predict(self): + raise NotImplementedError("Trainer not implemented for NumPy backend.") + + def evaluate(self): + raise NotImplementedError("Trainer not implemented for NumPy backend.") + + def train_on_batch(self): + raise NotImplementedError("Trainer not implemented for NumPy backend.") + + def test_on_batch(self): + raise NotImplementedError("Trainer not implemented for NumPy backend.") + + def predict_on_batch(self): + raise NotImplementedError("Trainer not implemented for NumPy backend.") diff --git a/keras_core/callbacks/callback_test.py b/keras_core/callbacks/callback_test.py index c4d5c7a6b..b4f49ed5a 100644 --- a/keras_core/callbacks/callback_test.py +++ b/keras_core/callbacks/callback_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import models from keras_core import testing @@ -6,6 +7,7 @@ from keras_core.callbacks.callback import Callback class CallbackTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_model_state_is_current_on_epoch_end(self): class TestModel(models.Model): def __init__(self): diff --git a/keras_core/callbacks/csv_logger_test.py b/keras_core/callbacks/csv_logger_test.py index 1e8260cc2..27a2d7b57 100644 --- a/keras_core/callbacks/csv_logger_test.py +++ b/keras_core/callbacks/csv_logger_test.py @@ -4,6 +4,7 @@ import re import tempfile import numpy as np +import pytest from keras_core import callbacks from keras_core import initializers @@ -19,6 +20,7 @@ BATCH_SIZE = 4 class CSVLoggerTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_CSVLogger(self): OUTPUT_DIM = 1 np.random.seed(1337) @@ -126,6 +128,7 @@ class CSVLoggerTest(testing.TestCase): else: self.assertEqual(row["val_loss"], "NA") + @pytest.mark.requires_trainable_backend def test_stop_training_csv(self): # Test that using the CSVLogger callback with the TerminateOnNaN # callback does not result in invalid CSVs. diff --git a/keras_core/callbacks/early_stopping_test.py b/keras_core/callbacks/early_stopping_test.py index d9924da47..19d346894 100644 --- a/keras_core/callbacks/early_stopping_test.py +++ b/keras_core/callbacks/early_stopping_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import callbacks from keras_core import layers @@ -7,6 +8,7 @@ from keras_core import testing class EarlyStoppingTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_early_stopping(self): x_train = np.random.random((10, 5)) y_train = np.random.random((10, 1)) @@ -48,6 +50,7 @@ class EarlyStoppingTest(testing.TestCase): verbose=0, ) + @pytest.mark.requires_trainable_backend def test_early_stopping_patience(self): cases = [0, 1, 2, 3] losses = [10.0, 9.0, 8.0, 9.0, 8.9, 8.8, 8.7, 8.6, 8.5] @@ -65,6 +68,7 @@ class EarlyStoppingTest(testing.TestCase): self.assertEqual(stopper.stopped_epoch, max(patience, 1) + 2) + @pytest.mark.requires_trainable_backend def test_early_stopping_reuse(self): patience = 3 data = np.random.random((100, 1)) @@ -91,6 +95,7 @@ class EarlyStoppingTest(testing.TestCase): ) assert len(hist.epoch) >= patience + @pytest.mark.requires_trainable_backend def test_early_stopping_with_baseline(self): baseline = 0.6 x_train = np.random.random((10, 5)) @@ -169,6 +174,7 @@ class EarlyStoppingTest(testing.TestCase): self.assertEqual(epochs_trained, 5) self.assertEqual(early_stop.model.get_weights(), 2) + @pytest.mark.requires_trainable_backend def test_early_stopping_with_start_from_epoch(self): x_train = np.random.random((10, 5)) y_train = np.random.random((10, 1)) diff --git a/keras_core/callbacks/lambda_callback_test.py b/keras_core/callbacks/lambda_callback_test.py index 18c75f784..082f0f48f 100644 --- a/keras_core/callbacks/lambda_callback_test.py +++ b/keras_core/callbacks/lambda_callback_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from absl import logging from keras_core import callbacks @@ -10,6 +11,7 @@ from keras_core.models.sequential import Sequential class LambdaCallbackTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_LambdaCallback(self): BATCH_SIZE = 4 model = Sequential( diff --git a/keras_core/callbacks/learning_rate_scheduler_test.py b/keras_core/callbacks/learning_rate_scheduler_test.py index d8eb13223..276468f4b 100644 --- a/keras_core/callbacks/learning_rate_scheduler_test.py +++ b/keras_core/callbacks/learning_rate_scheduler_test.py @@ -1,3 +1,5 @@ +import pytest + from keras_core import callbacks from keras_core import layers from keras_core import optimizers @@ -29,6 +31,7 @@ class LearningRateSchedulerTest(testing.TestCase): self.x_train = x_train self.y_train = y_train + @pytest.mark.requires_trainable_backend def test_updates_learning_rate(self): lr_scheduler = callbacks.LearningRateScheduler( lambda step: 1.0 / (2.0 + step), verbose=1 @@ -43,6 +46,7 @@ class LearningRateSchedulerTest(testing.TestCase): self.assertEqual(self.model.optimizer.learning_rate.value, 0.5) + @pytest.mark.requires_trainable_backend def test_verbose_logging(self): lr_scheduler = callbacks.LearningRateScheduler( lambda step: 1.0 / (1.0 + step), verbose=1 @@ -59,6 +63,7 @@ class LearningRateSchedulerTest(testing.TestCase): expected_log = "LearningRateScheduler setting learning rate to 1.0" self.assertTrue(any(expected_log in log for log in logs.output)) + @pytest.mark.requires_trainable_backend def test_schedule_dependent_on_previous_learning_rate(self): lr_scheduler = callbacks.LearningRateScheduler(lambda step, lr: lr / 2) @@ -78,6 +83,7 @@ class LearningRateSchedulerTest(testing.TestCase): self.model.optimizer.learning_rate.value, initial_lr / 4.0 ) + @pytest.mark.requires_trainable_backend def test_throws_when_optimizer_has_schedule(self): lr_scheduler = callbacks.LearningRateScheduler(lambda step, lr: lr / 2) diff --git a/keras_core/callbacks/model_checkpoint_test.py b/keras_core/callbacks/model_checkpoint_test.py index 40cef3935..08f17072c 100644 --- a/keras_core/callbacks/model_checkpoint_test.py +++ b/keras_core/callbacks/model_checkpoint_test.py @@ -31,6 +31,7 @@ class ModelCheckpointTest(testing.TestCase): h5py is None, reason="`h5py` is a required dependency for `ModelCheckpoint` tests.", ) + @pytest.mark.requires_trainable_backend def test_model_checkpoint_options(self): def get_model(): model = Sequential( @@ -445,6 +446,7 @@ class ModelCheckpointTest(testing.TestCase): h5py is None, reason="`h5py` is a required dependency for `ModelCheckpoint` tests.", ) + @pytest.mark.requires_trainable_backend def test_model_checkpoint_loading(self): def get_model(): inputs = layers.Input(shape=(INPUT_DIM,), batch_size=2) diff --git a/keras_core/callbacks/reduce_lr_on_plateau_test.py b/keras_core/callbacks/reduce_lr_on_plateau_test.py index d97de6209..73e531dd3 100644 --- a/keras_core/callbacks/reduce_lr_on_plateau_test.py +++ b/keras_core/callbacks/reduce_lr_on_plateau_test.py @@ -1,3 +1,5 @@ +import pytest + from keras_core import callbacks from keras_core import layers from keras_core import optimizers @@ -32,6 +34,7 @@ class ReduceLROnPlateauTest(testing.TestCase): self.y_train = y_train self.y_test = y_test + @pytest.mark.requires_trainable_backend def test_reduces_lr_with_model_fit(self): reduce_lr = callbacks.ReduceLROnPlateau( patience=1, factor=0.1, monitor="val_loss", min_delta=100 @@ -47,6 +50,7 @@ class ReduceLROnPlateauTest(testing.TestCase): self.assertEqual(self.model.optimizer.learning_rate.value, 0.01) + @pytest.mark.requires_trainable_backend def test_throws_when_optimizer_has_schedule(self): reduce_lr = callbacks.ReduceLROnPlateau( patience=1, factor=0.1, monitor="val_loss", min_delta=100 @@ -73,6 +77,7 @@ class ReduceLROnPlateauTest(testing.TestCase): epochs=2, ) + @pytest.mark.requires_trainable_backend def test_verbose_logging(self): reduce_lr = callbacks.ReduceLROnPlateau( patience=1, factor=0.1, monitor="val_loss", min_delta=100, verbose=1 @@ -90,6 +95,7 @@ class ReduceLROnPlateauTest(testing.TestCase): expected_log = "ReduceLROnPlateau reducing learning rate to 0.01" self.assertTrue(any(expected_log in log for log in logs.output)) + @pytest.mark.requires_trainable_backend def test_honors_min_lr(self): reduce_lr = callbacks.ReduceLROnPlateau( patience=1, @@ -109,6 +115,7 @@ class ReduceLROnPlateauTest(testing.TestCase): self.assertEqual(self.model.optimizer.learning_rate.value, 0.005) + @pytest.mark.requires_trainable_backend def test_cooldown(self): reduce_lr = callbacks.ReduceLROnPlateau( patience=1, diff --git a/keras_core/callbacks/remote_monitor_test.py b/keras_core/callbacks/remote_monitor_test.py index 55c708c62..a53fb7cab 100644 --- a/keras_core/callbacks/remote_monitor_test.py +++ b/keras_core/callbacks/remote_monitor_test.py @@ -3,6 +3,7 @@ from unittest import mock import numpy as np +from keras_core import backend from keras_core import callbacks from keras_core import layers from keras_core import testing @@ -60,6 +61,8 @@ class TerminateOnNaNTest(testing.TestCase): if requests is None: self.skipTest("`requests` required to run this test") + if backend.backend() == "numpy": + self.skipTest("Trainer not implemented from NumPy backend.") TRAIN_SAMPLES = 10 TEST_SAMPLES = 10 INPUT_DIM = 3 diff --git a/keras_core/callbacks/tensorboard_test.py b/keras_core/callbacks/tensorboard_test.py index 109e292a8..eecad42b5 100644 --- a/keras_core/callbacks/tensorboard_test.py +++ b/keras_core/callbacks/tensorboard_test.py @@ -143,6 +143,7 @@ class TestTensorBoardV2(testing.TestCase): model.compile("sgd", "mse") return model + @pytest.mark.requires_trainable_backend def test_TensorBoard_basic(self): model = self._get_model() x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) @@ -171,6 +172,7 @@ class TestTensorBoardV2(testing.TestCase): }, ) + @pytest.mark.requires_trainable_backend def test_TensorBoard_across_invocations(self): """Regression test for summary writer resource use-after-free.""" model = self._get_model() @@ -201,6 +203,7 @@ class TestTensorBoardV2(testing.TestCase): }, ) + @pytest.mark.requires_trainable_backend def test_TensorBoard_no_spurious_event_files(self): model = self._get_model() x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) @@ -214,6 +217,7 @@ class TestTensorBoardV2(testing.TestCase): events_file_run_basenames.add(os.path.basename(dirpath)) self.assertEqual(events_file_run_basenames, {"train"}) + @pytest.mark.requires_trainable_backend def test_TensorBoard_batch_metrics(self): model = self._get_model() x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) @@ -243,6 +247,7 @@ class TestTensorBoardV2(testing.TestCase): }, ) + @pytest.mark.requires_trainable_backend def test_TensorBoard_learning_rate_schedules(self): model = self._get_model(compile_model=False) opt = optimizers.SGD(schedules.CosineDecay(0.01, 1)) @@ -267,6 +272,7 @@ class TestTensorBoardV2(testing.TestCase): }, ) + @pytest.mark.requires_trainable_backend def test_TensorBoard_global_step(self): model = self._get_model(compile_model=False) opt = optimizers.SGD(schedules.CosineDecay(0.01, 1)) @@ -306,6 +312,7 @@ class TestTensorBoardV2(testing.TestCase): }, ) + @pytest.mark.requires_trainable_backend def test_TensorBoard_weight_histograms(self): model = self._get_model() x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) @@ -338,6 +345,7 @@ class TestTensorBoardV2(testing.TestCase): {_ObservedSummary(logdir=train_dir, tag="histogram")}, ) + @pytest.mark.requires_trainable_backend def test_TensorBoard_weight_images(self): model = self._get_model() x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) @@ -392,6 +400,7 @@ class TestTensorBoardV2(testing.TestCase): expected_image_summaries, ) + @pytest.mark.requires_trainable_backend def test_TensorBoard_projector_callback(self): model = models.Sequential( [ @@ -433,6 +442,7 @@ class TestTensorBoardV2(testing.TestCase): ], ) + @pytest.mark.requires_trainable_backend def test_custom_summary(self): def scalar_v2_mock(name, data, step=None): """A reimplementation of the scalar plugin to avoid circular @@ -592,6 +602,7 @@ class TestTensorBoardV2(testing.TestCase): backend.backend() == "torch", reason="Torch backend requires blocking numpy conversion.", ) + @pytest.mark.requires_trainable_backend def test_TensorBoard_non_blocking(self): logdir, _, _ = self._get_log_dirs() model = models.Sequential([layers.Dense(1)]) diff --git a/keras_core/callbacks/terminate_on_nan_test.py b/keras_core/callbacks/terminate_on_nan_test.py index 16b3e9a90..8d61171c9 100644 --- a/keras_core/callbacks/terminate_on_nan_test.py +++ b/keras_core/callbacks/terminate_on_nan_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import callbacks from keras_core import initializers @@ -9,6 +10,7 @@ from keras_core.utils import numerical_utils class TerminateOnNaNTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_TerminateOnNaN(self): TRAIN_SAMPLES = 10 TEST_SAMPLES = 10 diff --git a/keras_core/layers/activations/activation_test.py b/keras_core/layers/activations/activation_test.py index 49be3cf35..b03ffce07 100644 --- a/keras_core/layers/activations/activation_test.py +++ b/keras_core/layers/activations/activation_test.py @@ -1,9 +1,12 @@ +import pytest + from keras_core import activations from keras_core import layers from keras_core import testing class ActivationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_activation_basics(self): self.run_layer_test( layers.Activation, diff --git a/keras_core/layers/activations/elu_test.py b/keras_core/layers/activations/elu_test.py index 2eb264331..6b85c1903 100644 --- a/keras_core/layers/activations/elu_test.py +++ b/keras_core/layers/activations/elu_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import tensorflow as tf from keras_core import testing @@ -10,6 +11,7 @@ class ELUTest(testing.TestCase): elu_layer = elu.ELU() self.run_class_serialization_test(elu_layer) + @pytest.mark.requires_trainable_backend def test_elu(self): self.run_layer_test( elu.ELU, diff --git a/keras_core/layers/activations/leaky_relu_test.py b/keras_core/layers/activations/leaky_relu_test.py index 546286990..18e24033e 100644 --- a/keras_core/layers/activations/leaky_relu_test.py +++ b/keras_core/layers/activations/leaky_relu_test.py @@ -1,10 +1,12 @@ import numpy as np +import pytest from keras_core import testing from keras_core.layers.activations import leaky_relu class LeakyReLUTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_leaky_relu(self): self.run_layer_test( leaky_relu.LeakyReLU, diff --git a/keras_core/layers/activations/prelu_test.py b/keras_core/layers/activations/prelu_test.py index 65a4222f6..ea4f79559 100644 --- a/keras_core/layers/activations/prelu_test.py +++ b/keras_core/layers/activations/prelu_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import tensorflow as tf from keras_core import testing @@ -6,6 +7,7 @@ from keras_core.layers.activations import prelu class PReLUTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_prelu(self): self.run_layer_test( prelu.PReLU, diff --git a/keras_core/layers/activations/relu_test.py b/keras_core/layers/activations/relu_test.py index 061dab5db..c00c2d301 100644 --- a/keras_core/layers/activations/relu_test.py +++ b/keras_core/layers/activations/relu_test.py @@ -1,10 +1,12 @@ import numpy as np +import pytest from keras_core import testing from keras_core.layers.activations import relu class ReLUTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_relu(self): self.run_layer_test( relu.ReLU, diff --git a/keras_core/layers/activations/softmax_test.py b/keras_core/layers/activations/softmax_test.py index 106829de4..466538827 100644 --- a/keras_core/layers/activations/softmax_test.py +++ b/keras_core/layers/activations/softmax_test.py @@ -1,10 +1,12 @@ import numpy as np +import pytest from keras_core import testing from keras_core.layers.activations import softmax class SoftmaxTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_softmax(self): self.run_layer_test( softmax.Softmax, diff --git a/keras_core/layers/attention/multi_head_attention_test.py b/keras_core/layers/attention/multi_head_attention_test.py index 52d8d3ca1..61a76788d 100644 --- a/keras_core/layers/attention/multi_head_attention_test.py +++ b/keras_core/layers/attention/multi_head_attention_test.py @@ -1,6 +1,8 @@ import numpy as np +import pytest from absl.testing import parameterized +from keras_core import backend from keras_core import initializers from keras_core import layers from keras_core import testing @@ -164,6 +166,10 @@ class MultiHeadAttentionTest(testing.TestCase, parameterized.TestCase): layer._output_dense.kernel, ) + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support masking.", + ) def test_query_mask_progagation(self): """Test automatic propagation of the query's mask.""" layer = layers.MultiHeadAttention(num_heads=2, key_dim=2) @@ -175,6 +181,10 @@ class MultiHeadAttentionTest(testing.TestCase, parameterized.TestCase): self.assertAllClose(masked_query._keras_mask, output._keras_mask) @parameterized.named_parameters(("causal", True), ("not_causal", 0)) + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support masking.", + ) def test_masking(self, use_causal_mask): """Test that the value and causal masks are taken into account.""" layer = layers.MultiHeadAttention(num_heads=2, key_dim=2) diff --git a/keras_core/layers/convolutional/conv_test.py b/keras_core/layers/convolutional/conv_test.py index 50e1b76e8..014078db9 100644 --- a/keras_core/layers/convolutional/conv_test.py +++ b/keras_core/layers/convolutional/conv_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import tensorflow as tf from absl.testing import parameterized @@ -53,6 +54,7 @@ class ConvBasicTest(testing.TestCase, parameterized.TestCase): "output_shape": (3, 2, 6), }, ) + @pytest.mark.requires_trainable_backend def test_conv1d_basic( self, filters, @@ -119,6 +121,7 @@ class ConvBasicTest(testing.TestCase, parameterized.TestCase): "output_shape": (3, 2, 4, 6), }, ) + @pytest.mark.requires_trainable_backend def test_conv2d_basic( self, filters, @@ -185,6 +188,7 @@ class ConvBasicTest(testing.TestCase, parameterized.TestCase): "output_shape": (3, 2, 4, 2, 6), }, ) + @pytest.mark.requires_trainable_backend def test_conv3d_basic( self, filters, diff --git a/keras_core/layers/convolutional/conv_transpose_test.py b/keras_core/layers/convolutional/conv_transpose_test.py index 00070b848..21476cb8c 100644 --- a/keras_core/layers/convolutional/conv_transpose_test.py +++ b/keras_core/layers/convolutional/conv_transpose_test.py @@ -44,6 +44,7 @@ class ConvTransposeBasicTest(testing.TestCase, parameterized.TestCase): "output_shape": (2, 16, 6), }, ) + @pytest.mark.requires_trainable_backend def test_conv1d_transpose_basic( self, filters, @@ -121,6 +122,7 @@ class ConvTransposeBasicTest(testing.TestCase, parameterized.TestCase): "output_shape": (1, 224, 224, 2), }, ) + @pytest.mark.requires_trainable_backend def test_conv2d_transpose_basic( self, filters, @@ -193,6 +195,7 @@ class ConvTransposeBasicTest(testing.TestCase, parameterized.TestCase): "output_shape": (2, 16, 9, 17, 6), }, ) + @pytest.mark.requires_trainable_backend def test_conv3d_transpose_basic( self, filters, diff --git a/keras_core/layers/convolutional/depthwise_conv_test.py b/keras_core/layers/convolutional/depthwise_conv_test.py index ea066ed1a..814912b42 100644 --- a/keras_core/layers/convolutional/depthwise_conv_test.py +++ b/keras_core/layers/convolutional/depthwise_conv_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import tensorflow as tf from absl.testing import parameterized @@ -39,6 +40,7 @@ class DepthwiseConvBasicTest(testing.TestCase, parameterized.TestCase): "output_shape": (3, 2, 24), }, ) + @pytest.mark.requires_trainable_backend def test_depthwise_conv1d_basic( self, depth_multiplier, @@ -100,6 +102,7 @@ class DepthwiseConvBasicTest(testing.TestCase, parameterized.TestCase): "output_shape": (3, 2, 2, 24), }, ) + @pytest.mark.requires_trainable_backend def test_depthwise_conv2d_basic( self, depth_multiplier, diff --git a/keras_core/layers/convolutional/separable_conv_test.py b/keras_core/layers/convolutional/separable_conv_test.py index b523d5c6a..e9cbebb78 100644 --- a/keras_core/layers/convolutional/separable_conv_test.py +++ b/keras_core/layers/convolutional/separable_conv_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import tensorflow as tf from absl.testing import parameterized @@ -42,6 +43,7 @@ class SeparableConvBasicTest(testing.TestCase, parameterized.TestCase): "output_shape": (3, 2, 6), }, ) + @pytest.mark.requires_trainable_backend def test_separable_conv1d_basic( self, depth_multiplier, @@ -108,6 +110,7 @@ class SeparableConvBasicTest(testing.TestCase, parameterized.TestCase): "output_shape": (3, 2, 2, 6), }, ) + @pytest.mark.requires_trainable_backend def test_separable_conv2d_basic( self, depth_multiplier, diff --git a/keras_core/layers/core/dense_test.py b/keras_core/layers/core/dense_test.py index a8f40a966..b47f371ae 100644 --- a/keras_core/layers/core/dense_test.py +++ b/keras_core/layers/core/dense_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import layers from keras_core import testing @@ -6,6 +7,7 @@ from keras_core.backend.common import keras_tensor class DenseTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_dense_basics(self): # 2D case, no bias. self.run_layer_test( diff --git a/keras_core/layers/core/einsum_dense_test.py b/keras_core/layers/core/einsum_dense_test.py index 593316264..5fcab7afc 100644 --- a/keras_core/layers/core/einsum_dense_test.py +++ b/keras_core/layers/core/einsum_dense_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras_core import layers @@ -227,6 +228,7 @@ class EinsumDenseTest(testing.TestCase, parameterized.TestCase): "expected_output_shape": (2, 3, 4, 2), }, ) + @pytest.mark.requires_trainable_backend def test_einsum_dense_basics( self, equation, diff --git a/keras_core/layers/core/embedding_test.py b/keras_core/layers/core/embedding_test.py index 705fcd636..a617e0bb1 100644 --- a/keras_core/layers/core/embedding_test.py +++ b/keras_core/layers/core/embedding_test.py @@ -1,10 +1,12 @@ import numpy as np +import pytest from keras_core import layers from keras_core.testing import test_case class EmbeddingTest(test_case.TestCase): + @pytest.mark.requires_trainable_backend def test_embedding_basics(self): self.run_layer_test( layers.Embedding, diff --git a/keras_core/layers/core/identity_test.py b/keras_core/layers/core/identity_test.py index d1d010b93..6dbc0503e 100644 --- a/keras_core/layers/core/identity_test.py +++ b/keras_core/layers/core/identity_test.py @@ -1,8 +1,11 @@ +import pytest + from keras_core import layers from keras_core import testing class IdentityTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_identity_basics(self): self.run_layer_test( layers.Identity, diff --git a/keras_core/layers/core/lambda_layer_test.py b/keras_core/layers/core/lambda_layer_test.py index 32a989df1..34ffc46d9 100644 --- a/keras_core/layers/core/lambda_layer_test.py +++ b/keras_core/layers/core/lambda_layer_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import layers from keras_core import ops @@ -6,6 +7,7 @@ from keras_core import testing class LambdaTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_lambda_basics(self): self.run_layer_test( layers.Lambda, diff --git a/keras_core/layers/core/masking_test.py b/keras_core/layers/core/masking_test.py index 947ad1570..7cfff6aea 100644 --- a/keras_core/layers/core/masking_test.py +++ b/keras_core/layers/core/masking_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import layers from keras_core import models @@ -6,6 +7,7 @@ from keras_core import testing class MaskingTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_masking_basics(self): self.run_layer_test( layers.Masking, @@ -19,6 +21,7 @@ class MaskingTest(testing.TestCase): supports_masking=True, ) + @pytest.mark.requires_trainable_backend def test_masking_correctness(self): x = np.array( [ diff --git a/keras_core/layers/core/wrapper_test.py b/keras_core/layers/core/wrapper_test.py index ae1556d6d..1db8a048f 100644 --- a/keras_core/layers/core/wrapper_test.py +++ b/keras_core/layers/core/wrapper_test.py @@ -1,3 +1,5 @@ +import pytest + from keras_core import layers from keras_core import testing @@ -10,6 +12,7 @@ class ExampleWrapper(layers.Wrapper): class WrapperTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_wrapper_basics(self): self.run_layer_test( ExampleWrapper, diff --git a/keras_core/layers/layer.py b/keras_core/layers/layer.py index a7fdf0c11..9b76f0c99 100644 --- a/keras_core/layers/layer.py +++ b/keras_core/layers/layer.py @@ -44,6 +44,8 @@ elif backend.backend() == "jax": from keras_core.backend.jax.layer import JaxLayer as BackendLayer elif backend.backend() == "torch": from keras_core.backend.torch.layer import TorchLayer as BackendLayer +elif backend.backend() == "numpy": + from keras_core.backend.numpy.layer import NumpyLayer as BackendLayer else: raise RuntimeError( f"Backend '{backend.backend()}' must implement a layer mixin class." @@ -1235,6 +1237,12 @@ class Layer(BackendLayer, Operation): for tensor, mask in zip(flat_outputs, flat_masks): if getattr(tensor, "_keras_mask", None) is None: try: + # Numpy backend does not support masking. + if backend.backend() == "numpy": + warnings.warn( + "The NumPy backend does not support masking at this" + "time. Masks will be ignored." + ) tensor._keras_mask = mask except AttributeError: # It's a C type. diff --git a/keras_core/layers/layer_test.py b/keras_core/layers/layer_test.py index 75b2aa51c..c65044c73 100644 --- a/keras_core/layers/layer_test.py +++ b/keras_core/layers/layer_test.py @@ -264,6 +264,7 @@ class LayerTest(testing.TestCase): layer(layers.Input(batch_shape=(2, 2))) self.assertLen(layer.losses, 0) + @pytest.mark.requires_trainable_backend def test_add_loss(self): class LossLayer(layers.Layer): def call(self, x): @@ -378,6 +379,10 @@ class LayerTest(testing.TestCase): self.assertEqual(backend.standardize_dtype(y.dtype), "float16") self.assertEqual(layer.kernel.dtype, "float32") + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support masking.", + ) def test_masking(self): class BasicMaskedLayer(layers.Layer): def __init__(self): diff --git a/keras_core/layers/merging/merging_test.py b/keras_core/layers/merging/merging_test.py index 5b7087062..23ff4aa6d 100644 --- a/keras_core/layers/merging/merging_test.py +++ b/keras_core/layers/merging/merging_test.py @@ -8,6 +8,7 @@ from keras_core import ops from keras_core import testing +@pytest.mark.requires_trainable_backend class MergingLayersTest(testing.TestCase): def test_add_basic(self): self.run_layer_test( diff --git a/keras_core/layers/normalization/batch_normalization_test.py b/keras_core/layers/normalization/batch_normalization_test.py index 6bbb8a9e8..2eebac3af 100644 --- a/keras_core/layers/normalization/batch_normalization_test.py +++ b/keras_core/layers/normalization/batch_normalization_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from absl.testing import parameterized from keras_core import backend @@ -7,6 +8,7 @@ from keras_core import testing class BatchNormalizationTest(testing.TestCase, parameterized.TestCase): + @pytest.mark.requires_trainable_backend def test_bn_basics(self): # vector case self.run_layer_test( diff --git a/keras_core/layers/normalization/group_normalization_test.py b/keras_core/layers/normalization/group_normalization_test.py index ce52f411e..a1017cc13 100644 --- a/keras_core/layers/normalization/group_normalization_test.py +++ b/keras_core/layers/normalization/group_normalization_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import constraints from keras_core import layers @@ -7,6 +8,7 @@ from keras_core import testing class GroupNormalizationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_groupnorm(self): self.run_layer_test( layers.GroupNormalization, diff --git a/keras_core/layers/normalization/layer_normalization_test.py b/keras_core/layers/normalization/layer_normalization_test.py index ec5fcdbb7..16b1741c5 100644 --- a/keras_core/layers/normalization/layer_normalization_test.py +++ b/keras_core/layers/normalization/layer_normalization_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import layers from keras_core import ops @@ -7,6 +8,7 @@ from keras_core import testing class LayerNormalizationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_ln_basics(self): self.run_layer_test( layers.LayerNormalization, diff --git a/keras_core/layers/normalization/spectral_normalization_test.py b/keras_core/layers/normalization/spectral_normalization_test.py index 4740a192e..b0981f7fd 100644 --- a/keras_core/layers/normalization/spectral_normalization_test.py +++ b/keras_core/layers/normalization/spectral_normalization_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import initializers from keras_core import layers @@ -6,6 +7,7 @@ from keras_core import testing class SpectralNormalizationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_basic_spectralnorm(self): self.run_layer_test( layers.SpectralNormalization, diff --git a/keras_core/layers/normalization/unit_normalization_test.py b/keras_core/layers/normalization/unit_normalization_test.py index 354592f1c..8a3e6b027 100644 --- a/keras_core/layers/normalization/unit_normalization_test.py +++ b/keras_core/layers/normalization/unit_normalization_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import backend from keras_core import layers @@ -11,6 +12,7 @@ def squared_l2_norm(x): class UnitNormalizationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_un_basics(self): self.run_layer_test( layers.UnitNormalization, diff --git a/keras_core/layers/pooling/average_pooling_test.py b/keras_core/layers/pooling/average_pooling_test.py index 11a9b6cd4..8e8b1e407 100644 --- a/keras_core/layers/pooling/average_pooling_test.py +++ b/keras_core/layers/pooling/average_pooling_test.py @@ -8,6 +8,7 @@ from keras_core import layers from keras_core import testing +@pytest.mark.requires_trainable_backend class AveragePoolingBasicTest(testing.TestCase, parameterized.TestCase): @parameterized.parameters( (2, 1, "valid", "channels_last", (3, 5, 4), (3, 4, 4)), diff --git a/keras_core/layers/pooling/global_average_pooling_test.py b/keras_core/layers/pooling/global_average_pooling_test.py index f52242bdc..76f5afcb2 100644 --- a/keras_core/layers/pooling/global_average_pooling_test.py +++ b/keras_core/layers/pooling/global_average_pooling_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import tensorflow as tf from absl.testing import parameterized @@ -6,6 +7,7 @@ from keras_core import layers from keras_core import testing +@pytest.mark.requires_trainable_backend class GlobalAveragePoolingBasicTest(testing.TestCase, parameterized.TestCase): @parameterized.parameters( ("channels_last", False, (3, 5, 4), (3, 4)), diff --git a/keras_core/layers/pooling/global_max_pooling_test.py b/keras_core/layers/pooling/global_max_pooling_test.py index e2c0d254c..327b261a2 100644 --- a/keras_core/layers/pooling/global_max_pooling_test.py +++ b/keras_core/layers/pooling/global_max_pooling_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import tensorflow as tf from absl.testing import parameterized @@ -6,6 +7,7 @@ from keras_core import layers from keras_core import testing +@pytest.mark.requires_trainable_backend class GlobalMaxPoolingBasicTest(testing.TestCase, parameterized.TestCase): @parameterized.parameters( ("channels_last", False, (3, 5, 4), (3, 4)), diff --git a/keras_core/layers/pooling/max_pooling_test.py b/keras_core/layers/pooling/max_pooling_test.py index 4df64ab8a..9a3168a72 100644 --- a/keras_core/layers/pooling/max_pooling_test.py +++ b/keras_core/layers/pooling/max_pooling_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import tensorflow as tf from absl.testing import parameterized @@ -6,6 +7,7 @@ from keras_core import layers from keras_core import testing +@pytest.mark.requires_trainable_backend class MaxPoolingBasicTest(testing.TestCase, parameterized.TestCase): @parameterized.parameters( (2, 1, "valid", "channels_last", (3, 5, 4), (3, 4, 4)), diff --git a/keras_core/layers/preprocessing/center_crop_test.py b/keras_core/layers/preprocessing/center_crop_test.py index 52e56d7d5..425a8f85a 100644 --- a/keras_core/layers/preprocessing/center_crop_test.py +++ b/keras_core/layers/preprocessing/center_crop_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import tensorflow as tf from absl.testing import parameterized @@ -7,6 +8,7 @@ from keras_core import testing class CenterCropTest(testing.TestCase, parameterized.TestCase): + @pytest.mark.requires_trainable_backend def test_center_crop_basics(self): self.run_layer_test( layers.CenterCrop, diff --git a/keras_core/layers/preprocessing/normalization_test.py b/keras_core/layers/preprocessing/normalization_test.py index a72ed3472..a5c075440 100644 --- a/keras_core/layers/preprocessing/normalization_test.py +++ b/keras_core/layers/preprocessing/normalization_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import tensorflow as tf from absl.testing import parameterized @@ -8,6 +9,7 @@ from keras_core import testing class NormalizationTest(testing.TestCase, parameterized.TestCase): + @pytest.mark.requires_trainable_backend def test_normalization_basics(self): self.run_layer_test( layers.Normalization, diff --git a/keras_core/layers/preprocessing/random_brightness_test.py b/keras_core/layers/preprocessing/random_brightness_test.py index 3f68fd691..02ab589df 100644 --- a/keras_core/layers/preprocessing/random_brightness_test.py +++ b/keras_core/layers/preprocessing/random_brightness_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import tensorflow as tf from keras_core import backend @@ -7,6 +8,7 @@ from keras_core import testing class RandomBrightnessTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_layer(self): self.run_layer_test( layers.RandomBrightness, diff --git a/keras_core/layers/preprocessing/random_contrast_test.py b/keras_core/layers/preprocessing/random_contrast_test.py index 8c07c3ed7..0b4fa979e 100644 --- a/keras_core/layers/preprocessing/random_contrast_test.py +++ b/keras_core/layers/preprocessing/random_contrast_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import tensorflow as tf from keras_core import backend @@ -7,6 +8,7 @@ from keras_core import testing class RandomContrastTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_layer(self): self.run_layer_test( layers.RandomContrast, diff --git a/keras_core/layers/preprocessing/rescaling_test.py b/keras_core/layers/preprocessing/rescaling_test.py index e3a85a1ee..6bdf61cff 100644 --- a/keras_core/layers/preprocessing/rescaling_test.py +++ b/keras_core/layers/preprocessing/rescaling_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import tensorflow as tf from keras_core import layers @@ -6,6 +7,7 @@ from keras_core import testing class RescalingTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_rescaling_basics(self): self.run_layer_test( layers.Rescaling, @@ -19,6 +21,7 @@ class RescalingTest(testing.TestCase): supports_masking=True, ) + @pytest.mark.requires_trainable_backend def test_rescaling_dtypes(self): # int scale self.run_layer_test( diff --git a/keras_core/layers/regularization/activity_regularization_test.py b/keras_core/layers/regularization/activity_regularization_test.py index f240f2aa0..075d5c19f 100644 --- a/keras_core/layers/regularization/activity_regularization_test.py +++ b/keras_core/layers/regularization/activity_regularization_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import layers from keras_core.testing import test_case @@ -11,6 +12,7 @@ class ActivityRegularizationTest(test_case.TestCase): self.assertLen(layer.losses, 1) self.assertAllClose(layer.losses[0], 4 * 0.3 + 2 * 0.2) + @pytest.mark.requires_trainable_backend def test_activity_regularization_basics(self): self.run_layer_test( layers.ActivityRegularization, diff --git a/keras_core/layers/regularization/dropout_test.py b/keras_core/layers/regularization/dropout_test.py index 2f1839dcd..cb75c0b94 100644 --- a/keras_core/layers/regularization/dropout_test.py +++ b/keras_core/layers/regularization/dropout_test.py @@ -7,6 +7,7 @@ from keras_core import testing class DropoutTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_dropout_basics(self): self.run_layer_test( layers.Dropout, diff --git a/keras_core/layers/regularization/gaussian_dropout_test.py b/keras_core/layers/regularization/gaussian_dropout_test.py index 2f912145a..852db925d 100644 --- a/keras_core/layers/regularization/gaussian_dropout_test.py +++ b/keras_core/layers/regularization/gaussian_dropout_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import backend from keras_core import layers @@ -6,6 +7,7 @@ from keras_core import testing class GaussianDropoutTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_gaussian_dropout_basics(self): self.run_layer_test( layers.GaussianDropout, diff --git a/keras_core/layers/regularization/gaussian_noise_test.py b/keras_core/layers/regularization/gaussian_noise_test.py index 6023d0149..27066dbb4 100644 --- a/keras_core/layers/regularization/gaussian_noise_test.py +++ b/keras_core/layers/regularization/gaussian_noise_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import backend from keras_core import layers @@ -6,6 +7,7 @@ from keras_core import testing class GaussianNoiseTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_gaussian_noise_basics(self): self.run_layer_test( layers.GaussianNoise, diff --git a/keras_core/layers/regularization/spatial_dropout_test.py b/keras_core/layers/regularization/spatial_dropout_test.py index bb9751062..a6681621b 100644 --- a/keras_core/layers/regularization/spatial_dropout_test.py +++ b/keras_core/layers/regularization/spatial_dropout_test.py @@ -1,10 +1,12 @@ import numpy as np +import pytest from keras_core import layers from keras_core.testing import test_case class SpatialDropoutTest(test_case.TestCase): + @pytest.mark.requires_trainable_backend def test_spatial_dropout_1d(self): self.run_layer_test( layers.SpatialDropout1D, @@ -20,6 +22,7 @@ class SpatialDropoutTest(test_case.TestCase): input_shape=(2, 3, 4), ) + @pytest.mark.requires_trainable_backend def test_spatial_dropout_2d(self): self.run_layer_test( layers.SpatialDropout2D, @@ -35,6 +38,7 @@ class SpatialDropoutTest(test_case.TestCase): input_shape=(2, 3, 4, 5), ) + @pytest.mark.requires_trainable_backend def test_spatial_dropout_3d(self): self.run_layer_test( layers.SpatialDropout3D, diff --git a/keras_core/layers/reshaping/cropping1d_test.py b/keras_core/layers/reshaping/cropping1d_test.py index 8d22ddca1..67862843d 100644 --- a/keras_core/layers/reshaping/cropping1d_test.py +++ b/keras_core/layers/reshaping/cropping1d_test.py @@ -1,13 +1,13 @@ import numpy as np import pytest -from keras_core import backend from keras_core import layers from keras_core import ops from keras_core import testing class Cropping1DTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_cropping_1d(self): inputs = np.random.rand(3, 5, 7) @@ -47,10 +47,7 @@ class Cropping1DTest(testing.TestCase): expected_output=ops.convert_to_tensor(inputs[:, 1:5, :]), ) - @pytest.mark.skipif( - not backend.DYNAMIC_SHAPES_OK, - reason="Backend does not support dynamic shapes", - ) + @pytest.mark.requires_trainable_backend def test_cropping_1d_with_dynamic_spatial_dim(self): input_layer = layers.Input(batch_shape=(1, None, 7)) cropped = layers.Cropping1D((1, 2))(input_layer) diff --git a/keras_core/layers/reshaping/cropping2d_test.py b/keras_core/layers/reshaping/cropping2d_test.py index 704510358..e36b06db6 100644 --- a/keras_core/layers/reshaping/cropping2d_test.py +++ b/keras_core/layers/reshaping/cropping2d_test.py @@ -33,6 +33,7 @@ class Cropping2DTest(testing.TestCase, parameterized.TestCase): {"data_format": "channels_last"}, ), ) + @pytest.mark.requires_trainable_backend def test_cropping_2d(self, cropping, data_format, expected_ranges): if data_format == "channels_first": inputs = np.random.rand(3, 5, 7, 9) diff --git a/keras_core/layers/reshaping/cropping3d_test.py b/keras_core/layers/reshaping/cropping3d_test.py index e6f95d418..7566bdafd 100644 --- a/keras_core/layers/reshaping/cropping3d_test.py +++ b/keras_core/layers/reshaping/cropping3d_test.py @@ -30,6 +30,7 @@ class Cropping3DTest(testing.TestCase, parameterized.TestCase): {"data_format": "channels_last"}, ), ) + @pytest.mark.requires_trainable_backend def test_cropping_3d( self, dim1_cropping, @@ -88,6 +89,7 @@ class Cropping3DTest(testing.TestCase, parameterized.TestCase): {"data_format": "channels_last"}, ), ) + @pytest.mark.requires_trainable_backend def test_cropping_3d_with_same_cropping( self, cropping, data_format, expected ): diff --git a/keras_core/layers/reshaping/flatten_test.py b/keras_core/layers/reshaping/flatten_test.py index 7043eb849..339a7180f 100644 --- a/keras_core/layers/reshaping/flatten_test.py +++ b/keras_core/layers/reshaping/flatten_test.py @@ -8,6 +8,7 @@ from keras_core import testing class FlattenTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_flatten(self): inputs = np.random.random((10, 3, 5, 5)).astype("float32") @@ -39,6 +40,7 @@ class FlattenTest(testing.TestCase): expected_output=expected_output, ) + @pytest.mark.requires_trainable_backend def test_flatten_with_scalar_channels(self): inputs = np.random.random((10,)).astype("float32") expected_output = ops.convert_to_tensor(np.expand_dims(inputs, -1)) diff --git a/keras_core/layers/reshaping/permute_test.py b/keras_core/layers/reshaping/permute_test.py index 0b2dbedea..f6f72e6fe 100644 --- a/keras_core/layers/reshaping/permute_test.py +++ b/keras_core/layers/reshaping/permute_test.py @@ -8,6 +8,7 @@ from keras_core import testing class PermuteTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_permute(self): inputs = np.random.random((2, 3, 5)).astype("float32") expected_output = ops.convert_to_tensor( diff --git a/keras_core/layers/reshaping/repeat_vector_test.py b/keras_core/layers/reshaping/repeat_vector_test.py index 9fd9b55bc..b608799aa 100644 --- a/keras_core/layers/reshaping/repeat_vector_test.py +++ b/keras_core/layers/reshaping/repeat_vector_test.py @@ -8,6 +8,7 @@ from keras_core import testing class FlattenTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_repeat_vector(self): inputs = np.random.random((2, 5)).astype("float32") expected_output = ops.convert_to_tensor( diff --git a/keras_core/layers/reshaping/reshape_test.py b/keras_core/layers/reshaping/reshape_test.py index d71262e73..eea5a4e50 100644 --- a/keras_core/layers/reshaping/reshape_test.py +++ b/keras_core/layers/reshaping/reshape_test.py @@ -6,6 +6,7 @@ from keras_core import testing class ReshapeTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_reshape(self): self.run_layer_test( layers.Reshape, diff --git a/keras_core/layers/reshaping/up_sampling1d_test.py b/keras_core/layers/reshaping/up_sampling1d_test.py index e482f5d66..e92ea64dc 100644 --- a/keras_core/layers/reshaping/up_sampling1d_test.py +++ b/keras_core/layers/reshaping/up_sampling1d_test.py @@ -8,6 +8,7 @@ from keras_core.backend.common.keras_tensor import KerasTensor class UpSamplingTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_upsampling_1d(self): self.run_layer_test( layers.UpSampling1D, diff --git a/keras_core/layers/reshaping/up_sampling2d_test.py b/keras_core/layers/reshaping/up_sampling2d_test.py index e53f923c7..b87e8174c 100644 --- a/keras_core/layers/reshaping/up_sampling2d_test.py +++ b/keras_core/layers/reshaping/up_sampling2d_test.py @@ -14,6 +14,7 @@ class UpSampling2dTest(testing.TestCase, parameterized.TestCase): length_row=[2], length_col=[2, 3], ) + @pytest.mark.requires_trainable_backend def test_upsampling_2d(self, data_format, length_row, length_col): num_samples = 2 stack_size = 2 @@ -64,6 +65,7 @@ class UpSampling2dTest(testing.TestCase, parameterized.TestCase): length_row=[2], length_col=[2, 3], ) + @pytest.mark.requires_trainable_backend def test_upsampling_2d_bilinear(self, data_format, length_row, length_col): num_samples = 2 stack_size = 2 diff --git a/keras_core/layers/reshaping/up_sampling3d_test.py b/keras_core/layers/reshaping/up_sampling3d_test.py index 6b7b10a6f..c747dff9f 100644 --- a/keras_core/layers/reshaping/up_sampling3d_test.py +++ b/keras_core/layers/reshaping/up_sampling3d_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from absl.testing import parameterized from keras_core import backend @@ -13,6 +14,7 @@ class UpSampling3dTest(testing.TestCase, parameterized.TestCase): length_dim2=[2], length_dim3=[3], ) + @pytest.mark.requires_trainable_backend def test_upsampling_3d( self, data_format, length_dim1, length_dim2, length_dim3 ): diff --git a/keras_core/layers/rnn/bidirectional_test.py b/keras_core/layers/rnn/bidirectional_test.py index 13b5b3269..0f326a211 100644 --- a/keras_core/layers/rnn/bidirectional_test.py +++ b/keras_core/layers/rnn/bidirectional_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import initializers from keras_core import layers @@ -6,6 +7,7 @@ from keras_core import testing class SimpleRNNTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_basics(self): self.run_layer_test( layers.Bidirectional, diff --git a/keras_core/layers/rnn/conv_lstm1d_test.py b/keras_core/layers/rnn/conv_lstm1d_test.py index e24b3108f..948b637c9 100644 --- a/keras_core/layers/rnn/conv_lstm1d_test.py +++ b/keras_core/layers/rnn/conv_lstm1d_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import initializers from keras_core import layers @@ -6,6 +7,7 @@ from keras_core import testing class ConvLSTM1DTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_basics(self): self.run_layer_test( layers.ConvLSTM1D, diff --git a/keras_core/layers/rnn/conv_lstm2d_test.py b/keras_core/layers/rnn/conv_lstm2d_test.py index 298d80fe7..88e825071 100644 --- a/keras_core/layers/rnn/conv_lstm2d_test.py +++ b/keras_core/layers/rnn/conv_lstm2d_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import initializers from keras_core import layers @@ -6,6 +7,7 @@ from keras_core import testing class ConvLSTM2DTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_basics(self): self.run_layer_test( layers.ConvLSTM2D, diff --git a/keras_core/layers/rnn/conv_lstm3d_test.py b/keras_core/layers/rnn/conv_lstm3d_test.py index fb108085f..7d5a33420 100644 --- a/keras_core/layers/rnn/conv_lstm3d_test.py +++ b/keras_core/layers/rnn/conv_lstm3d_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import initializers from keras_core import layers @@ -6,6 +7,7 @@ from keras_core import testing class ConvLSTM1DTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_basics(self): self.run_layer_test( layers.ConvLSTM3D, diff --git a/keras_core/layers/rnn/dropout_rnn_cell_test.py b/keras_core/layers/rnn/dropout_rnn_cell_test.py index 4953dd4e5..6b000e894 100644 --- a/keras_core/layers/rnn/dropout_rnn_cell_test.py +++ b/keras_core/layers/rnn/dropout_rnn_cell_test.py @@ -1,3 +1,5 @@ +import pytest + from keras_core import backend from keras_core import layers from keras_core import ops @@ -50,6 +52,7 @@ class DropoutRNNCellTest(testing.TestCase): layer = layers.RNN(cell) self.assertEqual(len(layer.non_trainable_variables), 1) + @pytest.mark.requires_trainable_backend def test_basics(self): self.run_layer_test( layers.RNN, diff --git a/keras_core/layers/rnn/gru_test.py b/keras_core/layers/rnn/gru_test.py index 0603e5d04..16fab6c35 100644 --- a/keras_core/layers/rnn/gru_test.py +++ b/keras_core/layers/rnn/gru_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from absl.testing import parameterized from keras_core import initializers @@ -7,6 +8,7 @@ from keras_core import testing class GRUTest(testing.TestCase, parameterized.TestCase): + @pytest.mark.requires_trainable_backend def test_basics(self): self.run_layer_test( layers.GRU, diff --git a/keras_core/layers/rnn/lstm_test.py b/keras_core/layers/rnn/lstm_test.py index 3ef23677a..95d0af326 100644 --- a/keras_core/layers/rnn/lstm_test.py +++ b/keras_core/layers/rnn/lstm_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from absl.testing import parameterized from keras_core import initializers @@ -7,6 +8,7 @@ from keras_core import testing class LSTMTest(testing.TestCase, parameterized.TestCase): + @pytest.mark.requires_trainable_backend def test_basics(self): self.run_layer_test( layers.LSTM, diff --git a/keras_core/layers/rnn/rnn_test.py b/keras_core/layers/rnn/rnn_test.py index c24b54cd1..87e0456f1 100644 --- a/keras_core/layers/rnn/rnn_test.py +++ b/keras_core/layers/rnn/rnn_test.py @@ -69,6 +69,7 @@ class TwoStatesRNNCell(layers.Layer): class RNNTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_basics(self): self.run_layer_test( layers.RNN, diff --git a/keras_core/layers/rnn/simple_rnn_test.py b/keras_core/layers/rnn/simple_rnn_test.py index 2b174cdc9..6d99e0b03 100644 --- a/keras_core/layers/rnn/simple_rnn_test.py +++ b/keras_core/layers/rnn/simple_rnn_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import initializers from keras_core import layers @@ -6,6 +7,7 @@ from keras_core import testing class SimpleRNNTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_basics(self): self.run_layer_test( layers.SimpleRNN, diff --git a/keras_core/layers/rnn/stacked_rnn_cells_test.py b/keras_core/layers/rnn/stacked_rnn_cells_test.py index 1d62175d1..e9a0e0f71 100644 --- a/keras_core/layers/rnn/stacked_rnn_cells_test.py +++ b/keras_core/layers/rnn/stacked_rnn_cells_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import layers from keras_core import testing @@ -7,6 +8,7 @@ from keras_core.layers.rnn.rnn_test import TwoStatesRNNCell class StackedRNNTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_basics(self): self.run_layer_test( layers.RNN, diff --git a/keras_core/layers/rnn/time_distributed_test.py b/keras_core/layers/rnn/time_distributed_test.py index b2c3638dd..312afd88f 100644 --- a/keras_core/layers/rnn/time_distributed_test.py +++ b/keras_core/layers/rnn/time_distributed_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import initializers from keras_core import layers @@ -7,6 +8,7 @@ from keras_core import testing class TimeDistributedTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_basics(self): self.run_layer_test( layers.TimeDistributed, diff --git a/keras_core/losses/loss_test.py b/keras_core/losses/loss_test.py index 2eb5d1f83..57f62e1f4 100644 --- a/keras_core/losses/loss_test.py +++ b/keras_core/losses/loss_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import backend from keras_core import losses as losses_module @@ -39,6 +40,10 @@ class LossTest(testing.TestCase): with self.assertRaisesRegex(ValueError, "Invalid value for argument"): ExampleLoss(reduction="abc") + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support masking.", + ) def test_mask(self): mask = np.array([True, False, True, True]) y_true = np.array([1.0, 0.0, 1.0, 0.0]) @@ -84,6 +89,10 @@ class LossTest(testing.TestCase): self.assertEqual(backend.standardize_dtype(loss.dtype), "float32") self.assertAllClose(loss, 0) # No NaN. + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support masking.", + ) def test_mask_and_sample_weight(self): sample_weight = np.array([0.4, 0.3, 0.2, 0.1]) y_true = np.array([1.0, 0.0, 1.0, 0.0]) @@ -111,6 +120,10 @@ class LossTest(testing.TestCase): # @testing.parametrize( # "uprank", ["mask", "sample_weight", "y_true", "y_pred"]) # TODO: use parameterization decorator + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support masking.", + ) def test_rank_adjustment(self): for uprank in ["mask", "sample_weight", "ys"]: sample_weight = np.array([0.4, 0.3, 0.2, 0.1]) diff --git a/keras_core/metrics/accuracy_metrics.py b/keras_core/metrics/accuracy_metrics.py index 2678e7419..bab1963ca 100644 --- a/keras_core/metrics/accuracy_metrics.py +++ b/keras_core/metrics/accuracy_metrics.py @@ -9,7 +9,10 @@ def accuracy(y_true, y_pred): y_pred = ops.convert_to_tensor(y_pred) y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) y_true, y_pred = squeeze_to_same_rank(y_true, y_pred) - return ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx()) + return ops.mean( + ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx()), + axis=-1, + ) @keras_core_export("keras_core.metrics.Accuracy") @@ -135,7 +138,7 @@ def categorical_accuracy(y_true, y_pred): and (y_pred_rank is not None) and (len(y_true.shape) == len(y_pred.shape)) ): - y_true = ops.squeeze(y_true, [-1]) + y_true = ops.squeeze(y_true, -1) reshape_matches = True y_pred = ops.argmax(y_pred, axis=-1) @@ -218,7 +221,7 @@ def sparse_categorical_accuracy(y_true, y_pred): and (y_pred_rank is not None) and (len(y_true.shape) == len(y_pred.shape)) ): - y_true = ops.squeeze(y_true, [-1]) + y_true = ops.squeeze(y_true, -1) reshape_matches = True y_pred = ops.argmax(y_pred, axis=-1) @@ -231,7 +234,7 @@ def sparse_categorical_accuracy(y_true, y_pred): matches = ops.reshape(matches, new_shape=y_true_org_shape) # if shape is (num_samples, 1) squeeze if len(matches.shape) > 1 and matches.shape[-1] == 1: - matches = ops.squeeze(matches, [-1]) + matches = ops.squeeze(matches, -1) return matches diff --git a/keras_core/metrics/confusion_metrics_test.py b/keras_core/metrics/confusion_metrics_test.py index 845993d65..aac522438 100644 --- a/keras_core/metrics/confusion_metrics_test.py +++ b/keras_core/metrics/confusion_metrics_test.py @@ -1,6 +1,7 @@ import json import numpy as np +import pytest from absl import logging from absl.testing import parameterized from tensorflow.python.ops.numpy_ops import np_config @@ -756,7 +757,7 @@ class SensitivityAtSpecificityTest(testing.TestCase, parameterized.TestCase): label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2] y_pred = ops.transpose(np.array([pred_values] * 3)) - y_true = ops.one_hot(label_values, num_classes=3) + y_true = ops.one_hot(np.array(label_values), num_classes=3) self.assertAlmostEqual(0.6, s_obj(y_true, y_pred)) @@ -845,7 +846,7 @@ class SpecificityAtSensitivityTest(testing.TestCase, parameterized.TestCase): label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2] y_pred = ops.transpose(np.array([pred_values] * 3)) - y_true = ops.one_hot(label_values, num_classes=3) + y_true = ops.one_hot(np.array(label_values), num_classes=3) self.assertAlmostEqual(0.6, s_obj(y_true, y_pred)) @@ -931,7 +932,7 @@ class PrecisionAtRecallTest(testing.TestCase, parameterized.TestCase): label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2] y_pred = ops.transpose(np.array([pred_values] * 3)) - y_true = ops.one_hot(label_values, num_classes=3) + y_true = ops.one_hot(np.array(label_values), num_classes=3) # For 0.2 < decision threshold < 0.5. self.assertAlmostEqual(0.75, s_obj(y_true, y_pred)) @@ -1067,7 +1068,7 @@ class RecallAtPrecisionTest(testing.TestCase, parameterized.TestCase): # recalls: [1, 1, 5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6, # 1/6]. y_pred = ops.transpose(np.array([pred_values] * 3)) - y_true = ops.one_hot(label_values, num_classes=3) + y_true = ops.one_hot(np.array(label_values), num_classes=3) # The precision 5/7 can be reached at thresholds 00.3<=t<0.35. self.assertAlmostEqual(5.0 / 6, s_obj(y_true, y_pred)) @@ -1645,6 +1646,7 @@ class MultiAUCTest(testing.TestCase): # PR AUCs are 0.939 and 1.0 respectively self.assertAllClose(good_result, (0.939 + 1.0) / 2.0, 1e-1) + @pytest.mark.requires_trainable_backend def test_keras_model_compiles(self): inputs = layers.Input(shape=(10,), batch_size=1) output = layers.Dense(3, activation="sigmoid")(inputs) diff --git a/keras_core/models/cloning_test.py b/keras_core/models/cloning_test.py index 17652db92..0df8e12a1 100644 --- a/keras_core/models/cloning_test.py +++ b/keras_core/models/cloning_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from absl.testing import parameterized from keras_core import layers @@ -41,6 +42,7 @@ def get_subclassed_model(): return ExampleModel() +@pytest.mark.requires_trainable_backend class CloneModelTest(testing.TestCase, parameterized.TestCase): @parameterized.named_parameters( ("functional", get_functional_model), diff --git a/keras_core/models/functional_test.py b/keras_core/models/functional_test.py index fec237f36..d667b921c 100644 --- a/keras_core/models/functional_test.py +++ b/keras_core/models/functional_test.py @@ -1,6 +1,7 @@ import warnings import numpy as np +import pytest from keras_core import backend from keras_core import layers @@ -12,6 +13,7 @@ from keras_core.models import Model class FunctionalTest(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_basic_flow_multi_input(self): input_a = Input(shape=(3,), batch_size=2, name="input_a") input_b = Input(shape=(3,), batch_size=2, name="input_b") @@ -37,6 +39,7 @@ class FunctionalTest(testing.TestCase): out_val = model(in_val) self.assertEqual(out_val.shape, (2, 4)) + @pytest.mark.requires_trainable_backend def test_scalar_input(self): input_a = Input(shape=(3,), batch_size=2, name="input_a") input_b = Input(shape=(), batch_size=2, name="input_b") @@ -48,6 +51,7 @@ class FunctionalTest(testing.TestCase): out_val = model(in_val) self.assertAllClose(out_val, np.ones((2, 3))) + @pytest.mark.requires_trainable_backend def test_basic_flow_multi_output(self): inputs = Input(shape=(3,), batch_size=2, name="input") x = layers.Dense(5)(inputs) @@ -70,6 +74,7 @@ class FunctionalTest(testing.TestCase): self.assertEqual(out_val[0].shape, (2, 4)) self.assertEqual(out_val[1].shape, (2, 5)) + @pytest.mark.requires_trainable_backend def test_basic_flow_dict_io(self): input_a = Input(shape=(3,), batch_size=2, name="a") input_b = Input(shape=(3,), batch_size=2, name="b") @@ -101,6 +106,7 @@ class FunctionalTest(testing.TestCase): out_val = model(in_val) self.assertEqual(out_val.shape, (2, 4)) + @pytest.mark.requires_trainable_backend def test_named_input_dict_io(self): input_a = Input(shape=(3,), batch_size=2, name="a") x = layers.Dense(5)(input_a) @@ -119,6 +125,7 @@ class FunctionalTest(testing.TestCase): out_val = model(in_val) self.assertEqual(out_val.shape, (2, 4)) + @pytest.mark.requires_trainable_backend def test_input_dict_with_extra_field(self): input_a = Input(shape=(3,), batch_size=2, name="a") x = input_a * 5 @@ -145,6 +152,7 @@ class FunctionalTest(testing.TestCase): out_val = model(in_val) self.assertEqual(out_val.shape, (2, 3)) + @pytest.mark.requires_trainable_backend def test_layer_getters(self): # Test mixing ops and layers input_a = Input(shape=(3,), batch_size=2, name="input_a") @@ -162,6 +170,7 @@ class FunctionalTest(testing.TestCase): self.assertEqual(model.get_layer(index=3).name, "dense_2") self.assertEqual(model.get_layer(name="dense_1").name, "dense_1") + @pytest.mark.requires_trainable_backend def test_training_arg(self): class Canary(layers.Layer): def call(self, x, training=False): @@ -180,6 +189,7 @@ class FunctionalTest(testing.TestCase): # TODO pass + @pytest.mark.requires_trainable_backend def test_passing_inputs_by_name(self): input_a = Input(shape=(3,), batch_size=2, name="input_a") input_b = Input(shape=(3,), batch_size=2, name="input_b") @@ -203,6 +213,7 @@ class FunctionalTest(testing.TestCase): out_val = model(in_val) self.assertEqual(out_val.shape, (2, 4)) + @pytest.mark.requires_trainable_backend def test_rank_standardization(self): # Downranking inputs = Input(shape=(3,), batch_size=2) @@ -218,6 +229,7 @@ class FunctionalTest(testing.TestCase): out_val = model(np.random.random((2, 3))) self.assertEqual(out_val.shape, (2, 3, 3)) + @pytest.mark.requires_trainable_backend def test_dtype_standardization(self): float_input = Input(shape=(2,), dtype="float16") int_input = Input(shape=(2,), dtype="int32") @@ -229,6 +241,7 @@ class FunctionalTest(testing.TestCase): self.assertEqual(backend.standardize_dtype(float_data.dtype), "float16") self.assertEqual(backend.standardize_dtype(int_data.dtype), "int32") + @pytest.mark.requires_trainable_backend def test_serialization(self): # Test basic model inputs = Input(shape=(3,), batch_size=2) @@ -269,6 +282,7 @@ class FunctionalTest(testing.TestCase): model = Functional({"a": input_a, "b": input_b}, outputs) self.run_class_serialization_test(model) + @pytest.mark.requires_trainable_backend def test_bad_input_spec(self): # Single input inputs = Input(shape=(4,)) @@ -303,6 +317,7 @@ class FunctionalTest(testing.TestCase): ): model({"a": np.zeros((2, 3)), "b": np.zeros((2, 4))}) + @pytest.mark.requires_trainable_backend def test_manual_input_spec(self): inputs = Input(shape=(None, 3)) outputs = layers.Dense(2)(inputs) diff --git a/keras_core/models/model.py b/keras_core/models/model.py index 1a9aa98bf..0c60f6948 100644 --- a/keras_core/models/model.py +++ b/keras_core/models/model.py @@ -22,6 +22,8 @@ elif backend.backend() == "jax": from keras_core.backend.jax.trainer import JAXTrainer as Trainer elif backend.backend() == "torch": from keras_core.backend.torch.trainer import TorchTrainer as Trainer +elif backend.backend() == "numpy": + from keras_core.backend.numpy.trainer import NumpyTrainer as Trainer else: raise RuntimeError( f"Backend '{backend.backend()}' must implement the Trainer class." diff --git a/keras_core/models/model_test.py b/keras_core/models/model_test.py index 4ef479eb8..f07fe55d5 100644 --- a/keras_core/models/model_test.py +++ b/keras_core/models/model_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from absl.testing import parameterized from keras_core import layers @@ -64,6 +65,7 @@ def _get_model_multi_outputs_dict(): return model +@pytest.mark.requires_trainable_backend class ModelTest(testing.TestCase, parameterized.TestCase): def test_functional_rerouting(self): model = _get_model() diff --git a/keras_core/models/sequential_test.py b/keras_core/models/sequential_test.py index 8ad079a3d..c1b81d20e 100644 --- a/keras_core/models/sequential_test.py +++ b/keras_core/models/sequential_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import backend from keras_core import layers @@ -8,6 +9,7 @@ from keras_core.models.functional import Functional from keras_core.models.sequential import Sequential +@pytest.mark.requires_trainable_backend class SequentialTest(testing.TestCase): def test_basic_flow_with_input(self): model = Sequential(name="seq") diff --git a/keras_core/ops/core_test.py b/keras_core/ops/core_test.py index 1e3209314..22eb92737 100644 --- a/keras_core/ops/core_test.py +++ b/keras_core/ops/core_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras_core import layers from keras_core import losses @@ -162,7 +163,7 @@ class CoreOpsCorrectnessTest(testing.TestCase): def test_slice_update(self): # Test 1D. inputs = np.array([0, 0, 0, 0, 0, 0, 0, 0]) - start_indices = [1] + start_indices = np.array([1]) updates = np.array([9, 10, 11, 12]) self.assertAllClose( core.slice_update(inputs, start_indices, updates), @@ -204,6 +205,7 @@ class CoreOpsCorrectnessTest(testing.TestCase): self.assertAllClose(x, np.ones((2, 3)) * 6) self.assertAllClose(y, np.ones((3, 2)) * 6) + @pytest.mark.requires_trainable_backend def test_stop_gradient(self): class ExampleLayer(layers.Layer): def __init__(self): diff --git a/keras_core/optimizers/adam_test.py b/keras_core/optimizers/adam_test.py index f9dc4e1fa..f8b5c1066 100644 --- a/keras_core/optimizers/adam_test.py +++ b/keras_core/optimizers/adam_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import keras_core from keras_core import backend @@ -75,6 +76,7 @@ class AdamTest(testing.TestCase): clipped_grad = optimizer._clip_gradients(grad) self.assertAllClose(clipped_grad[0], [1.0, 1.0]) + @pytest.mark.requires_trainable_backend def test_ema(self): # TODO: test correctness model = keras_core.Sequential([keras_core.layers.Dense(10)]) diff --git a/keras_core/optimizers/schedules/learning_rate_schedule_test.py b/keras_core/optimizers/schedules/learning_rate_schedule_test.py index 52505400c..dd9388d6f 100644 --- a/keras_core/optimizers/schedules/learning_rate_schedule_test.py +++ b/keras_core/optimizers/schedules/learning_rate_schedule_test.py @@ -3,6 +3,7 @@ import math import numpy as np +import pytest from keras_core import backend from keras_core import layers @@ -13,6 +14,7 @@ from keras_core.optimizers import schedules class TestFitLRSchedulesFlow(testing.TestCase): + @pytest.mark.requires_trainable_backend def test_fit_lr_correctness(self): model = Sequential( [ diff --git a/keras_core/saving/legacy/legacy_h5_format_test.py b/keras_core/saving/legacy/legacy_h5_format_test.py index a3c6171e1..9c046c466 100644 --- a/keras_core/saving/legacy/legacy_h5_format_test.py +++ b/keras_core/saving/legacy/legacy_h5_format_test.py @@ -1,6 +1,7 @@ import os import numpy as np +import pytest import tensorflow as tf import keras_core @@ -52,6 +53,7 @@ def get_subclassed_model(keras): return model +@pytest.mark.requires_trainable_backend class LegacyH5LoadingTest(testing.TestCase): def _check_reloading(self, ref_input, model, tf_keras_model): ref_output = tf_keras_model(ref_input) diff --git a/keras_core/saving/saving_lib_test.py b/keras_core/saving/saving_lib_test.py index 170d4dd21..0fdc0961d 100644 --- a/keras_core/saving/saving_lib_test.py +++ b/keras_core/saving/saving_lib_test.py @@ -7,6 +7,7 @@ from pathlib import Path from unittest import mock import numpy as np +import pytest import keras_core from keras_core import ops @@ -130,6 +131,7 @@ def my_mean_squared_error(y_true, y_pred): return ops.mean(ops.square(y_pred - y_true), axis=-1) +@pytest.mark.requires_trainable_backend class SavingTest(testing.TestCase): def _get_subclassed_model(self, compile=True): subclassed_model = CustomModelX() diff --git a/keras_core/saving/serialization_lib_test.py b/keras_core/saving/serialization_lib_test.py index c6e685f89..8519f8bac 100644 --- a/keras_core/saving/serialization_lib_test.py +++ b/keras_core/saving/serialization_lib_test.py @@ -3,6 +3,7 @@ import json import numpy as np +import pytest import keras_core from keras_core import ops @@ -188,6 +189,7 @@ class SerializationLibTest(testing.TestCase): # y2 = new_lmbda(x) # self.assertAllClose(y1, y2, atol=1e-5) + @pytest.mark.requires_trainable_backend def test_dict_inputs_outputs(self): input_foo = keras_core.Input((2,), name="foo") input_bar = keras_core.Input((2,), name="bar") @@ -223,6 +225,7 @@ class SerializationLibTest(testing.TestCase): self.assertIs(model.layers[2], model.layers[3].layer) self.assertIs(new_model.layers[2], new_model.layers[3].layer) + @pytest.mark.requires_trainable_backend def test_functional_subclass(self): class PlainFunctionalSubclass(keras_core.Model): pass diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index 4e9356c60..040458a19 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -21,6 +21,8 @@ elif backend.backend() == "tensorflow": from keras_core.backend.tensorflow.trainer import ( TensorFlowTrainer as Trainer, ) +elif backend.backend() == "numpy": + from keras_core.backend.numpy.trainer import NumpyTrainer as Trainer else: raise ImportError(f"Invalid backend: {backend.backend()}") @@ -70,6 +72,7 @@ class TrainingTestingLayer(layers.Layer, Trainer): return x * 0 +@pytest.mark.requires_trainable_backend class TestTrainer(testing.TestCase, parameterized.TestCase): def test_metric_tracking(self): class ModelWithMetric(layers.Dense, Trainer): diff --git a/keras_core/utils/backend_utils.py b/keras_core/utils/backend_utils.py index 10b73f3a4..2aa7018c1 100644 --- a/keras_core/utils/backend_utils.py +++ b/keras_core/utils/backend_utils.py @@ -49,3 +49,11 @@ class DynamicBackend: from keras_core.backend import torch as torch_backend return getattr(torch_backend, name) + if self._backend == "numpy": + # TODO (ariG23498): + # The import `from keras_core.backend import numpy as numpy_backend` + # is not working. This is a temporary fix. + # The import is redirected to `keras_core.backend.numpy.numpy.py` + from keras_core import backend as numpy_backend + + return getattr(numpy_backend, name) diff --git a/keras_core/utils/numerical_utils.py b/keras_core/utils/numerical_utils.py index ca412afd8..79c03409a 100644 --- a/keras_core/utils/numerical_utils.py +++ b/keras_core/utils/numerical_utils.py @@ -29,6 +29,9 @@ def normalize(x, axis=-1, order=2): # NumPy input norm = np.atleast_1d(np.linalg.norm(x, order, axis)) norm[norm == 0] = 1 + + # axis cannot be `None` + axis = axis or -1 return x / np.expand_dims(norm, axis) # Backend tensor input diff --git a/keras_core/utils/rng_utils_test.py b/keras_core/utils/rng_utils_test.py index 8ae629c5c..534207c1b 100644 --- a/keras_core/utils/rng_utils_test.py +++ b/keras_core/utils/rng_utils_test.py @@ -1,12 +1,18 @@ import numpy as np +import pytest import tensorflow as tf import keras_core +from keras_core import backend from keras_core.testing import test_case from keras_core.utils import rng_utils class TestRandomSeedSetting(test_case.TestCase): + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support random seed setting.", + ) def test_set_random_seed(self): def get_model_output(): model = keras_core.Sequential( diff --git a/keras_core/utils/traceback_utils.py b/keras_core/utils/traceback_utils.py index b050c09b8..afd75be43 100644 --- a/keras_core/utils/traceback_utils.py +++ b/keras_core/utils/traceback_utils.py @@ -230,6 +230,8 @@ def format_argument_value(value): tensor_cls = "jnp.ndarray" elif backend.backend() == "torch": tensor_cls = "torch.Tensor" + elif backend.backend() == "numpy": + tensor_cls = "np.ndarray" else: tensor_cls = "array"