Add layer testing infra.

This commit is contained in:
Francois Chollet 2023-04-22 18:03:15 -07:00
parent 3f7fdc2310
commit bb1d4eeb18
11 changed files with 330 additions and 49 deletions

@ -1,27 +1,25 @@
import types
from keras_core.api_export import keras_core_export
from keras_core.saving import serialization_lib
from keras_core.saving import object_registration
from keras_core.activations.activations import relu
from keras_core.activations.activations import leaky_relu
from keras_core.activations.activations import relu6
from keras_core.activations.activations import softmax
from keras_core.activations.activations import elu
from keras_core.activations.activations import exponential
from keras_core.activations.activations import gelu
from keras_core.activations.activations import hard_sigmoid
from keras_core.activations.activations import leaky_relu
from keras_core.activations.activations import linear
from keras_core.activations.activations import log_softmax
from keras_core.activations.activations import mish
from keras_core.activations.activations import relu
from keras_core.activations.activations import relu6
from keras_core.activations.activations import selu
from keras_core.activations.activations import sigmoid
from keras_core.activations.activations import silu
from keras_core.activations.activations import softmax
from keras_core.activations.activations import softplus
from keras_core.activations.activations import softsign
from keras_core.activations.activations import silu
from keras_core.activations.activations import gelu
from keras_core.activations.activations import tanh
from keras_core.activations.activations import sigmoid
from keras_core.activations.activations import exponential
from keras_core.activations.activations import hard_sigmoid
from keras_core.activations.activations import linear
from keras_core.activations.activations import mish
from keras_core.activations.activations import log_softmax
from keras_core.api_export import keras_core_export
from keras_core.saving import object_registration
from keras_core.saving import serialization_lib
ALL_OBJECTS = {
relu,

@ -39,8 +39,17 @@ def relu(x, negative_slope=0.0, max_value=None, threshold=0.0):
A tensor with the same shape and dtype as input `x`.
"""
if backend.any_symbolic_tensors((x,)):
return ReLU(negative_slope=negative_slope, max_value=max_value, threshold=threshold)(x)
return ReLU.static_call(x, negative_slope=negative_slope, max_value=max_value, threshold=threshold)
return ReLU(
negative_slope=negative_slope,
max_value=max_value,
threshold=threshold,
)(x)
return ReLU.static_call(
x,
negative_slope=negative_slope,
max_value=max_value,
threshold=threshold,
)
class ReLU(ops.Operation):
@ -93,7 +102,6 @@ class ReLU(ops.Operation):
return x
@keras_core_export("keras_core.activations.leaky_relu")
def leaky_relu(x, negative_slope=0.2):
"""Leaky relu activation function.
@ -236,7 +244,9 @@ def softsign(x):
return ops.softsign(x)
@keras_core_export(["keras_core.activations.silu", "keras_core.activations.swish"])
@keras_core_export(
["keras_core.activations.silu", "keras_core.activations.swish"]
)
def silu(x):
"""Swish (or Silu) activation function.

@ -0,0 +1 @@
# TODO

@ -1,9 +1,11 @@
import numpy as np
import tensorflow as tf
from jax import numpy as jnp
from keras_core import operations as ops
from keras_core import testing
from keras_core.backend import keras_tensor
from keras_core import operations as ops
import numpy as np
from jax import numpy as jnp
import tensorflow as tf
class KerasTensorTest(testing.TestCase):
def test_attributes(self):
@ -30,11 +32,17 @@ class KerasTensorTest(testing.TestCase):
def test_invalid_usage(self):
x = keras_tensor.KerasTensor(shape=(3,), dtype="float32")
with self.assertRaisesRegex(ValueError, "doesn't have any actual numerical value"):
with self.assertRaisesRegex(
ValueError, "doesn't have any actual numerical value"
):
np.array(x)
with self.assertRaisesRegex(ValueError, "cannot be used as input to a JAX function"):
with self.assertRaisesRegex(
ValueError, "cannot be used as input to a JAX function"
):
jnp.array(x)
with self.assertRaisesRegex(ValueError, "cannot be used as input to a TensorFlow function"):
with self.assertRaisesRegex(
ValueError, "cannot be used as input to a TensorFlow function"
):
tf.convert_to_tensor(x)

@ -49,6 +49,7 @@ class Dense(Layer):
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
)
self.built = True
def call(self, inputs):
x = ops.matmul(inputs, self.kernel)
@ -66,7 +67,9 @@ class Dense(Layer):
self.kernel_initializer
),
"bias_initializer": initializers.serialize(self.bias_initializer),
"kernel_regularizer": regularizers.serialize(self.kernel_regularizer),
"kernel_regularizer": regularizers.serialize(
self.kernel_regularizer
),
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
"kernel_constraint": constraints.serialize(self.kernel_constraint),
"bias_constraint": constraints.serialize(self.bias_constraint),

@ -0,0 +1,47 @@
from keras_core import testing
from keras_core.layers.core.dense import Dense
class DenseTest(testing.TestCase):
def test_basics(self):
# 2D case, no bias.
self.run_layer_test(
Dense,
init_kwargs={
"units": 4,
"activation": "relu",
"kernel_initializer": "random_uniform",
"bias_initializer": "ones",
"use_bias": False,
},
input_shape=(2, 3),
expected_output_shape=(2, 4),
expected_output=None,
expected_num_trainable_weights=1,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=False,
)
# 3D case, some regularizers.
self.run_layer_test(
Dense,
init_kwargs={
"units": 5,
"activation": "sigmoid",
"kernel_regularizer": "l2",
"bias_regularizer": "l2",
},
input_shape=(2, 3, 4),
expected_output_shape=(2, 3, 5),
expected_output=None,
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=2, # we have 2 regularizers.
supports_masking=False,
)
def test_correctness(self):
# TODO
pass

@ -1074,7 +1074,7 @@ class ExpandDims(Operation):
axis = len(x.shape) + 1 + self.axis
else:
axis = self.axis
output_shape = x_shape[: axis] + [1] + x_shape[axis :]
output_shape = x_shape[:axis] + [1] + x_shape[axis:]
return KerasTensor(output_shape, dtype=x.dtype)

@ -1,4 +1,5 @@
from keras_core.saving.object_registration import CustomObjectScope
from keras_core.saving.object_registration import custom_object_scope
from keras_core.saving.object_registration import get_custom_objects
from keras_core.saving.object_registration import get_registered_name
from keras_core.saving.object_registration import get_registered_object

@ -38,17 +38,16 @@ class CustomObjectScope:
```
Args:
*args: Dictionary or dictionaries of `{name: object}` pairs.
custom_objects: Dictionary of `{name: object}` pairs.
"""
def __init__(self, *args):
self.custom_objects = args
def __init__(self, custom_objects):
self.custom_objects = custom_objects or {}
self.backup = None
def __enter__(self):
self.backup = _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.copy()
for objects in self.custom_objects:
_THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.update(objects)
_THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.update(self.custom_objects)
return self
def __exit__(self, *args, **kwargs):
@ -56,6 +55,10 @@ class CustomObjectScope:
_THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.update(self.backup)
# Alias.
custom_object_scope = CustomObjectScope
@keras_core_export(
"keras_core.saving.get_custom_objects",
)

@ -1,6 +1,8 @@
import json
import unittest
import numpy as np
from tensorflow import nest
class TestCase(unittest.TestCase):
@ -12,3 +14,211 @@ class TestCase(unittest.TestCase):
def assertLen(self, iterable, expected_len):
np.testing.assert_equal(len(iterable), expected_len)
def run_class_serialization_test(self, instance, custom_objects=None):
from keras_core.saving import custom_object_scope
from keras_core.saving import deserialize_keras_object
from keras_core.saving import serialize_keras_object
# get_config roundtrip
cls = instance.__class__
config = instance.get_config()
ref_dir = dir(instance)[:]
with custom_object_scope(custom_objects):
revived_instance = cls.from_config(config)
revived_config = revived_instance.get_config()
self.assertEqual(config, revived_config)
self.assertEqual(ref_dir, dir(revived_instance))
# serialization roundtrip
serialized = serialize_keras_object(instance)
json_str = json.dumps(serialized)
with custom_object_scope(custom_objects):
revived_instance = deserialize_keras_object(json.loads(json_str))
revived_config = revived_instance.get_config()
self.assertEqual(config, revived_config)
self.assertEqual(ref_dir, dir(revived_instance))
def run_layer_test(
self,
layer_cls,
init_kwargs,
input_shape,
input_dtype="float32",
input_data=None,
call_kwargs=None,
expected_output_shape=None,
expected_output_dtype=None,
expected_output=None,
expected_num_trainable_weights=None,
expected_num_non_trainable_weights=None,
expected_num_seed_generators=None,
expected_num_losses=None,
supports_masking=None,
expected_mask_shape=None,
custom_objects=None,
):
"""Run basic checks on a layer.
Args:
layer_cls: The class of the layer to test.
init_kwargs: Dict of arguments to be used to
instantiate the layer.
input_shape: Shape tuple (or list/dict of shape tuples)
to call the layer on.
input_dtype: Corresponding input dtype.
input_data: Tensor (or list/dict of tensors)
to call the layer on.
call_kwargs: Dict of arguments to use when calling the
layer (does not include the first input tensor argument)
expected_output_shape: Shape tuple
(or list/dict of shape tuples)
expected as output.
expected_output_dtype: dtype expected as output.
expected_output: Expected output tensor -- only
to be specified if input_data is provided.
expected_num_trainable_weights: Expected number
of trainable weights of the layer once built.
expected_num_non_trainable_weights: Expected number
of non-trainable weights of the layer once built.
expected_num_seed_generators: Expected number of
SeedGenerators objects of the layer once built.
expected_num_losses: Expected number of loss tensors
produced when calling the layer.
supports_masking: If True, will check that the layer
supports masking.
expected_mask_shape: Expected mask shape tuple
returned by compute_mask() (only supports 1 shape).
custom_objects: Dict of any custom objects to be
considered during deserialization.
"""
if input_shape is not None and input_data is not None:
raise ValueError(
"input_shape and input_data cannot be passed "
"at the same time."
)
if expected_output_shape is not None and expected_output is not None:
raise ValueError(
"expected_output_shape and expected_output cannot be passed "
"at the same time."
)
if expected_output is not None and input_data is None:
raise ValueError(
"In order to use expected_output, input_data must be provided."
)
if expected_mask_shape is not None and supports_masking is not True:
raise ValueError(
"In order to use expected_mask_shape, supports_masking must be True."
)
init_kwargs = init_kwargs or {}
call_kwargs = call_kwargs or {}
# Serialization test.
layer = layer_cls(**init_kwargs)
self.run_class_serialization_test(layer, custom_objects)
# Basic masking test.
if supports_masking is not None:
self.assertEqual(layer.supports_masking, supports_masking)
def run_build_asserts(layer):
self.assertTrue(layer.built)
if expected_num_trainable_weights is not None:
self.assertLen(
layer.trainable_weights, expected_num_trainable_weights
)
if expected_num_non_trainable_weights is not None:
self.assertLen(
layer.non_trainable_weights,
expected_num_non_trainable_weights,
)
if expected_num_seed_generators is not None:
self.assertLen(
layer._seed_generators, expected_num_seed_generators
)
def run_output_asserts(layer, output, eager=False):
if expected_output_shape is not None:
if isinstance(expected_output_shape, tuple):
self.assertEqual(expected_output_shape, output.shape)
elif isinstance(expected_output_shape, dict):
self.assertTrue(isinstance(output, dict))
self.assertEqual(
set(output.keys()), set(expected_output_shape.keys())
)
output_shape = {
k: v.shape for k, v in expected_output_shape.items()
}
self.assertEqual(expected_output_shape, output_shape)
elif isinstance(expected_output_shape, list):
self.assertTrue(isinstance(output, list))
self.assertEqual(
len(output.keys()), len(expected_output_shape.keys())
)
output_shape = [v.shape for v in expected_output_shape]
self.assertEqual(expected_output_shape, output_shape)
if expected_output_dtype is not None:
output_dtype = nest.flatten(output)[0].dtype
self.assertEqual(expected_output_dtype, output_dtype)
if eager:
if expected_output is not None:
self.assertEqual(type(expected_output), type(output))
for ref_v, v in zip(
nest.flatten(expected_output), nest.flatten(output)
):
self.assertAllClose(ref_v, v)
if expected_num_losses is not None:
self.assertLen(layer.losses, expected_num_losses)
# Build test.
if input_shape is not None:
layer = layer_cls(**init_kwargs)
layer.build(input_shape)
run_build_asserts(layer)
# Symbolic call test.
keras_tensor_inputs = create_keras_tensors(input_shape, input_dtype)
layer = layer_cls(**init_kwargs)
keras_tensor_outputs = layer(keras_tensor_inputs, **call_kwargs)
run_build_asserts(layer)
run_output_asserts(layer, keras_tensor_outputs, eager=False)
if expected_mask_shape is not None:
output_mask = layer.compute_mask(keras_tensor_inputs)
self.assertEqual(expected_mask_shape, output_mask.shape)
# Eager call test.
if input_data is not None or input_shape is not None:
if input_data is None:
input_data = create_eager_tensors(input_shape, input_dtype)
layer = layer_cls(**init_kwargs)
output_data = layer(input_data, **call_kwargs)
run_output_asserts(layer, output_data, eager=True)
def create_keras_tensors(input_shape, dtype):
from keras_core.backend import keras_tensor
if isinstance(input_shape, tuple):
return keras_tensor.KerasTensor(input_shape, dtype=dtype)
if isinstance(input_shape, list):
return [keras_tensor.KerasTensor(s, dtype=dtype) for s in input_shape]
if isinstance(input_shape, dict):
return {
k: keras_tensor.KerasTensor(v, dtype=dtype)
for k, v in input_shape.items()
}
def create_eager_tensors(input_shape, dtype):
from keras_core.backend import random
if isinstance(input_shape, tuple):
return random.uniform(input_shape, dtype=dtype)
if isinstance(input_shape, list):
return [random.uniform(s, dtype=dtype) for s in input_shape]
if isinstance(input_shape, dict):
return {
k: random.uniform(v, dtype=dtype) for k, v in input_shape.items()
}