Adding: Numpy Backend (#483)

* chore: adding numpy backend

* creview comments

* review comments

* chore: adding math

* chore: adding random module

* chore: adding ranndom in init

* review comments

* chore: adding numpy and nn for numpy backend

* chore: adding generic pool, max, and average pool

* chore: adding the conv ops

* chore: reformat code and using jax for conv and pool

* chore:  added self value

* chore: activation tests pass

* chore: adding post build method

* chore: adding necessaity methods to the numpy trainer

* chore: fixing utils test

* chore: fixing losses test suite

* chore: fix backend tests

* chore: fixing initializers test

* chore: fixing accuracy metrics test

* chore: fixing ops test

* chore: review comments

* chore: init with image and fixing random tests

* chore: skipping random seed set for numpy backend

* chore: adding single resize image method

* chore: skipping tests for applications and layers

* chore: skipping tests for models

* chore: skipping testsor saving

* chore: skipping tests for trainers

* chore:ixing one hot

* chore: fixing vmap in numpy and metrics test

* chore: adding a wrapper to numpy sum, started fixing layer tests

* fix: is_tensor now accepts numpy scalars

* chore: adding draw seed

* fix: warn message for numpy masking

* fix: checking whether kernel are tensors

* chore: adding rnn

* chore: adding dynamic backend for numpy

* fix: axis cannot be None for normalize

* chore: adding jax resize for numpy image

* chore: adding rnn implementation in numpy

* chore: using pytest fixtures

* change: numpy import string

* chore: review comments

* chore: adding numpy to backend list of github actions

* chore: remove debug print statements
This commit is contained in:
Aritra Roy Gosthipaty 2023-07-19 01:08:48 +05:30 committed by Francois Chollet
parent 860b1ca4da
commit 2481069ed4
104 changed files with 2089 additions and 14 deletions

@ -5,3 +5,24 @@ try:
import torch # noqa: F401
except ImportError:
pass
import pytest
from keras_core.backend import backend
def pytest_configure(config):
config.addinivalue_line(
"markers",
"requires_trainable_backend: mark test for trainable backend only",
)
def pytest_collection_modifyitems(config, items):
requires_trainable_backend = pytest.mark.skipif(
backend() == "numpy",
reason="Trainer not implemented for NumPy backend.",
)
for item in items:
if "requires_trainable_backend" in item.keywords:
item.add_marker(requires_trainable_backend)

@ -107,6 +107,7 @@ def _get_elephant(target_size):
os.environ.get("SKIP_APPLICATIONS_TESTS"),
reason="Env variable set to skip.",
)
@pytest.mark.requires_trainable_backend
class ApplicationsTest(testing.TestCase, parameterized.TestCase):
@parameterized.named_parameters(MODEL_LIST)
def test_application_notop_variable_input_channels(self, app, last_dim, _):

@ -1,4 +1,5 @@
import numpy as np
import pytest
from absl.testing import parameterized
import keras_core as keras
@ -74,6 +75,7 @@ class TestImageNetUtils(testing.TestCase, parameterized.TestCase):
{"testcase_name": "mode_caffe", "mode": "caffe"},
]
)
@pytest.mark.requires_trainable_backend
def test_preprocess_input_symbolic(self, mode):
# Test image batch
x = np.random.uniform(0, 255, (2, 10, 10, 3))

@ -37,5 +37,12 @@ elif backend() == "jax":
elif backend() == "torch":
print_msg("Using PyTorch backend.")
from keras_core.backend.torch import * # noqa: F403
elif backend() == "numpy":
print_msg(
"Using NumPy backend.\nThe NumPy backend does not support "
"training. It should only be used for inference, evaluation, "
"and debugging."
)
from keras_core.backend.numpy import * # noqa: F403
else:
raise ValueError(f"Unable to import backend : {backend()}")

@ -0,0 +1,20 @@
from keras_core.backend.numpy import core
from keras_core.backend.numpy import image
from keras_core.backend.numpy import math
from keras_core.backend.numpy import nn
from keras_core.backend.numpy import numpy
from keras_core.backend.numpy import random
from keras_core.backend.numpy.core import DYNAMIC_SHAPES_OK
from keras_core.backend.numpy.core import Variable
from keras_core.backend.numpy.core import cast
from keras_core.backend.numpy.core import compute_output_spec
from keras_core.backend.numpy.core import cond
from keras_core.backend.numpy.core import convert_to_numpy
from keras_core.backend.numpy.core import convert_to_tensor
from keras_core.backend.numpy.core import is_tensor
from keras_core.backend.numpy.core import name_scope
from keras_core.backend.numpy.core import shape
from keras_core.backend.numpy.core import vectorized_map
from keras_core.backend.numpy.rnn import gru
from keras_core.backend.numpy.rnn import lstm
from keras_core.backend.numpy.rnn import rnn

@ -0,0 +1,212 @@
from contextlib import nullcontext
import numpy as np
from tensorflow import nest
from keras_core.backend.common import KerasVariable
from keras_core.backend.common import standardize_dtype
from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.backend.common.stateless_scope import StatelessScope
DYNAMIC_SHAPES_OK = True
class Variable(KerasVariable):
def _initialize(self, value):
self._value = np.array(value, dtype=self._dtype)
def _direct_assign(self, value):
self._value = np.array(value, dtype=self._dtype)
def _convert_to_tensor(self, value, dtype=None):
return convert_to_tensor(value, dtype=dtype)
# Overload native accessor.
def __array__(self):
return self.value
def convert_to_tensor(x, dtype=None):
if dtype is not None:
dtype = standardize_dtype(dtype)
if isinstance(x, Variable):
if dtype and dtype != x.dtype:
return x.value.astype(dtype)
return x.value
return np.array(x, dtype=dtype)
def convert_to_numpy(x):
return np.array(x)
def is_tensor(x):
if isinstance(x, (np.generic, np.ndarray)):
return True
return False
def shape(x):
return x.shape
def cast(x, dtype):
return convert_to_tensor(x, dtype=dtype)
def cond(pred, true_fn, false_fn):
if pred:
return true_fn()
return false_fn()
def name_scope(name):
# There is no need for a named context for NumPy.
return nullcontext()
def vectorized_map(function, elements):
if len(elements) == 1:
return function(elements)
else:
batch_size = elements[0].shape[0]
output_store = list()
for index in range(batch_size):
output_store.append(function([x[index] for x in elements]))
return np.stack(output_store)
# Shape / dtype inference util
def compute_output_spec(fn, *args, **kwargs):
with StatelessScope():
def has_none_shape(x):
if isinstance(x, KerasTensor):
return None in x.shape
return False
none_in_shape = any(map(has_none_shape, nest.flatten((args, kwargs))))
def convert_keras_tensor_to_numpy(x, fill_value=None):
if isinstance(x, KerasTensor):
shape = list(x.shape)
if fill_value:
for i, e in enumerate(shape):
if e is None:
shape[i] = fill_value
return np.empty(
shape=shape,
dtype=x.dtype,
)
return x
args_1, kwargs_1 = nest.map_structure(
lambda x: convert_keras_tensor_to_numpy(x, fill_value=83),
(args, kwargs),
)
outputs_1 = fn(*args_1, **kwargs_1)
outputs = outputs_1
if none_in_shape:
args_2, kwargs_2 = nest.map_structure(
lambda x: convert_keras_tensor_to_numpy(x, fill_value=89),
(args, kwargs),
)
outputs_2 = fn(*args_2, **kwargs_2)
flat_out_1 = nest.flatten(outputs_1)
flat_out_2 = nest.flatten(outputs_2)
flat_out = []
for x1, x2 in zip(flat_out_1, flat_out_2):
shape = list(x1.shape)
for i, e in enumerate(x2.shape):
if e != shape[i]:
shape[i] = None
flat_out.append(KerasTensor(shape, standardize_dtype(x1.dtype)))
outputs = nest.pack_sequence_as(outputs_1, flat_out)
def convert_numpy_to_keras_tensor(x):
if is_tensor(x):
return KerasTensor(x.shape, standardize_dtype(x.dtype))
return x
output_spec = nest.map_structure(convert_numpy_to_keras_tensor, outputs)
return output_spec
def scatter(indices, values, shape):
indices = convert_to_tensor(indices)
values = convert_to_tensor(values)
zeros = np.zeros(shape, dtype=values.dtype)
index_length = indices.shape[-1]
value_shape = shape[index_length:]
indices = np.reshape(indices, [-1, index_length])
values = np.reshape(values, [-1] + list(value_shape))
for i in range(indices.shape[0]):
index = indices[i]
zeros[tuple(index)] += values[i]
return zeros
def scatter_update(inputs, indices, updates):
indices = np.array(indices)
indices = np.transpose(indices)
inputs[tuple(indices)] = updates
return inputs
def slice(inputs, start_indices, lengths):
# Validate inputs
assert len(start_indices) == len(lengths)
# Generate list of indices arrays for each dimension
indices = [
np.arange(start, start + length)
for start, length in zip(start_indices, lengths)
]
# Use np.ix_ to create a multidimensional index array
mesh = np.ix_(*indices)
return inputs[mesh]
def slice_update(inputs, start_indices, updates):
# Generate list of indices arrays for each dimension
indices = [
np.arange(start, start + length)
for start, length in zip(start_indices, updates.shape)
]
# Use np.ix_ to create a multidimensional index array
mesh = np.ix_(*indices)
inputs[mesh] = updates
return inputs
def while_loop(
cond,
body,
loop_vars,
maximum_iterations=None,
):
current_iter = 0
iteration_check = (
lambda iter: maximum_iterations is None or iter < maximum_iterations
)
loop_vars = tuple([convert_to_tensor(v) for v in loop_vars])
while cond(*loop_vars) and iteration_check(current_iter):
loop_vars = body(*loop_vars)
if not isinstance(loop_vars, (list, tuple)):
loop_vars = (loop_vars,)
loop_vars = tuple(loop_vars)
current_iter += 1
return loop_vars
def stop_gradient(x):
pass

@ -0,0 +1,45 @@
import jax
import numpy as np
RESIZE_METHODS = (
"bilinear",
"nearest",
"lanczos3",
"lanczos5",
"bicubic",
)
def resize(
image, size, method="bilinear", antialias=False, data_format="channels_last"
):
if method not in RESIZE_METHODS:
raise ValueError(
"Invalid value for argument `method`. Expected of one "
f"{RESIZE_METHODS}. Received: method={method}"
)
if not len(size) == 2:
raise ValueError(
"Argument `size` must be a tuple of two elements "
f"(height, width). Received: size={size}"
)
size = tuple(size)
if len(image.shape) == 4:
if data_format == "channels_last":
size = (image.shape[0],) + size + (image.shape[-1],)
else:
size = (image.shape[0], image.shape[1]) + size
elif len(image.shape) == 3:
if data_format == "channels_last":
size = size + (image.shape[-1],)
else:
size = (image.shape[0],) + size
else:
raise ValueError(
"Invalid input rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
return np.array(
jax.image.resize(image, size, method=method, antialias=antialias)
)

@ -0,0 +1,3 @@
class NumpyLayer:
def _post_build(self):
pass

@ -0,0 +1,76 @@
import numpy as np
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
if num_segments is None:
num_segments = np.amax(segment_ids) + 1
valid_indices = segment_ids >= 0 # Ignore segment_ids that are -1
valid_data = data[valid_indices]
valid_segment_ids = segment_ids[valid_indices]
data_shape = list(valid_data.shape)
data_shape[
0
] = num_segments # Replace first dimension (which corresponds to segments)
if sorted:
result = np.zeros(data_shape, dtype=valid_data.dtype)
np.add.at(result, valid_segment_ids, valid_data)
else:
sort_indices = np.argsort(valid_segment_ids)
sorted_segment_ids = valid_segment_ids[sort_indices]
sorted_data = valid_data[sort_indices]
result = np.zeros(data_shape, dtype=valid_data.dtype)
np.add.at(result, sorted_segment_ids, sorted_data)
return result
def top_k(x, k, sorted=False):
sorted_indices = np.argsort(x, axis=-1)[..., ::-1]
sorted_values = np.sort(x, axis=-1)[..., ::-1]
if sorted:
# Take the k largest values.
top_k_values = sorted_values[..., :k]
top_k_indices = sorted_indices[..., :k]
else:
# Partition the array such that all values larger than the k-th
# largest value are to the right of it.
top_k_values = np.partition(x, -k, axis=-1)[..., -k:]
top_k_indices = np.argpartition(x, -k, axis=-1)[..., -k:]
# Get the indices in sorted order.
idx = np.argsort(-top_k_values, axis=-1)
# Get the top k values and their indices.
top_k_values = np.take_along_axis(top_k_values, idx, axis=-1)
top_k_indices = np.take_along_axis(top_k_indices, idx, axis=-1)
return top_k_values, top_k_indices
def in_top_k(targets, predictions, k):
targets = targets[:, None]
topk_values = top_k(predictions, k)[0]
targets_values = np.take_along_axis(predictions, targets, axis=-1)
mask = targets_values >= topk_values
return np.any(mask, axis=-1)
def logsumexp(x, axis=None, keepdims=False):
max_x = np.max(x, axis=axis, keepdims=True)
result = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) + max_x
return np.squeeze(result) if not keepdims else result
def qr(x, mode="reduced"):
if mode not in {"reduced", "complete"}:
raise ValueError(
"`mode` argument value not supported. "
"Expected one of {'reduced', 'complete'}. "
f"Received: mode={mode}"
)
return np.linalg.qr(x, mode=mode)

@ -0,0 +1,517 @@
import jax
import numpy as np
from jax import lax
from jax import numpy as jnp
from keras_core.backend import standardize_data_format
from keras_core.backend.common.backend_utils import (
compute_conv_transpose_padding,
)
from keras_core.backend.config import epsilon
from keras_core.backend.numpy.core import is_tensor
def relu(x):
return np.maximum(x, 0.0)
def relu6(x):
return np.clip(x, 0.0, 6.0)
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-x))
def tanh(x):
return np.tanh(x)
def softplus(x):
return np.log(1.0 + np.exp(x))
def softsign(x):
return x / (1.0 + np.abs(x))
def silu(x):
return x * (1.0 / (1.0 + np.exp(-x)))
def swish(x):
return x * (1.0 / (1.0 + np.exp(-x)))
def log_sigmoid(x):
return np.log(1.0 / (1.0 + np.exp(-x)))
def leaky_relu(x, negative_slope=0.2):
return np.maximum(x, negative_slope * x)
def hard_sigmoid(x):
x = (x / 6.0) + 0.5
return np.where(x <= 0.0, 0.0, np.where(x >= 1.0, 1.0, x))
def elu(x, alpha=1.0):
return np.where(x >= 0.0, x, alpha * (np.exp(x) - 1.0))
def selu(
x,
alpha=1.6732632423543772848170429916717,
scale=1.0507009873554804934193349852946,
):
return scale * np.where(x >= 0.0, x, alpha * (np.exp(x) - 1.0))
def gelu(x, approximate=True):
if approximate:
return (
0.5
* x
* (
1.0
+ np.tanh(
np.sqrt(2.0 / np.pi) * (x + 0.044715 * np.power(x, 3))
)
)
)
else:
from scipy.stats import norm
return x * norm.cdf(x)
def softmax(x, axis=None):
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def log_softmax(x, axis=None):
max_x = np.max(x, axis=axis, keepdims=True)
logsumexp = np.log(np.exp(x - max_x).sum(axis=axis, keepdims=True))
return x - max_x - logsumexp
def _convert_to_spatial_operand(
x,
num_spatial_dims,
data_format="channels_last",
include_batch_and_channels=True,
):
# Helper function that converts an operand to a spatial operand.
x = (x,) * num_spatial_dims if isinstance(x, int) else x
if not include_batch_and_channels:
return x
if data_format == "channels_last":
x = (1,) + x + (1,)
else:
x = (1,) + (1,) + x
return x
def _pool(
inputs,
initial_value,
reduce_fn,
pool_size,
strides=None,
padding="valid",
):
"""Helper function to define pooling functions.
Args:
inputs: input data of shape `N+2`.
initial_value: the initial value for the reduction.
reduce_fn: a reduce function of the form `(T, T) -> T`.
pool_size: a sequence of `N` integers, representing the window size to
reduce over.
strides: a sequence of `N` integers, representing the inter-window
strides (default: `(1, ..., 1)`).
padding: either the string `same` or `valid`.
Returns:
The output of the reduction for each window slice.
"""
if padding not in ("same", "valid"):
raise ValueError(
f"Invalid padding '{padding}', must be 'same' or 'valid'."
)
padding = padding.upper()
return np.array(
lax.reduce_window(
inputs,
initial_value,
reduce_fn,
pool_size,
strides,
padding,
)
)
def max_pool(
inputs,
pool_size,
strides=None,
padding="valid",
data_format=None,
):
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
pool_size = _convert_to_spatial_operand(
pool_size, num_spatial_dims, data_format
)
strides = pool_size if strides is None else strides
strides = _convert_to_spatial_operand(
strides, num_spatial_dims, data_format
)
return _pool(inputs, -jnp.inf, lax.max, pool_size, strides, padding)
def average_pool(
inputs,
pool_size,
strides,
padding,
data_format=None,
):
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
pool_size = _convert_to_spatial_operand(
pool_size, num_spatial_dims, data_format
)
strides = pool_size if strides is None else strides
strides = _convert_to_spatial_operand(
strides, num_spatial_dims, data_format
)
pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding)
if padding == "valid":
# Avoid the extra reduce_window.
return pooled / np.prod(pool_size)
else:
# Count the number of valid entries at each input point, then use that
# for computing average. Assumes that any two arrays of same shape will
# be padded the same. Avoid broadcasting on axis where pooling is
# skipped.
shape = [
(a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size)
]
window_counts = _pool(
jnp.ones(shape, inputs.dtype),
0.0,
lax.add,
pool_size,
strides,
padding,
)
return pooled / window_counts
def _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format="channels_last",
transpose=False,
):
"""Create a `lax.ConvDimensionNumbers` for the given inputs."""
num_dims = num_spatial_dims + 2
if data_format == "channels_last":
spatial_dims = tuple(range(1, num_dims - 1))
inputs_dn = (0, num_dims - 1) + spatial_dims
else:
spatial_dims = tuple(range(2, num_dims))
inputs_dn = (0, 1) + spatial_dims
if transpose:
kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
else:
kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
return lax.ConvDimensionNumbers(
lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn
)
def conv(
inputs,
kernel,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
):
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format,
transpose=False,
)
strides = _convert_to_spatial_operand(
strides,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
dilation_rate = _convert_to_spatial_operand(
dilation_rate,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
if data_format == "channels_last":
channels = inputs.shape[-1]
else:
channels = inputs.shape[1]
kernel_in_channels = kernel.shape[-2]
if channels % kernel_in_channels > 0:
raise ValueError(
"The number of input channels must be evenly divisible by "
f"kernel's in_channels. Received input channels {channels} and "
f"kernel in_channels {kernel_in_channels}. "
)
feature_group_count = channels // kernel_in_channels
return np.array(
jax.lax.conv_general_dilated(
inputs,
kernel if is_tensor(kernel) else kernel.numpy(),
strides,
padding,
rhs_dilation=dilation_rate,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
)
)
def depthwise_conv(
inputs,
kernel,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
):
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format,
transpose=False,
)
strides = _convert_to_spatial_operand(
strides,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
dilation_rate = _convert_to_spatial_operand(
dilation_rate,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
feature_group_count = (
inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1]
)
kernel = jnp.reshape(
kernel if is_tensor(kernel) else kernel.numpy(),
kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]),
)
return np.array(
jax.lax.conv_general_dilated(
inputs,
kernel,
strides,
padding,
rhs_dilation=dilation_rate,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
)
)
def separable_conv(
inputs,
depthwise_kernel,
pointwise_kernel,
strides=1,
padding="valid",
data_format=None,
dilation_rate=1,
):
data_format = standardize_data_format(data_format)
depthwise_conv_output = depthwise_conv(
inputs,
depthwise_kernel,
strides,
padding,
data_format,
dilation_rate,
)
return conv(
depthwise_conv_output,
pointwise_kernel,
strides=1,
padding="valid",
data_format=data_format,
dilation_rate=dilation_rate,
)
def conv_transpose(
inputs,
kernel,
strides=1,
padding="valid",
output_padding=None,
data_format=None,
dilation_rate=1,
):
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
padding_values = compute_conv_transpose_padding(
inputs.shape,
kernel.shape,
strides,
padding,
output_padding,
data_format,
dilation_rate,
)
dimension_numbers = _convert_to_lax_conv_dimension_numbers(
num_spatial_dims,
data_format,
transpose=False,
)
strides = _convert_to_spatial_operand(
strides,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
dilation_rate = _convert_to_spatial_operand(
dilation_rate,
num_spatial_dims,
data_format,
include_batch_and_channels=False,
)
return np.array(
jax.lax.conv_transpose(
inputs,
kernel if is_tensor(kernel) else kernel.numpy(),
strides,
padding=padding_values,
rhs_dilation=dilation_rate,
dimension_numbers=dimension_numbers,
transpose_kernel=True,
)
)
def one_hot(x, num_classes, axis=-1, dtype="float32"):
input_shape = x.shape
# Shrink the last dimension if the shape is (..., 1).
if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
input_shape = tuple(input_shape[:-1])
x = x.reshape(-1)
if not num_classes:
num_classes = np.max(x) + 1
batch_size = x.shape[0]
categorical = np.zeros((batch_size, num_classes), dtype=dtype)
valid_indices = x >= 0
categorical[np.arange(batch_size)[valid_indices], x[valid_indices]] = 1
# First, reshape the array with the extra dimension at the end
output_shape = input_shape + (num_classes,)
categorical = np.reshape(categorical, output_shape)
# Then, move this new dimension to the right place (according to axis)
if axis != -1:
categorical = np.moveaxis(categorical, -1, axis)
return categorical
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
target = np.array(target)
output = np.array(output)
if target.shape != output.shape:
raise ValueError(
"Arguments `target` and `output` must have the same shape. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
if len(target.shape) < 1:
raise ValueError(
"Arguments `target` and `output` must be at least rank 1. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
if from_logits:
log_prob = log_softmax(output, axis=axis)
else:
output = output / np.sum(output, axis, keepdims=True)
output = np.clip(output, epsilon(), 1.0 - epsilon())
log_prob = np.log(output)
return -np.sum(target * log_prob, axis=axis)
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
target = np.array(target, dtype="int32")
output = np.array(output)
if len(target.shape) == len(output.shape) and target.shape[-1] == 1:
target = np.squeeze(target, axis=-1)
if len(output.shape) < 1:
raise ValueError(
"Argument `output` must be at least rank 1. "
"Received: "
f"output.shape={output.shape}"
)
if target.shape != output.shape[:-1]:
raise ValueError(
"Arguments `target` and `output` must have the same shape "
"up until the last dimension: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
if from_logits:
log_prob = log_softmax(output, axis=axis)
else:
output = output / np.sum(output, axis, keepdims=True)
output = np.clip(output, epsilon(), 1.0 - epsilon())
log_prob = np.log(output)
target = one_hot(target, output.shape[axis], axis=axis)
return -np.sum(target * log_prob, axis=axis)
def binary_crossentropy(target, output, from_logits=False):
target = np.array(target)
output = np.array(output)
if target.shape != output.shape:
raise ValueError(
"Arguments `target` and `output` must have the same shape. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
if from_logits:
output = sigmoid(output)
output = np.clip(output, epsilon(), 1.0 - epsilon())
bce = target * np.log(output)
bce += (1.0 - target) * np.log(1.0 - output)
return -bce

@ -0,0 +1,571 @@
import numpy as np
def add(x1, x2):
return np.add(x1, x2)
def einsum(subscripts, *operands, **kwargs):
return np.einsum(subscripts, *operands, **kwargs)
def subtract(x1, x2):
return np.subtract(x1, x2)
def matmul(x1, x2):
return np.matmul(x1, x2)
def multiply(x1, x2):
return np.multiply(x1, x2)
def mean(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.mean(x, axis=axis, keepdims=keepdims)
def max(x, axis=None, keepdims=False, initial=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.max(x, axis=axis, keepdims=keepdims, initial=initial)
def ones(shape, dtype="float32"):
return np.ones(shape, dtype=dtype)
def zeros(shape, dtype="float32"):
return np.zeros(shape, dtype=dtype)
def absolute(x):
return np.absolute(x)
def abs(x):
return absolute(x)
def all(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.all(x, axis=axis, keepdims=keepdims)
def any(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.any(x, axis=axis, keepdims=keepdims)
def amax(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.amax(x, axis=axis, keepdims=keepdims)
def amin(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.amin(x, axis=axis, keepdims=keepdims)
def append(
x1,
x2,
axis=None,
):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.append(x1, x2, axis=axis)
def arange(start, stop=None, step=None, dtype=None):
return np.arange(start, stop, step=step, dtype=dtype)
def arccos(x):
return np.arccos(x)
def arcsin(x):
return np.arcsin(x)
def arctan(x):
return np.arctan(x)
def arctan2(x1, x2):
return np.arctan2(x1, x2)
def argmax(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.argmax(x, axis=axis)
def argmin(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.argmin(x, axis=axis)
def argsort(x, axis=-1):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.argsort(x, axis=axis)
def array(x, dtype=None):
return np.array(x, dtype=dtype)
def average(x, axis=None, weights=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.average(x, weights=weights, axis=axis)
def bincount(x, weights=None, minlength=0):
return np.bincount(x, weights, minlength)
def broadcast_to(x, shape):
return np.broadcast_to(x, shape)
def ceil(x):
return np.ceil(x)
def clip(x, x_min, x_max):
return np.clip(x, x_min, x_max)
def concatenate(xs, axis=0):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.concatenate(xs, axis=axis)
def conjugate(x):
return np.conjugate(x)
def conj(x):
return conjugate(x)
def copy(x):
return np.copy(x)
def cos(x):
return np.cos(x)
def count_nonzero(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.count_nonzero(x, axis=axis)
def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.cross(
x1,
x2,
axisa=axisa,
axisb=axisb,
axisc=axisc,
axis=axis,
)
def cumprod(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.cumprod(x, axis=axis)
def cumsum(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.cumsum(x, axis=axis)
def diag(x, k=0):
return np.diag(x, k=k)
def diagonal(x, offset=0, axis1=0, axis2=1):
axis1 = tuple(axis1) if isinstance(axis1, list) else axis1
axis2 = tuple(axis2) if isinstance(axis2, list) else axis2
return np.diagonal(
x,
offset=offset,
axis1=axis1,
axis2=axis2,
)
def dot(x, y):
return np.dot(x, y)
def empty(shape, dtype="float32"):
return np.empty(shape, dtype=dtype)
def equal(x1, x2):
return np.equal(x1, x2)
def exp(x):
return np.exp(x)
def expand_dims(x, axis):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.expand_dims(x, axis)
def expm1(x):
return np.expm1(x)
def flip(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.flip(x, axis=axis)
def floor(x):
return np.floor(x)
def full(shape, fill_value, dtype=None):
return np.full(shape, fill_value, dtype=dtype)
def full_like(x, fill_value, dtype=None):
return np.full_like(x, fill_value, dtype=dtype)
def greater(x1, x2):
return np.greater(x1, x2)
def greater_equal(x1, x2):
return np.greater_equal(x1, x2)
def hstack(xs):
return np.hstack(xs)
def identity(n, dtype="float32"):
return np.identity(n, dtype=dtype)
def imag(x):
return np.imag(x)
def isclose(x1, x2):
return np.isclose(x1, x2)
def isfinite(x):
return np.isfinite(x)
def isinf(x):
return np.isinf(x)
def isnan(x):
return np.isnan(x)
def less(x1, x2):
return np.less(x1, x2)
def less_equal(x1, x2):
return np.less_equal(x1, x2)
def linspace(
start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0
):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.linspace(
start,
stop,
num=num,
endpoint=endpoint,
retstep=retstep,
dtype=dtype,
axis=axis,
)
def log(x):
return np.log(x)
def log10(x):
return np.log10(x)
def log1p(x):
return np.log1p(x)
def log2(x):
return np.log2(x)
def logaddexp(x1, x2):
return np.logaddexp(x1, x2)
def logical_and(x1, x2):
return np.logical_and(x1, x2)
def logical_not(x):
return np.logical_not(x)
def logical_or(x1, x2):
return np.logical_or(x1, x2)
def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
return np.logspace(
start,
stop,
num=num,
endpoint=endpoint,
base=base,
dtype=dtype,
axis=axis,
)
def maximum(x1, x2):
return np.maximum(x1, x2)
def meshgrid(*x, indexing="xy"):
return np.meshgrid(*x, indexing=indexing)
def min(x, axis=None, keepdims=False, initial=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.min(x, axis=axis, keepdims=keepdims, initial=initial)
def minimum(x1, x2):
return np.minimum(x1, x2)
def mod(x1, x2):
return np.mod(x1, x2)
def moveaxis(x, source, destination):
return np.moveaxis(x, source=source, destination=destination)
def nan_to_num(x):
return np.nan_to_num(x)
def ndim(x):
return np.ndim(x)
def nonzero(x):
return np.nonzero(x)
def not_equal(x1, x2):
return np.not_equal(x1, x2)
def zeros_like(x, dtype=None):
return np.zeros_like(x, dtype=dtype)
def ones_like(x, dtype=None):
return np.ones_like(x, dtype=dtype)
def outer(x1, x2):
return np.outer(x1, x2)
def pad(x, pad_width, mode="constant"):
return np.pad(x, pad_width, mode=mode)
def prod(x, axis=None, keepdims=False, dtype=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.prod(x, axis=axis, keepdims=keepdims, dtype=dtype)
def ravel(x):
return np.ravel(x)
def real(x):
return np.real(x)
def reciprocal(x):
return np.reciprocal(x)
def repeat(x, repeats, axis=None):
return np.repeat(x, repeats, axis=axis)
def reshape(x, new_shape):
return np.reshape(x, new_shape)
def roll(x, shift, axis=None):
return np.roll(x, shift, axis=axis)
def sign(x):
return np.sign(x)
def sin(x):
return np.sin(x)
def size(x):
return np.size(x)
def sort(x, axis=-1):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.sort(x, axis=axis)
def split(x, indices_or_sections, axis=0):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.split(x, indices_or_sections, axis=axis)
def stack(x, axis=0):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.stack(x, axis=axis)
def std(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.std(x, axis=axis, keepdims=keepdims)
def swapaxes(x, axis1, axis2):
return np.swapaxes(x, axis1=axis1, axis2=axis2)
def take(x, indices, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.take(x, indices, axis=axis)
def take_along_axis(x, indices, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.take_along_axis(x, indices, axis=axis)
def tan(x):
return np.tan(x)
def tensordot(x1, x2, axes=2):
axes = tuple(axes) if isinstance(axes, list) else axes
return np.tensordot(x1, x2, axes=axes)
def round(x, decimals=0):
return np.round(x, decimals=decimals)
def tile(x, repeats):
return np.tile(x, repeats)
def trace(x, offset=0, axis1=0, axis2=1):
axis1 = tuple(axis1) if isinstance(axis1, list) else axis1
axis2 = tuple(axis2) if isinstance(axis2, list) else axis2
return np.trace(x, offset=offset, axis1=axis1, axis2=axis2)
def tri(N, M=None, k=0, dtype="float32"):
return np.tri(N, M=M, k=k, dtype=dtype)
def tril(x, k=0):
return np.tril(x, k=k)
def triu(x, k=0):
return np.triu(x, k=k)
def vdot(x1, x2):
return np.vdot(x1, x2)
def vstack(xs):
return np.vstack(xs)
def where(condition, x1, x2):
return np.where(condition, x1, x2)
def divide(x1, x2):
return np.divide(x1, x2)
def true_divide(x1, x2):
return np.true_divide(x1, x2)
def power(x1, x2):
return np.power(x1, x2)
def negative(x):
return np.negative(x)
def square(x):
return np.square(x)
def sqrt(x):
return np.sqrt(x)
def squeeze(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.squeeze(x, axis=axis)
def transpose(x, axes=None):
axes = tuple(axes) if isinstance(axes, list) else axes
return np.transpose(x, axes=axes)
def var(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.var(x, axis=axis, keepdims=keepdims)
def sum(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.sum(x, axis=axis, keepdims=keepdims)
def eye(N, M=None, k=0, dtype="float32"):
return np.eye(N, M=M, k=k, dtype=dtype)

@ -0,0 +1,88 @@
import numpy as np
from keras_core.backend.config import floatx
from keras_core.backend.numpy.nn import softmax
from keras_core.random.seed_generator import SeedGenerator
from keras_core.random.seed_generator import draw_seed
from keras_core.random.seed_generator import make_default_seed
def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
dtype = dtype or floatx()
seed = draw_seed(seed)
rng = np.random.default_rng(seed)
return rng.normal(size=shape, loc=mean, scale=stddev).astype(dtype)
def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
dtype = dtype or floatx()
seed = draw_seed(seed)
rng = np.random.default_rng(seed)
return rng.uniform(size=shape, low=minval, high=maxval).astype(dtype)
def categorical(logits, num_samples, dtype="int64", seed=None):
seed = draw_seed(seed)
rng = np.random.default_rng(seed)
output = []
for logits_instance in logits:
probabilities = softmax(logits_instance)
classes = np.arange(logits_instance.shape[-1])
samples = rng.choice(classes, size=num_samples, p=probabilities)
output.append(samples)
return np.array(output).astype(dtype)
def randint(shape, minval, maxval, dtype="int32", seed=None):
seed = draw_seed(seed)
rng = np.random.default_rng(seed)
output = rng.integers(low=minval, high=maxval, size=shape, dtype=dtype)
return output
def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
dtype = dtype or floatx()
seed = draw_seed(seed)
rng = np.random.default_rng(seed)
lower_bound = mean - 2 * stddev
upper_bound = mean + 2 * stddev
flat_shape = np.prod(shape)
random_numbers = np.empty(0)
# loop until we have enough valid numbers to fill our desired shape
while random_numbers.shape[0] < flat_shape:
# Generate a batch of random numbers from a normal distribution
batch = rng.normal(loc=mean, scale=stddev, size=flat_shape)
# Filter the numbers to keep only those within the specified bounds
valid = batch[(batch >= lower_bound) & (batch <= upper_bound)]
# Append the valid numbers to the result array
random_numbers = np.append(random_numbers, valid)
# Truncate the result array to the desired size and reshape it
return random_numbers[:flat_shape].astype(dtype).reshape(shape)
def dropout(inputs, rate, noise_shape=None, seed=None):
seed = draw_seed(seed)
keep_prob = 1.0 - rate
# If noise_shape is not provided, use the shape of inputs
if noise_shape is None:
noise_shape = inputs.shape
else:
# If noise_shape is provided, replace None with corresponding
# input shape
noise_shape = [
n if n is not None else inputs.shape[i]
for i, n in enumerate(noise_shape)
]
rng = np.random.default_rng(seed)
mask = rng.uniform(size=noise_shape) < keep_prob
mask = np.broadcast_to(mask, inputs.shape)
return np.where(mask, inputs / keep_prob, np.zeros_like(inputs))

@ -0,0 +1,236 @@
import numpy as np
import tree
from keras_core.utils.nest import pack_sequence_as
def rnn(
step_function,
inputs,
initial_states,
go_backwards=False,
mask=None,
constants=None,
unroll=False,
input_length=None,
time_major=False,
zero_output_for_mask=False,
return_all_outputs=True,
):
def swap_batch_timestep(input_t):
# Swap the batch and timestep dim for the incoming tensor.
axes = list(range(len(input_t.shape)))
axes[0], axes[1] = 1, 0
return np.transpose(input_t, axes)
if not time_major:
inputs = tree.map_structure(swap_batch_timestep, inputs)
flattened_inputs = tree.flatten(inputs)
time_steps = flattened_inputs[0].shape[0]
if mask is not None:
if mask.dtype != "bool":
mask = mask.astype("bool")
if len(mask.shape) == 2:
mask = np.expand_dims(mask, axis=-1)
if not time_major:
mask = swap_batch_timestep(mask)
if constants is None:
constants = []
def _expand_mask(mask_t, input_t, fixed_dim=1):
if tree.is_nested(mask_t):
raise ValueError(
f"mask_t is expected to be tensor, but got {mask_t}"
)
if tree.is_nested(input_t):
raise ValueError(
f"input_t is expected to be tensor, but got {input_t}"
)
rank_diff = len(input_t.shape) - len(mask_t.shape)
for _ in range(rank_diff):
mask_t = np.expand_dims(mask_t, -1)
multiples = [1] * fixed_dim + list(input_t.shape[fixed_dim:])
return np.tile(mask_t, multiples)
if unroll:
if not time_steps:
raise ValueError("Unrolling requires a fixed number of timesteps.")
states = tuple(initial_states)
successive_states = []
successive_outputs = []
# Process the input tensors. The input tensor need to be split on the
# time_step dim, and reverse if go_backwards is True. In the case of
# nested input, the input is flattened and then transformed
# individually. The result of this will be a tuple of lists, each of
# the item in tuple is list of the tensor with shape (batch, feature)
def _process_single_input_t(input_t):
input_t = unstack(input_t) # unstack for time_step dim
if go_backwards:
input_t.reverse()
return input_t
if tree.is_nested(inputs):
processed_input = tree.map_structure(
_process_single_input_t, inputs
)
else:
processed_input = (_process_single_input_t(inputs),)
def _get_input_tensor(time):
inp = [t_[time] for t_ in processed_input]
return pack_sequence_as(inputs, inp)
if mask is not None:
mask_list = unstack(mask)
if go_backwards:
mask_list.reverse()
for i in range(time_steps):
inp = _get_input_tensor(i)
mask_t = mask_list[i]
output, new_states = step_function(
inp, tuple(states) + tuple(constants)
)
tiled_mask_t = _expand_mask(mask_t, output)
if not successive_outputs:
prev_output = np.zeros_like(output)
else:
prev_output = successive_outputs[-1]
output = np.where(tiled_mask_t, output, prev_output)
flat_states = tree.flatten(states)
flat_new_states = tree.flatten(new_states)
tiled_mask_t = tuple(
_expand_mask(mask_t, s) for s in flat_states
)
flat_final_states = tuple(
np.where(m, s, ps)
for m, s, ps in zip(
tiled_mask_t, flat_new_states, flat_states
)
)
states = pack_sequence_as(states, flat_final_states)
if return_all_outputs:
successive_outputs.append(output)
successive_states.append(states)
else:
successive_outputs = [output]
successive_states = [states]
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = np.stack(successive_outputs)
else: # mask is None
for i in range(time_steps):
inp = _get_input_tensor(i)
output, states = step_function(
inp, tuple(states) + tuple(constants)
)
if return_all_outputs:
successive_outputs.append(output)
successive_states.append(states)
else:
successive_outputs = [output]
successive_states = [states]
last_output = successive_outputs[-1]
new_states = successive_states[-1]
outputs = np.stack(successive_outputs)
else: # Unroll == False
if mask is not None:
def _step(states, current_input):
current_input, current_mask = current_input
is_masked = np.all(
np.logical_not(current_mask), axis=-1, keepdims=True
)
output_t, new_states = step_function(current_input, states)
if zero_output_for_mask:
masked_outs = np.where(
is_masked, np.zeros_like(output_t), output_t
)
else:
# Assume the first state is the previous output.
output_tm1 = states[0]
masked_outs = np.where(is_masked, output_tm1, output_t)
new_states = [
np.where(is_masked, s, ns)
for s, ns in zip(states, new_states)
]
return (new_states, masked_outs)
scan_xs = (inputs, mask)
else:
def _step(states, current_input):
output_t, new_states = step_function(current_input, states)
return new_states, output_t
scan_xs = inputs
new_states, outputs = numpy_scan(
f=_step,
init=initial_states,
xs=scan_xs,
reverse=go_backwards,
mask=mask,
)
if go_backwards:
outputs = np.flip(outputs, axis=0)
last_output = outputs[-1]
if not time_major:
outputs = tree.map_structure(swap_batch_timestep, outputs)
return last_output, outputs, new_states
def lstm(*args, **kwargs):
raise NotImplementedError
def gru(*args, **kwargs):
raise NotImplementedError
def unstack(x, axis=0):
return [x.take(i, axis) for i in range(x.shape[axis])]
def numpy_scan(f, init, xs, reverse=False, mask=None):
states = init
outputs = []
if mask is not None:
x, mask = xs
x = np.flip(x, axis=0) if reverse else x
mask = np.flip(mask, axis=0) if reverse else mask
for each_x, each_mask in zip(x, mask):
states, output = f(states, (each_x, each_mask))
outputs.append(output)
else:
xs = np.flip(xs, axis=0) if reverse else xs
for x in xs:
states, output = f(states, x)
outputs.append(output)
outputs = np.array(outputs)
if reverse:
outputs = np.flip(outputs, axis=0)
return states, outputs

@ -0,0 +1,18 @@
class NumpyTrainer:
def fit(self):
raise NotImplementedError("Trainer not implemented for NumPy backend.")
def predict(self):
raise NotImplementedError("Trainer not implemented for NumPy backend.")
def evaluate(self):
raise NotImplementedError("Trainer not implemented for NumPy backend.")
def train_on_batch(self):
raise NotImplementedError("Trainer not implemented for NumPy backend.")
def test_on_batch(self):
raise NotImplementedError("Trainer not implemented for NumPy backend.")
def predict_on_batch(self):
raise NotImplementedError("Trainer not implemented for NumPy backend.")

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import models
from keras_core import testing
@ -6,6 +7,7 @@ from keras_core.callbacks.callback import Callback
class CallbackTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_model_state_is_current_on_epoch_end(self):
class TestModel(models.Model):
def __init__(self):

@ -4,6 +4,7 @@ import re
import tempfile
import numpy as np
import pytest
from keras_core import callbacks
from keras_core import initializers
@ -19,6 +20,7 @@ BATCH_SIZE = 4
class CSVLoggerTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_CSVLogger(self):
OUTPUT_DIM = 1
np.random.seed(1337)
@ -126,6 +128,7 @@ class CSVLoggerTest(testing.TestCase):
else:
self.assertEqual(row["val_loss"], "NA")
@pytest.mark.requires_trainable_backend
def test_stop_training_csv(self):
# Test that using the CSVLogger callback with the TerminateOnNaN
# callback does not result in invalid CSVs.

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import callbacks
from keras_core import layers
@ -7,6 +8,7 @@ from keras_core import testing
class EarlyStoppingTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_early_stopping(self):
x_train = np.random.random((10, 5))
y_train = np.random.random((10, 1))
@ -48,6 +50,7 @@ class EarlyStoppingTest(testing.TestCase):
verbose=0,
)
@pytest.mark.requires_trainable_backend
def test_early_stopping_patience(self):
cases = [0, 1, 2, 3]
losses = [10.0, 9.0, 8.0, 9.0, 8.9, 8.8, 8.7, 8.6, 8.5]
@ -65,6 +68,7 @@ class EarlyStoppingTest(testing.TestCase):
self.assertEqual(stopper.stopped_epoch, max(patience, 1) + 2)
@pytest.mark.requires_trainable_backend
def test_early_stopping_reuse(self):
patience = 3
data = np.random.random((100, 1))
@ -91,6 +95,7 @@ class EarlyStoppingTest(testing.TestCase):
)
assert len(hist.epoch) >= patience
@pytest.mark.requires_trainable_backend
def test_early_stopping_with_baseline(self):
baseline = 0.6
x_train = np.random.random((10, 5))
@ -169,6 +174,7 @@ class EarlyStoppingTest(testing.TestCase):
self.assertEqual(epochs_trained, 5)
self.assertEqual(early_stop.model.get_weights(), 2)
@pytest.mark.requires_trainable_backend
def test_early_stopping_with_start_from_epoch(self):
x_train = np.random.random((10, 5))
y_train = np.random.random((10, 1))

@ -1,4 +1,5 @@
import numpy as np
import pytest
from absl import logging
from keras_core import callbacks
@ -10,6 +11,7 @@ from keras_core.models.sequential import Sequential
class LambdaCallbackTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_LambdaCallback(self):
BATCH_SIZE = 4
model = Sequential(

@ -1,3 +1,5 @@
import pytest
from keras_core import callbacks
from keras_core import layers
from keras_core import optimizers
@ -29,6 +31,7 @@ class LearningRateSchedulerTest(testing.TestCase):
self.x_train = x_train
self.y_train = y_train
@pytest.mark.requires_trainable_backend
def test_updates_learning_rate(self):
lr_scheduler = callbacks.LearningRateScheduler(
lambda step: 1.0 / (2.0 + step), verbose=1
@ -43,6 +46,7 @@ class LearningRateSchedulerTest(testing.TestCase):
self.assertEqual(self.model.optimizer.learning_rate.value, 0.5)
@pytest.mark.requires_trainable_backend
def test_verbose_logging(self):
lr_scheduler = callbacks.LearningRateScheduler(
lambda step: 1.0 / (1.0 + step), verbose=1
@ -59,6 +63,7 @@ class LearningRateSchedulerTest(testing.TestCase):
expected_log = "LearningRateScheduler setting learning rate to 1.0"
self.assertTrue(any(expected_log in log for log in logs.output))
@pytest.mark.requires_trainable_backend
def test_schedule_dependent_on_previous_learning_rate(self):
lr_scheduler = callbacks.LearningRateScheduler(lambda step, lr: lr / 2)
@ -78,6 +83,7 @@ class LearningRateSchedulerTest(testing.TestCase):
self.model.optimizer.learning_rate.value, initial_lr / 4.0
)
@pytest.mark.requires_trainable_backend
def test_throws_when_optimizer_has_schedule(self):
lr_scheduler = callbacks.LearningRateScheduler(lambda step, lr: lr / 2)

@ -31,6 +31,7 @@ class ModelCheckpointTest(testing.TestCase):
h5py is None,
reason="`h5py` is a required dependency for `ModelCheckpoint` tests.",
)
@pytest.mark.requires_trainable_backend
def test_model_checkpoint_options(self):
def get_model():
model = Sequential(
@ -445,6 +446,7 @@ class ModelCheckpointTest(testing.TestCase):
h5py is None,
reason="`h5py` is a required dependency for `ModelCheckpoint` tests.",
)
@pytest.mark.requires_trainable_backend
def test_model_checkpoint_loading(self):
def get_model():
inputs = layers.Input(shape=(INPUT_DIM,), batch_size=2)

@ -1,3 +1,5 @@
import pytest
from keras_core import callbacks
from keras_core import layers
from keras_core import optimizers
@ -32,6 +34,7 @@ class ReduceLROnPlateauTest(testing.TestCase):
self.y_train = y_train
self.y_test = y_test
@pytest.mark.requires_trainable_backend
def test_reduces_lr_with_model_fit(self):
reduce_lr = callbacks.ReduceLROnPlateau(
patience=1, factor=0.1, monitor="val_loss", min_delta=100
@ -47,6 +50,7 @@ class ReduceLROnPlateauTest(testing.TestCase):
self.assertEqual(self.model.optimizer.learning_rate.value, 0.01)
@pytest.mark.requires_trainable_backend
def test_throws_when_optimizer_has_schedule(self):
reduce_lr = callbacks.ReduceLROnPlateau(
patience=1, factor=0.1, monitor="val_loss", min_delta=100
@ -73,6 +77,7 @@ class ReduceLROnPlateauTest(testing.TestCase):
epochs=2,
)
@pytest.mark.requires_trainable_backend
def test_verbose_logging(self):
reduce_lr = callbacks.ReduceLROnPlateau(
patience=1, factor=0.1, monitor="val_loss", min_delta=100, verbose=1
@ -90,6 +95,7 @@ class ReduceLROnPlateauTest(testing.TestCase):
expected_log = "ReduceLROnPlateau reducing learning rate to 0.01"
self.assertTrue(any(expected_log in log for log in logs.output))
@pytest.mark.requires_trainable_backend
def test_honors_min_lr(self):
reduce_lr = callbacks.ReduceLROnPlateau(
patience=1,
@ -109,6 +115,7 @@ class ReduceLROnPlateauTest(testing.TestCase):
self.assertEqual(self.model.optimizer.learning_rate.value, 0.005)
@pytest.mark.requires_trainable_backend
def test_cooldown(self):
reduce_lr = callbacks.ReduceLROnPlateau(
patience=1,

@ -3,6 +3,7 @@ from unittest import mock
import numpy as np
from keras_core import backend
from keras_core import callbacks
from keras_core import layers
from keras_core import testing
@ -60,6 +61,8 @@ class TerminateOnNaNTest(testing.TestCase):
if requests is None:
self.skipTest("`requests` required to run this test")
if backend.backend() == "numpy":
self.skipTest("Trainer not implemented from NumPy backend.")
TRAIN_SAMPLES = 10
TEST_SAMPLES = 10
INPUT_DIM = 3

@ -143,6 +143,7 @@ class TestTensorBoardV2(testing.TestCase):
model.compile("sgd", "mse")
return model
@pytest.mark.requires_trainable_backend
def test_TensorBoard_basic(self):
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
@ -171,6 +172,7 @@ class TestTensorBoardV2(testing.TestCase):
},
)
@pytest.mark.requires_trainable_backend
def test_TensorBoard_across_invocations(self):
"""Regression test for summary writer resource use-after-free."""
model = self._get_model()
@ -201,6 +203,7 @@ class TestTensorBoardV2(testing.TestCase):
},
)
@pytest.mark.requires_trainable_backend
def test_TensorBoard_no_spurious_event_files(self):
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
@ -214,6 +217,7 @@ class TestTensorBoardV2(testing.TestCase):
events_file_run_basenames.add(os.path.basename(dirpath))
self.assertEqual(events_file_run_basenames, {"train"})
@pytest.mark.requires_trainable_backend
def test_TensorBoard_batch_metrics(self):
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
@ -243,6 +247,7 @@ class TestTensorBoardV2(testing.TestCase):
},
)
@pytest.mark.requires_trainable_backend
def test_TensorBoard_learning_rate_schedules(self):
model = self._get_model(compile_model=False)
opt = optimizers.SGD(schedules.CosineDecay(0.01, 1))
@ -267,6 +272,7 @@ class TestTensorBoardV2(testing.TestCase):
},
)
@pytest.mark.requires_trainable_backend
def test_TensorBoard_global_step(self):
model = self._get_model(compile_model=False)
opt = optimizers.SGD(schedules.CosineDecay(0.01, 1))
@ -306,6 +312,7 @@ class TestTensorBoardV2(testing.TestCase):
},
)
@pytest.mark.requires_trainable_backend
def test_TensorBoard_weight_histograms(self):
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
@ -338,6 +345,7 @@ class TestTensorBoardV2(testing.TestCase):
{_ObservedSummary(logdir=train_dir, tag="histogram")},
)
@pytest.mark.requires_trainable_backend
def test_TensorBoard_weight_images(self):
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
@ -392,6 +400,7 @@ class TestTensorBoardV2(testing.TestCase):
expected_image_summaries,
)
@pytest.mark.requires_trainable_backend
def test_TensorBoard_projector_callback(self):
model = models.Sequential(
[
@ -433,6 +442,7 @@ class TestTensorBoardV2(testing.TestCase):
],
)
@pytest.mark.requires_trainable_backend
def test_custom_summary(self):
def scalar_v2_mock(name, data, step=None):
"""A reimplementation of the scalar plugin to avoid circular
@ -592,6 +602,7 @@ class TestTensorBoardV2(testing.TestCase):
backend.backend() == "torch",
reason="Torch backend requires blocking numpy conversion.",
)
@pytest.mark.requires_trainable_backend
def test_TensorBoard_non_blocking(self):
logdir, _, _ = self._get_log_dirs()
model = models.Sequential([layers.Dense(1)])

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import callbacks
from keras_core import initializers
@ -9,6 +10,7 @@ from keras_core.utils import numerical_utils
class TerminateOnNaNTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_TerminateOnNaN(self):
TRAIN_SAMPLES = 10
TEST_SAMPLES = 10

@ -1,9 +1,12 @@
import pytest
from keras_core import activations
from keras_core import layers
from keras_core import testing
class ActivationTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_activation_basics(self):
self.run_layer_test(
layers.Activation,

@ -1,4 +1,5 @@
import numpy as np
import pytest
import tensorflow as tf
from keras_core import testing
@ -10,6 +11,7 @@ class ELUTest(testing.TestCase):
elu_layer = elu.ELU()
self.run_class_serialization_test(elu_layer)
@pytest.mark.requires_trainable_backend
def test_elu(self):
self.run_layer_test(
elu.ELU,

@ -1,10 +1,12 @@
import numpy as np
import pytest
from keras_core import testing
from keras_core.layers.activations import leaky_relu
class LeakyReLUTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_leaky_relu(self):
self.run_layer_test(
leaky_relu.LeakyReLU,

@ -1,4 +1,5 @@
import numpy as np
import pytest
import tensorflow as tf
from keras_core import testing
@ -6,6 +7,7 @@ from keras_core.layers.activations import prelu
class PReLUTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_prelu(self):
self.run_layer_test(
prelu.PReLU,

@ -1,10 +1,12 @@
import numpy as np
import pytest
from keras_core import testing
from keras_core.layers.activations import relu
class ReLUTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_relu(self):
self.run_layer_test(
relu.ReLU,

@ -1,10 +1,12 @@
import numpy as np
import pytest
from keras_core import testing
from keras_core.layers.activations import softmax
class SoftmaxTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_softmax(self):
self.run_layer_test(
softmax.Softmax,

@ -1,6 +1,8 @@
import numpy as np
import pytest
from absl.testing import parameterized
from keras_core import backend
from keras_core import initializers
from keras_core import layers
from keras_core import testing
@ -164,6 +166,10 @@ class MultiHeadAttentionTest(testing.TestCase, parameterized.TestCase):
layer._output_dense.kernel,
)
@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy backend does not support masking.",
)
def test_query_mask_progagation(self):
"""Test automatic propagation of the query's mask."""
layer = layers.MultiHeadAttention(num_heads=2, key_dim=2)
@ -175,6 +181,10 @@ class MultiHeadAttentionTest(testing.TestCase, parameterized.TestCase):
self.assertAllClose(masked_query._keras_mask, output._keras_mask)
@parameterized.named_parameters(("causal", True), ("not_causal", 0))
@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy backend does not support masking.",
)
def test_masking(self, use_causal_mask):
"""Test that the value and causal masks are taken into account."""
layer = layers.MultiHeadAttention(num_heads=2, key_dim=2)

@ -1,4 +1,5 @@
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized
@ -53,6 +54,7 @@ class ConvBasicTest(testing.TestCase, parameterized.TestCase):
"output_shape": (3, 2, 6),
},
)
@pytest.mark.requires_trainable_backend
def test_conv1d_basic(
self,
filters,
@ -119,6 +121,7 @@ class ConvBasicTest(testing.TestCase, parameterized.TestCase):
"output_shape": (3, 2, 4, 6),
},
)
@pytest.mark.requires_trainable_backend
def test_conv2d_basic(
self,
filters,
@ -185,6 +188,7 @@ class ConvBasicTest(testing.TestCase, parameterized.TestCase):
"output_shape": (3, 2, 4, 2, 6),
},
)
@pytest.mark.requires_trainable_backend
def test_conv3d_basic(
self,
filters,

@ -44,6 +44,7 @@ class ConvTransposeBasicTest(testing.TestCase, parameterized.TestCase):
"output_shape": (2, 16, 6),
},
)
@pytest.mark.requires_trainable_backend
def test_conv1d_transpose_basic(
self,
filters,
@ -121,6 +122,7 @@ class ConvTransposeBasicTest(testing.TestCase, parameterized.TestCase):
"output_shape": (1, 224, 224, 2),
},
)
@pytest.mark.requires_trainable_backend
def test_conv2d_transpose_basic(
self,
filters,
@ -193,6 +195,7 @@ class ConvTransposeBasicTest(testing.TestCase, parameterized.TestCase):
"output_shape": (2, 16, 9, 17, 6),
},
)
@pytest.mark.requires_trainable_backend
def test_conv3d_transpose_basic(
self,
filters,

@ -1,4 +1,5 @@
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized
@ -39,6 +40,7 @@ class DepthwiseConvBasicTest(testing.TestCase, parameterized.TestCase):
"output_shape": (3, 2, 24),
},
)
@pytest.mark.requires_trainable_backend
def test_depthwise_conv1d_basic(
self,
depth_multiplier,
@ -100,6 +102,7 @@ class DepthwiseConvBasicTest(testing.TestCase, parameterized.TestCase):
"output_shape": (3, 2, 2, 24),
},
)
@pytest.mark.requires_trainable_backend
def test_depthwise_conv2d_basic(
self,
depth_multiplier,

@ -1,4 +1,5 @@
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized
@ -42,6 +43,7 @@ class SeparableConvBasicTest(testing.TestCase, parameterized.TestCase):
"output_shape": (3, 2, 6),
},
)
@pytest.mark.requires_trainable_backend
def test_separable_conv1d_basic(
self,
depth_multiplier,
@ -108,6 +110,7 @@ class SeparableConvBasicTest(testing.TestCase, parameterized.TestCase):
"output_shape": (3, 2, 2, 6),
},
)
@pytest.mark.requires_trainable_backend
def test_separable_conv2d_basic(
self,
depth_multiplier,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import layers
from keras_core import testing
@ -6,6 +7,7 @@ from keras_core.backend.common import keras_tensor
class DenseTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_dense_basics(self):
# 2D case, no bias.
self.run_layer_test(

@ -1,3 +1,4 @@
import pytest
from absl.testing import parameterized
from keras_core import layers
@ -227,6 +228,7 @@ class EinsumDenseTest(testing.TestCase, parameterized.TestCase):
"expected_output_shape": (2, 3, 4, 2),
},
)
@pytest.mark.requires_trainable_backend
def test_einsum_dense_basics(
self,
equation,

@ -1,10 +1,12 @@
import numpy as np
import pytest
from keras_core import layers
from keras_core.testing import test_case
class EmbeddingTest(test_case.TestCase):
@pytest.mark.requires_trainable_backend
def test_embedding_basics(self):
self.run_layer_test(
layers.Embedding,

@ -1,8 +1,11 @@
import pytest
from keras_core import layers
from keras_core import testing
class IdentityTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_identity_basics(self):
self.run_layer_test(
layers.Identity,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import layers
from keras_core import ops
@ -6,6 +7,7 @@ from keras_core import testing
class LambdaTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_lambda_basics(self):
self.run_layer_test(
layers.Lambda,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import layers
from keras_core import models
@ -6,6 +7,7 @@ from keras_core import testing
class MaskingTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_masking_basics(self):
self.run_layer_test(
layers.Masking,
@ -19,6 +21,7 @@ class MaskingTest(testing.TestCase):
supports_masking=True,
)
@pytest.mark.requires_trainable_backend
def test_masking_correctness(self):
x = np.array(
[

@ -1,3 +1,5 @@
import pytest
from keras_core import layers
from keras_core import testing
@ -10,6 +12,7 @@ class ExampleWrapper(layers.Wrapper):
class WrapperTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_wrapper_basics(self):
self.run_layer_test(
ExampleWrapper,

@ -44,6 +44,8 @@ elif backend.backend() == "jax":
from keras_core.backend.jax.layer import JaxLayer as BackendLayer
elif backend.backend() == "torch":
from keras_core.backend.torch.layer import TorchLayer as BackendLayer
elif backend.backend() == "numpy":
from keras_core.backend.numpy.layer import NumpyLayer as BackendLayer
else:
raise RuntimeError(
f"Backend '{backend.backend()}' must implement a layer mixin class."
@ -1235,6 +1237,12 @@ class Layer(BackendLayer, Operation):
for tensor, mask in zip(flat_outputs, flat_masks):
if getattr(tensor, "_keras_mask", None) is None:
try:
# Numpy backend does not support masking.
if backend.backend() == "numpy":
warnings.warn(
"The NumPy backend does not support masking at this"
"time. Masks will be ignored."
)
tensor._keras_mask = mask
except AttributeError:
# It's a C type.

@ -264,6 +264,7 @@ class LayerTest(testing.TestCase):
layer(layers.Input(batch_shape=(2, 2)))
self.assertLen(layer.losses, 0)
@pytest.mark.requires_trainable_backend
def test_add_loss(self):
class LossLayer(layers.Layer):
def call(self, x):
@ -378,6 +379,10 @@ class LayerTest(testing.TestCase):
self.assertEqual(backend.standardize_dtype(y.dtype), "float16")
self.assertEqual(layer.kernel.dtype, "float32")
@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy backend does not support masking.",
)
def test_masking(self):
class BasicMaskedLayer(layers.Layer):
def __init__(self):

@ -8,6 +8,7 @@ from keras_core import ops
from keras_core import testing
@pytest.mark.requires_trainable_backend
class MergingLayersTest(testing.TestCase):
def test_add_basic(self):
self.run_layer_test(

@ -1,4 +1,5 @@
import numpy as np
import pytest
from absl.testing import parameterized
from keras_core import backend
@ -7,6 +8,7 @@ from keras_core import testing
class BatchNormalizationTest(testing.TestCase, parameterized.TestCase):
@pytest.mark.requires_trainable_backend
def test_bn_basics(self):
# vector case
self.run_layer_test(

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import constraints
from keras_core import layers
@ -7,6 +8,7 @@ from keras_core import testing
class GroupNormalizationTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_groupnorm(self):
self.run_layer_test(
layers.GroupNormalization,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import layers
from keras_core import ops
@ -7,6 +8,7 @@ from keras_core import testing
class LayerNormalizationTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_ln_basics(self):
self.run_layer_test(
layers.LayerNormalization,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import initializers
from keras_core import layers
@ -6,6 +7,7 @@ from keras_core import testing
class SpectralNormalizationTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_basic_spectralnorm(self):
self.run_layer_test(
layers.SpectralNormalization,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import layers
@ -11,6 +12,7 @@ def squared_l2_norm(x):
class UnitNormalizationTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_un_basics(self):
self.run_layer_test(
layers.UnitNormalization,

@ -8,6 +8,7 @@ from keras_core import layers
from keras_core import testing
@pytest.mark.requires_trainable_backend
class AveragePoolingBasicTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
(2, 1, "valid", "channels_last", (3, 5, 4), (3, 4, 4)),

@ -1,4 +1,5 @@
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized
@ -6,6 +7,7 @@ from keras_core import layers
from keras_core import testing
@pytest.mark.requires_trainable_backend
class GlobalAveragePoolingBasicTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
("channels_last", False, (3, 5, 4), (3, 4)),

@ -1,4 +1,5 @@
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized
@ -6,6 +7,7 @@ from keras_core import layers
from keras_core import testing
@pytest.mark.requires_trainable_backend
class GlobalMaxPoolingBasicTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
("channels_last", False, (3, 5, 4), (3, 4)),

@ -1,4 +1,5 @@
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized
@ -6,6 +7,7 @@ from keras_core import layers
from keras_core import testing
@pytest.mark.requires_trainable_backend
class MaxPoolingBasicTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
(2, 1, "valid", "channels_last", (3, 5, 4), (3, 4, 4)),

@ -1,4 +1,5 @@
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized
@ -7,6 +8,7 @@ from keras_core import testing
class CenterCropTest(testing.TestCase, parameterized.TestCase):
@pytest.mark.requires_trainable_backend
def test_center_crop_basics(self):
self.run_layer_test(
layers.CenterCrop,

@ -1,4 +1,5 @@
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized
@ -8,6 +9,7 @@ from keras_core import testing
class NormalizationTest(testing.TestCase, parameterized.TestCase):
@pytest.mark.requires_trainable_backend
def test_normalization_basics(self):
self.run_layer_test(
layers.Normalization,

@ -1,4 +1,5 @@
import numpy as np
import pytest
import tensorflow as tf
from keras_core import backend
@ -7,6 +8,7 @@ from keras_core import testing
class RandomBrightnessTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer(self):
self.run_layer_test(
layers.RandomBrightness,

@ -1,4 +1,5 @@
import numpy as np
import pytest
import tensorflow as tf
from keras_core import backend
@ -7,6 +8,7 @@ from keras_core import testing
class RandomContrastTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer(self):
self.run_layer_test(
layers.RandomContrast,

@ -1,4 +1,5 @@
import numpy as np
import pytest
import tensorflow as tf
from keras_core import layers
@ -6,6 +7,7 @@ from keras_core import testing
class RescalingTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_rescaling_basics(self):
self.run_layer_test(
layers.Rescaling,
@ -19,6 +21,7 @@ class RescalingTest(testing.TestCase):
supports_masking=True,
)
@pytest.mark.requires_trainable_backend
def test_rescaling_dtypes(self):
# int scale
self.run_layer_test(

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import layers
from keras_core.testing import test_case
@ -11,6 +12,7 @@ class ActivityRegularizationTest(test_case.TestCase):
self.assertLen(layer.losses, 1)
self.assertAllClose(layer.losses[0], 4 * 0.3 + 2 * 0.2)
@pytest.mark.requires_trainable_backend
def test_activity_regularization_basics(self):
self.run_layer_test(
layers.ActivityRegularization,

@ -7,6 +7,7 @@ from keras_core import testing
class DropoutTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_dropout_basics(self):
self.run_layer_test(
layers.Dropout,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import layers
@ -6,6 +7,7 @@ from keras_core import testing
class GaussianDropoutTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_gaussian_dropout_basics(self):
self.run_layer_test(
layers.GaussianDropout,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import layers
@ -6,6 +7,7 @@ from keras_core import testing
class GaussianNoiseTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_gaussian_noise_basics(self):
self.run_layer_test(
layers.GaussianNoise,

@ -1,10 +1,12 @@
import numpy as np
import pytest
from keras_core import layers
from keras_core.testing import test_case
class SpatialDropoutTest(test_case.TestCase):
@pytest.mark.requires_trainable_backend
def test_spatial_dropout_1d(self):
self.run_layer_test(
layers.SpatialDropout1D,
@ -20,6 +22,7 @@ class SpatialDropoutTest(test_case.TestCase):
input_shape=(2, 3, 4),
)
@pytest.mark.requires_trainable_backend
def test_spatial_dropout_2d(self):
self.run_layer_test(
layers.SpatialDropout2D,
@ -35,6 +38,7 @@ class SpatialDropoutTest(test_case.TestCase):
input_shape=(2, 3, 4, 5),
)
@pytest.mark.requires_trainable_backend
def test_spatial_dropout_3d(self):
self.run_layer_test(
layers.SpatialDropout3D,

@ -1,13 +1,13 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import layers
from keras_core import ops
from keras_core import testing
class Cropping1DTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_cropping_1d(self):
inputs = np.random.rand(3, 5, 7)
@ -47,10 +47,7 @@ class Cropping1DTest(testing.TestCase):
expected_output=ops.convert_to_tensor(inputs[:, 1:5, :]),
)
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
@pytest.mark.requires_trainable_backend
def test_cropping_1d_with_dynamic_spatial_dim(self):
input_layer = layers.Input(batch_shape=(1, None, 7))
cropped = layers.Cropping1D((1, 2))(input_layer)

@ -33,6 +33,7 @@ class Cropping2DTest(testing.TestCase, parameterized.TestCase):
{"data_format": "channels_last"},
),
)
@pytest.mark.requires_trainable_backend
def test_cropping_2d(self, cropping, data_format, expected_ranges):
if data_format == "channels_first":
inputs = np.random.rand(3, 5, 7, 9)

@ -30,6 +30,7 @@ class Cropping3DTest(testing.TestCase, parameterized.TestCase):
{"data_format": "channels_last"},
),
)
@pytest.mark.requires_trainable_backend
def test_cropping_3d(
self,
dim1_cropping,
@ -88,6 +89,7 @@ class Cropping3DTest(testing.TestCase, parameterized.TestCase):
{"data_format": "channels_last"},
),
)
@pytest.mark.requires_trainable_backend
def test_cropping_3d_with_same_cropping(
self, cropping, data_format, expected
):

@ -8,6 +8,7 @@ from keras_core import testing
class FlattenTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_flatten(self):
inputs = np.random.random((10, 3, 5, 5)).astype("float32")
@ -39,6 +40,7 @@ class FlattenTest(testing.TestCase):
expected_output=expected_output,
)
@pytest.mark.requires_trainable_backend
def test_flatten_with_scalar_channels(self):
inputs = np.random.random((10,)).astype("float32")
expected_output = ops.convert_to_tensor(np.expand_dims(inputs, -1))

@ -8,6 +8,7 @@ from keras_core import testing
class PermuteTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_permute(self):
inputs = np.random.random((2, 3, 5)).astype("float32")
expected_output = ops.convert_to_tensor(

@ -8,6 +8,7 @@ from keras_core import testing
class FlattenTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_repeat_vector(self):
inputs = np.random.random((2, 5)).astype("float32")
expected_output = ops.convert_to_tensor(

@ -6,6 +6,7 @@ from keras_core import testing
class ReshapeTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_reshape(self):
self.run_layer_test(
layers.Reshape,

@ -8,6 +8,7 @@ from keras_core.backend.common.keras_tensor import KerasTensor
class UpSamplingTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_upsampling_1d(self):
self.run_layer_test(
layers.UpSampling1D,

@ -14,6 +14,7 @@ class UpSampling2dTest(testing.TestCase, parameterized.TestCase):
length_row=[2],
length_col=[2, 3],
)
@pytest.mark.requires_trainable_backend
def test_upsampling_2d(self, data_format, length_row, length_col):
num_samples = 2
stack_size = 2
@ -64,6 +65,7 @@ class UpSampling2dTest(testing.TestCase, parameterized.TestCase):
length_row=[2],
length_col=[2, 3],
)
@pytest.mark.requires_trainable_backend
def test_upsampling_2d_bilinear(self, data_format, length_row, length_col):
num_samples = 2
stack_size = 2

@ -1,4 +1,5 @@
import numpy as np
import pytest
from absl.testing import parameterized
from keras_core import backend
@ -13,6 +14,7 @@ class UpSampling3dTest(testing.TestCase, parameterized.TestCase):
length_dim2=[2],
length_dim3=[3],
)
@pytest.mark.requires_trainable_backend
def test_upsampling_3d(
self, data_format, length_dim1, length_dim2, length_dim3
):

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import initializers
from keras_core import layers
@ -6,6 +7,7 @@ from keras_core import testing
class SimpleRNNTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_basics(self):
self.run_layer_test(
layers.Bidirectional,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import initializers
from keras_core import layers
@ -6,6 +7,7 @@ from keras_core import testing
class ConvLSTM1DTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_basics(self):
self.run_layer_test(
layers.ConvLSTM1D,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import initializers
from keras_core import layers
@ -6,6 +7,7 @@ from keras_core import testing
class ConvLSTM2DTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_basics(self):
self.run_layer_test(
layers.ConvLSTM2D,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import initializers
from keras_core import layers
@ -6,6 +7,7 @@ from keras_core import testing
class ConvLSTM1DTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_basics(self):
self.run_layer_test(
layers.ConvLSTM3D,

@ -1,3 +1,5 @@
import pytest
from keras_core import backend
from keras_core import layers
from keras_core import ops
@ -50,6 +52,7 @@ class DropoutRNNCellTest(testing.TestCase):
layer = layers.RNN(cell)
self.assertEqual(len(layer.non_trainable_variables), 1)
@pytest.mark.requires_trainable_backend
def test_basics(self):
self.run_layer_test(
layers.RNN,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from absl.testing import parameterized
from keras_core import initializers
@ -7,6 +8,7 @@ from keras_core import testing
class GRUTest(testing.TestCase, parameterized.TestCase):
@pytest.mark.requires_trainable_backend
def test_basics(self):
self.run_layer_test(
layers.GRU,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from absl.testing import parameterized
from keras_core import initializers
@ -7,6 +8,7 @@ from keras_core import testing
class LSTMTest(testing.TestCase, parameterized.TestCase):
@pytest.mark.requires_trainable_backend
def test_basics(self):
self.run_layer_test(
layers.LSTM,

@ -69,6 +69,7 @@ class TwoStatesRNNCell(layers.Layer):
class RNNTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_basics(self):
self.run_layer_test(
layers.RNN,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import initializers
from keras_core import layers
@ -6,6 +7,7 @@ from keras_core import testing
class SimpleRNNTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_basics(self):
self.run_layer_test(
layers.SimpleRNN,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import layers
from keras_core import testing
@ -7,6 +8,7 @@ from keras_core.layers.rnn.rnn_test import TwoStatesRNNCell
class StackedRNNTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_basics(self):
self.run_layer_test(
layers.RNN,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import initializers
from keras_core import layers
@ -7,6 +8,7 @@ from keras_core import testing
class TimeDistributedTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_basics(self):
self.run_layer_test(
layers.TimeDistributed,

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import losses as losses_module
@ -39,6 +40,10 @@ class LossTest(testing.TestCase):
with self.assertRaisesRegex(ValueError, "Invalid value for argument"):
ExampleLoss(reduction="abc")
@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy backend does not support masking.",
)
def test_mask(self):
mask = np.array([True, False, True, True])
y_true = np.array([1.0, 0.0, 1.0, 0.0])
@ -84,6 +89,10 @@ class LossTest(testing.TestCase):
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
self.assertAllClose(loss, 0) # No NaN.
@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy backend does not support masking.",
)
def test_mask_and_sample_weight(self):
sample_weight = np.array([0.4, 0.3, 0.2, 0.1])
y_true = np.array([1.0, 0.0, 1.0, 0.0])
@ -111,6 +120,10 @@ class LossTest(testing.TestCase):
# @testing.parametrize(
# "uprank", ["mask", "sample_weight", "y_true", "y_pred"])
# TODO: use parameterization decorator
@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy backend does not support masking.",
)
def test_rank_adjustment(self):
for uprank in ["mask", "sample_weight", "ys"]:
sample_weight = np.array([0.4, 0.3, 0.2, 0.1])

@ -9,7 +9,10 @@ def accuracy(y_true, y_pred):
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
y_true, y_pred = squeeze_to_same_rank(y_true, y_pred)
return ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx())
return ops.mean(
ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx()),
axis=-1,
)
@keras_core_export("keras_core.metrics.Accuracy")
@ -135,7 +138,7 @@ def categorical_accuracy(y_true, y_pred):
and (y_pred_rank is not None)
and (len(y_true.shape) == len(y_pred.shape))
):
y_true = ops.squeeze(y_true, [-1])
y_true = ops.squeeze(y_true, -1)
reshape_matches = True
y_pred = ops.argmax(y_pred, axis=-1)
@ -218,7 +221,7 @@ def sparse_categorical_accuracy(y_true, y_pred):
and (y_pred_rank is not None)
and (len(y_true.shape) == len(y_pred.shape))
):
y_true = ops.squeeze(y_true, [-1])
y_true = ops.squeeze(y_true, -1)
reshape_matches = True
y_pred = ops.argmax(y_pred, axis=-1)
@ -231,7 +234,7 @@ def sparse_categorical_accuracy(y_true, y_pred):
matches = ops.reshape(matches, new_shape=y_true_org_shape)
# if shape is (num_samples, 1) squeeze
if len(matches.shape) > 1 and matches.shape[-1] == 1:
matches = ops.squeeze(matches, [-1])
matches = ops.squeeze(matches, -1)
return matches

@ -1,6 +1,7 @@
import json
import numpy as np
import pytest
from absl import logging
from absl.testing import parameterized
from tensorflow.python.ops.numpy_ops import np_config
@ -756,7 +757,7 @@ class SensitivityAtSpecificityTest(testing.TestCase, parameterized.TestCase):
label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2]
y_pred = ops.transpose(np.array([pred_values] * 3))
y_true = ops.one_hot(label_values, num_classes=3)
y_true = ops.one_hot(np.array(label_values), num_classes=3)
self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))
@ -845,7 +846,7 @@ class SpecificityAtSensitivityTest(testing.TestCase, parameterized.TestCase):
label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2]
y_pred = ops.transpose(np.array([pred_values] * 3))
y_true = ops.one_hot(label_values, num_classes=3)
y_true = ops.one_hot(np.array(label_values), num_classes=3)
self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))
@ -931,7 +932,7 @@ class PrecisionAtRecallTest(testing.TestCase, parameterized.TestCase):
label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2]
y_pred = ops.transpose(np.array([pred_values] * 3))
y_true = ops.one_hot(label_values, num_classes=3)
y_true = ops.one_hot(np.array(label_values), num_classes=3)
# For 0.2 < decision threshold < 0.5.
self.assertAlmostEqual(0.75, s_obj(y_true, y_pred))
@ -1067,7 +1068,7 @@ class RecallAtPrecisionTest(testing.TestCase, parameterized.TestCase):
# recalls: [1, 1, 5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6,
# 1/6].
y_pred = ops.transpose(np.array([pred_values] * 3))
y_true = ops.one_hot(label_values, num_classes=3)
y_true = ops.one_hot(np.array(label_values), num_classes=3)
# The precision 5/7 can be reached at thresholds 00.3<=t<0.35.
self.assertAlmostEqual(5.0 / 6, s_obj(y_true, y_pred))
@ -1645,6 +1646,7 @@ class MultiAUCTest(testing.TestCase):
# PR AUCs are 0.939 and 1.0 respectively
self.assertAllClose(good_result, (0.939 + 1.0) / 2.0, 1e-1)
@pytest.mark.requires_trainable_backend
def test_keras_model_compiles(self):
inputs = layers.Input(shape=(10,), batch_size=1)
output = layers.Dense(3, activation="sigmoid")(inputs)

@ -1,4 +1,5 @@
import numpy as np
import pytest
from absl.testing import parameterized
from keras_core import layers
@ -41,6 +42,7 @@ def get_subclassed_model():
return ExampleModel()
@pytest.mark.requires_trainable_backend
class CloneModelTest(testing.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
("functional", get_functional_model),

@ -1,6 +1,7 @@
import warnings
import numpy as np
import pytest
from keras_core import backend
from keras_core import layers
@ -12,6 +13,7 @@ from keras_core.models import Model
class FunctionalTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_basic_flow_multi_input(self):
input_a = Input(shape=(3,), batch_size=2, name="input_a")
input_b = Input(shape=(3,), batch_size=2, name="input_b")
@ -37,6 +39,7 @@ class FunctionalTest(testing.TestCase):
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))
@pytest.mark.requires_trainable_backend
def test_scalar_input(self):
input_a = Input(shape=(3,), batch_size=2, name="input_a")
input_b = Input(shape=(), batch_size=2, name="input_b")
@ -48,6 +51,7 @@ class FunctionalTest(testing.TestCase):
out_val = model(in_val)
self.assertAllClose(out_val, np.ones((2, 3)))
@pytest.mark.requires_trainable_backend
def test_basic_flow_multi_output(self):
inputs = Input(shape=(3,), batch_size=2, name="input")
x = layers.Dense(5)(inputs)
@ -70,6 +74,7 @@ class FunctionalTest(testing.TestCase):
self.assertEqual(out_val[0].shape, (2, 4))
self.assertEqual(out_val[1].shape, (2, 5))
@pytest.mark.requires_trainable_backend
def test_basic_flow_dict_io(self):
input_a = Input(shape=(3,), batch_size=2, name="a")
input_b = Input(shape=(3,), batch_size=2, name="b")
@ -101,6 +106,7 @@ class FunctionalTest(testing.TestCase):
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))
@pytest.mark.requires_trainable_backend
def test_named_input_dict_io(self):
input_a = Input(shape=(3,), batch_size=2, name="a")
x = layers.Dense(5)(input_a)
@ -119,6 +125,7 @@ class FunctionalTest(testing.TestCase):
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))
@pytest.mark.requires_trainable_backend
def test_input_dict_with_extra_field(self):
input_a = Input(shape=(3,), batch_size=2, name="a")
x = input_a * 5
@ -145,6 +152,7 @@ class FunctionalTest(testing.TestCase):
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 3))
@pytest.mark.requires_trainable_backend
def test_layer_getters(self):
# Test mixing ops and layers
input_a = Input(shape=(3,), batch_size=2, name="input_a")
@ -162,6 +170,7 @@ class FunctionalTest(testing.TestCase):
self.assertEqual(model.get_layer(index=3).name, "dense_2")
self.assertEqual(model.get_layer(name="dense_1").name, "dense_1")
@pytest.mark.requires_trainable_backend
def test_training_arg(self):
class Canary(layers.Layer):
def call(self, x, training=False):
@ -180,6 +189,7 @@ class FunctionalTest(testing.TestCase):
# TODO
pass
@pytest.mark.requires_trainable_backend
def test_passing_inputs_by_name(self):
input_a = Input(shape=(3,), batch_size=2, name="input_a")
input_b = Input(shape=(3,), batch_size=2, name="input_b")
@ -203,6 +213,7 @@ class FunctionalTest(testing.TestCase):
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))
@pytest.mark.requires_trainable_backend
def test_rank_standardization(self):
# Downranking
inputs = Input(shape=(3,), batch_size=2)
@ -218,6 +229,7 @@ class FunctionalTest(testing.TestCase):
out_val = model(np.random.random((2, 3)))
self.assertEqual(out_val.shape, (2, 3, 3))
@pytest.mark.requires_trainable_backend
def test_dtype_standardization(self):
float_input = Input(shape=(2,), dtype="float16")
int_input = Input(shape=(2,), dtype="int32")
@ -229,6 +241,7 @@ class FunctionalTest(testing.TestCase):
self.assertEqual(backend.standardize_dtype(float_data.dtype), "float16")
self.assertEqual(backend.standardize_dtype(int_data.dtype), "int32")
@pytest.mark.requires_trainable_backend
def test_serialization(self):
# Test basic model
inputs = Input(shape=(3,), batch_size=2)
@ -269,6 +282,7 @@ class FunctionalTest(testing.TestCase):
model = Functional({"a": input_a, "b": input_b}, outputs)
self.run_class_serialization_test(model)
@pytest.mark.requires_trainable_backend
def test_bad_input_spec(self):
# Single input
inputs = Input(shape=(4,))
@ -303,6 +317,7 @@ class FunctionalTest(testing.TestCase):
):
model({"a": np.zeros((2, 3)), "b": np.zeros((2, 4))})
@pytest.mark.requires_trainable_backend
def test_manual_input_spec(self):
inputs = Input(shape=(None, 3))
outputs = layers.Dense(2)(inputs)

@ -22,6 +22,8 @@ elif backend.backend() == "jax":
from keras_core.backend.jax.trainer import JAXTrainer as Trainer
elif backend.backend() == "torch":
from keras_core.backend.torch.trainer import TorchTrainer as Trainer
elif backend.backend() == "numpy":
from keras_core.backend.numpy.trainer import NumpyTrainer as Trainer
else:
raise RuntimeError(
f"Backend '{backend.backend()}' must implement the Trainer class."

@ -1,4 +1,5 @@
import numpy as np
import pytest
from absl.testing import parameterized
from keras_core import layers
@ -64,6 +65,7 @@ def _get_model_multi_outputs_dict():
return model
@pytest.mark.requires_trainable_backend
class ModelTest(testing.TestCase, parameterized.TestCase):
def test_functional_rerouting(self):
model = _get_model()

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import backend
from keras_core import layers
@ -8,6 +9,7 @@ from keras_core.models.functional import Functional
from keras_core.models.sequential import Sequential
@pytest.mark.requires_trainable_backend
class SequentialTest(testing.TestCase):
def test_basic_flow_with_input(self):
model = Sequential(name="seq")

@ -1,4 +1,5 @@
import numpy as np
import pytest
from keras_core import layers
from keras_core import losses
@ -162,7 +163,7 @@ class CoreOpsCorrectnessTest(testing.TestCase):
def test_slice_update(self):
# Test 1D.
inputs = np.array([0, 0, 0, 0, 0, 0, 0, 0])
start_indices = [1]
start_indices = np.array([1])
updates = np.array([9, 10, 11, 12])
self.assertAllClose(
core.slice_update(inputs, start_indices, updates),
@ -204,6 +205,7 @@ class CoreOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(x, np.ones((2, 3)) * 6)
self.assertAllClose(y, np.ones((3, 2)) * 6)
@pytest.mark.requires_trainable_backend
def test_stop_gradient(self):
class ExampleLayer(layers.Layer):
def __init__(self):

@ -1,4 +1,5 @@
import numpy as np
import pytest
import keras_core
from keras_core import backend
@ -75,6 +76,7 @@ class AdamTest(testing.TestCase):
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [1.0, 1.0])
@pytest.mark.requires_trainable_backend
def test_ema(self):
# TODO: test correctness
model = keras_core.Sequential([keras_core.layers.Dense(10)])

@ -3,6 +3,7 @@
import math
import numpy as np
import pytest
from keras_core import backend
from keras_core import layers
@ -13,6 +14,7 @@ from keras_core.optimizers import schedules
class TestFitLRSchedulesFlow(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_fit_lr_correctness(self):
model = Sequential(
[

@ -1,6 +1,7 @@
import os
import numpy as np
import pytest
import tensorflow as tf
import keras_core
@ -52,6 +53,7 @@ def get_subclassed_model(keras):
return model
@pytest.mark.requires_trainable_backend
class LegacyH5LoadingTest(testing.TestCase):
def _check_reloading(self, ref_input, model, tf_keras_model):
ref_output = tf_keras_model(ref_input)

@ -7,6 +7,7 @@ from pathlib import Path
from unittest import mock
import numpy as np
import pytest
import keras_core
from keras_core import ops
@ -130,6 +131,7 @@ def my_mean_squared_error(y_true, y_pred):
return ops.mean(ops.square(y_pred - y_true), axis=-1)
@pytest.mark.requires_trainable_backend
class SavingTest(testing.TestCase):
def _get_subclassed_model(self, compile=True):
subclassed_model = CustomModelX()

@ -3,6 +3,7 @@
import json
import numpy as np
import pytest
import keras_core
from keras_core import ops
@ -188,6 +189,7 @@ class SerializationLibTest(testing.TestCase):
# y2 = new_lmbda(x)
# self.assertAllClose(y1, y2, atol=1e-5)
@pytest.mark.requires_trainable_backend
def test_dict_inputs_outputs(self):
input_foo = keras_core.Input((2,), name="foo")
input_bar = keras_core.Input((2,), name="bar")
@ -223,6 +225,7 @@ class SerializationLibTest(testing.TestCase):
self.assertIs(model.layers[2], model.layers[3].layer)
self.assertIs(new_model.layers[2], new_model.layers[3].layer)
@pytest.mark.requires_trainable_backend
def test_functional_subclass(self):
class PlainFunctionalSubclass(keras_core.Model):
pass

@ -21,6 +21,8 @@ elif backend.backend() == "tensorflow":
from keras_core.backend.tensorflow.trainer import (
TensorFlowTrainer as Trainer,
)
elif backend.backend() == "numpy":
from keras_core.backend.numpy.trainer import NumpyTrainer as Trainer
else:
raise ImportError(f"Invalid backend: {backend.backend()}")
@ -70,6 +72,7 @@ class TrainingTestingLayer(layers.Layer, Trainer):
return x * 0
@pytest.mark.requires_trainable_backend
class TestTrainer(testing.TestCase, parameterized.TestCase):
def test_metric_tracking(self):
class ModelWithMetric(layers.Dense, Trainer):

Some files were not shown because too many files have changed in this diff Show More