keras/keras_core/testing/test_case.py
2023-04-26 20:22:03 -07:00

260 lines
10 KiB
Python

import json
import shutil
import tempfile
import unittest
import numpy as np
from tensorflow import nest
from keras_core import operations as ops
class TestCase(unittest.TestCase):
maxDiff = None
def get_temp_dir(self):
temp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(temp_dir))
return temp_dir
def assertAllClose(self, x1, x2, atol=1e-6, rtol=1e-6):
np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol)
def assertAlmostEqual(self, x1, x2, decimal=3):
np.testing.assert_almost_equal(x1, x2, decimal=decimal)
def assertLen(self, iterable, expected_len):
np.testing.assert_equal(len(iterable), expected_len)
def run_class_serialization_test(self, instance, custom_objects=None):
from keras_core.saving import custom_object_scope
from keras_core.saving import deserialize_keras_object
from keras_core.saving import serialize_keras_object
# get_config roundtrip
cls = instance.__class__
config = instance.get_config()
config_json = json.dumps(config, sort_keys=True, indent=4)
ref_dir = dir(instance)[:]
with custom_object_scope(custom_objects):
revived_instance = cls.from_config(config)
revived_config = revived_instance.get_config()
revived_config_json = json.dumps(
revived_config, sort_keys=True, indent=4
)
self.assertEqual(config_json, revived_config_json)
self.assertEqual(ref_dir, dir(revived_instance))
# serialization roundtrip
serialized = serialize_keras_object(instance)
serialized_json = json.dumps(serialized, sort_keys=True, indent=4)
with custom_object_scope(custom_objects):
revived_instance = deserialize_keras_object(
json.loads(serialized_json)
)
revived_config = revived_instance.get_config()
revived_config_json = json.dumps(
revived_config, sort_keys=True, indent=4
)
self.assertEqual(config_json, revived_config_json)
self.assertEqual(ref_dir, dir(revived_instance))
return revived_instance
def run_layer_test(
self,
layer_cls,
init_kwargs,
input_shape,
input_dtype="float32",
input_data=None,
call_kwargs=None,
expected_output_shape=None,
expected_output_dtype=None,
expected_output=None,
expected_num_trainable_weights=None,
expected_num_non_trainable_weights=None,
expected_num_seed_generators=None,
expected_num_losses=None,
supports_masking=None,
expected_mask_shape=None,
custom_objects=None,
):
"""Run basic checks on a layer.
Args:
layer_cls: The class of the layer to test.
init_kwargs: Dict of arguments to be used to
instantiate the layer.
input_shape: Shape tuple (or list/dict of shape tuples)
to call the layer on.
input_dtype: Corresponding input dtype.
input_data: Tensor (or list/dict of tensors)
to call the layer on.
call_kwargs: Dict of arguments to use when calling the
layer (does not include the first input tensor argument)
expected_output_shape: Shape tuple
(or list/dict of shape tuples)
expected as output.
expected_output_dtype: dtype expected as output.
expected_output: Expected output tensor -- only
to be specified if input_data is provided.
expected_num_trainable_weights: Expected number
of trainable weights of the layer once built.
expected_num_non_trainable_weights: Expected number
of non-trainable weights of the layer once built.
expected_num_seed_generators: Expected number of
SeedGenerators objects of the layer once built.
expected_num_losses: Expected number of loss tensors
produced when calling the layer.
supports_masking: If True, will check that the layer
supports masking.
expected_mask_shape: Expected mask shape tuple
returned by compute_mask() (only supports 1 shape).
custom_objects: Dict of any custom objects to be
considered during deserialization.
"""
if input_shape is not None and input_data is not None:
raise ValueError(
"input_shape and input_data cannot be passed "
"at the same time."
)
if expected_output_shape is not None and expected_output is not None:
raise ValueError(
"expected_output_shape and expected_output cannot be passed "
"at the same time."
)
if expected_output is not None and input_data is None:
raise ValueError(
"In order to use expected_output, input_data must be provided."
)
if expected_mask_shape is not None and supports_masking is not True:
raise ValueError(
"""In order to use expected_mask_shape, supports_masking
must be True."""
)
init_kwargs = init_kwargs or {}
call_kwargs = call_kwargs or {}
# Serialization test.
layer = layer_cls(**init_kwargs)
self.run_class_serialization_test(layer, custom_objects)
# Basic masking test.
if supports_masking is not None:
self.assertEqual(layer.supports_masking, supports_masking)
def run_build_asserts(layer):
self.assertTrue(layer.built)
if expected_num_trainable_weights is not None:
self.assertLen(
layer.trainable_weights, expected_num_trainable_weights
)
if expected_num_non_trainable_weights is not None:
self.assertLen(
layer.non_trainable_weights,
expected_num_non_trainable_weights,
)
if expected_num_seed_generators is not None:
self.assertLen(
layer._seed_generators, expected_num_seed_generators
)
def run_output_asserts(layer, output, eager=False):
if expected_output_shape is not None:
if isinstance(expected_output_shape, tuple):
self.assertEqual(expected_output_shape, output.shape)
elif isinstance(expected_output_shape, dict):
self.assertTrue(isinstance(output, dict))
self.assertEqual(
set(output.keys()), set(expected_output_shape.keys())
)
output_shape = {
k: v.shape for k, v in expected_output_shape.items()
}
self.assertEqual(expected_output_shape, output_shape)
elif isinstance(expected_output_shape, list):
self.assertTrue(isinstance(output, list))
self.assertEqual(
len(output.keys()), len(expected_output_shape.keys())
)
output_shape = [v.shape for v in expected_output_shape]
self.assertEqual(expected_output_shape, output_shape)
if expected_output_dtype is not None:
output_dtype = nest.flatten(output)[0].dtype
self.assertEqual(expected_output_dtype, output_dtype)
if eager:
if expected_output is not None:
self.assertEqual(type(expected_output), type(output))
for ref_v, v in zip(
nest.flatten(expected_output), nest.flatten(output)
):
self.assertAllClose(ref_v, v)
if expected_num_losses is not None:
self.assertLen(layer.losses, expected_num_losses)
# Build test.
if input_shape is not None:
layer = layer_cls(**init_kwargs)
layer.build(input_shape)
run_build_asserts(layer)
# Symbolic call test.
keras_tensor_inputs = create_keras_tensors(input_shape, input_dtype)
layer = layer_cls(**init_kwargs)
keras_tensor_outputs = layer(keras_tensor_inputs, **call_kwargs)
run_build_asserts(layer)
run_output_asserts(layer, keras_tensor_outputs, eager=False)
if expected_mask_shape is not None:
output_mask = layer.compute_mask(keras_tensor_inputs)
self.assertEqual(expected_mask_shape, output_mask.shape)
# Eager call test.
if input_data is not None or input_shape is not None:
if input_data is None:
input_data = create_eager_tensors(input_shape, input_dtype)
layer = layer_cls(**init_kwargs)
output_data = layer(input_data, **call_kwargs)
run_output_asserts(layer, output_data, eager=True)
def create_keras_tensors(input_shape, dtype):
from keras_core.backend import keras_tensor
if isinstance(input_shape, tuple):
return keras_tensor.KerasTensor(input_shape, dtype=dtype)
if isinstance(input_shape, list):
return [keras_tensor.KerasTensor(s, dtype=dtype) for s in input_shape]
if isinstance(input_shape, dict):
return {
k: keras_tensor.KerasTensor(v, dtype=dtype)
for k, v in input_shape.items()
}
def create_eager_tensors(input_shape, dtype):
from keras_core.backend import random
if dtype in ["float16", "float32", "float64"]:
create_fn = random.uniform
elif dtype in ["int16", "int32", "int64"]:
def create_fn(shape, dtype):
return ops.cast(
random.uniform(shape, dtype="float32") * 3, dtype=dtype
)
else:
raise ValueError(
"dtype must be a standard float or int dtype. "
f"Received: dtype={dtype}"
)
if isinstance(input_shape, tuple):
return create_fn(input_shape, dtype=dtype)
if isinstance(input_shape, list):
return [create_fn(s, dtype=dtype) for s in input_shape]
if isinstance(input_shape, dict):
return {k: create_fn(v, dtype=dtype) for k, v in input_shape.items()}