345 lines
11 KiB
Python
345 lines
11 KiB
Python
"""Tests for serialization_lib."""
|
|
|
|
import json
|
|
|
|
import numpy as np
|
|
|
|
import keras_core
|
|
from keras_core import operations as ops
|
|
from keras_core import testing
|
|
from keras_core.saving import serialization_lib
|
|
|
|
|
|
def custom_fn(x):
|
|
return x**2
|
|
|
|
|
|
class CustomLayer(keras_core.layers.Layer):
|
|
def __init__(self, factor):
|
|
super().__init__()
|
|
self.factor = factor
|
|
|
|
def call(self, x):
|
|
return x * self.factor
|
|
|
|
def get_config(self):
|
|
return {"factor": self.factor}
|
|
|
|
|
|
class NestedCustomLayer(keras_core.layers.Layer):
|
|
def __init__(self, factor, dense=None, activation=None):
|
|
super().__init__()
|
|
self.factor = factor
|
|
|
|
if dense is None:
|
|
self.dense = keras_core.layers.Dense(1, activation=custom_fn)
|
|
else:
|
|
self.dense = serialization_lib.deserialize_keras_object(dense)
|
|
self.activation = serialization_lib.deserialize_keras_object(activation)
|
|
|
|
def call(self, x):
|
|
return self.dense(x * self.factor)
|
|
|
|
def get_config(self):
|
|
return {
|
|
"factor": self.factor,
|
|
"dense": self.dense,
|
|
"activation": self.activation,
|
|
}
|
|
|
|
|
|
class WrapperLayer(keras_core.layers.Layer):
|
|
def __init__(self, layer, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.layer = layer
|
|
|
|
def call(self, x):
|
|
return self.layer(x)
|
|
|
|
def get_config(self):
|
|
config = super().get_config()
|
|
return {"layer": self.layer, **config}
|
|
|
|
|
|
class SerializationLibTest(testing.TestCase):
|
|
def roundtrip(self, obj, custom_objects=None, safe_mode=True):
|
|
serialized = serialization_lib.serialize_keras_object(obj)
|
|
json_data = json.dumps(serialized)
|
|
json_data = json.loads(json_data)
|
|
deserialized = serialization_lib.deserialize_keras_object(
|
|
json_data, custom_objects=custom_objects, safe_mode=safe_mode
|
|
)
|
|
reserialized = serialization_lib.serialize_keras_object(deserialized)
|
|
return serialized, deserialized, reserialized
|
|
|
|
def test_simple_objects(self):
|
|
for obj in [
|
|
"hello",
|
|
b"hello",
|
|
np.array([0, 1]),
|
|
np.array([0.0, 1.0]),
|
|
np.float32(1.0),
|
|
["hello", 0, "world", 1.0, True],
|
|
{"1": "hello", "2": 0, "3": True},
|
|
{"1": "hello", "2": [True, False]},
|
|
]:
|
|
serialized, _, reserialized = self.roundtrip(obj)
|
|
self.assertEqual(serialized, reserialized)
|
|
|
|
def test_builtin_layers(self):
|
|
serialized, _, reserialized = self.roundtrip(keras_core.layers.Dense(3))
|
|
self.assertEqual(serialized, reserialized)
|
|
|
|
def test_tensors_and_shapes(self):
|
|
x = ops.random.normal((2, 2), dtype="float64")
|
|
obj = {"x": x}
|
|
_, new_obj, _ = self.roundtrip(obj)
|
|
self.assertAllClose(x, new_obj["x"], atol=1e-5)
|
|
|
|
obj = {"x.shape": x.shape}
|
|
_, new_obj, _ = self.roundtrip(obj)
|
|
self.assertEqual(tuple(x.shape), tuple(new_obj["x.shape"]))
|
|
|
|
def test_custom_fn(self):
|
|
obj = {"activation": custom_fn}
|
|
serialized, _, reserialized = self.roundtrip(
|
|
obj, custom_objects={"custom_fn": custom_fn}
|
|
)
|
|
self.assertEqual(serialized, reserialized)
|
|
|
|
# Test inside layer
|
|
dense = keras_core.layers.Dense(1, activation=custom_fn)
|
|
dense.build((None, 2))
|
|
_, new_dense, _ = self.roundtrip(
|
|
dense, custom_objects={"custom_fn": custom_fn}
|
|
)
|
|
x = ops.random.normal((2, 2))
|
|
y1 = dense(x)
|
|
_ = new_dense(x)
|
|
new_dense.set_weights(dense.get_weights())
|
|
y2 = new_dense(x)
|
|
self.assertAllClose(y1, y2, atol=1e-5)
|
|
|
|
def test_custom_layer(self):
|
|
layer = CustomLayer(factor=2)
|
|
x = ops.random.normal((2, 2))
|
|
y1 = layer(x)
|
|
_, new_layer, _ = self.roundtrip(
|
|
layer, custom_objects={"CustomLayer": CustomLayer}
|
|
)
|
|
y2 = new_layer(x)
|
|
self.assertAllClose(y1, y2, atol=1e-5)
|
|
|
|
layer = NestedCustomLayer(factor=2)
|
|
x = ops.random.normal((2, 2))
|
|
y1 = layer(x)
|
|
_, new_layer, _ = self.roundtrip(
|
|
layer,
|
|
custom_objects={
|
|
"NestedCustomLayer": NestedCustomLayer,
|
|
"custom_fn": custom_fn,
|
|
},
|
|
)
|
|
_ = new_layer(x)
|
|
new_layer.set_weights(layer.get_weights())
|
|
y2 = new_layer(x)
|
|
self.assertAllClose(y1, y2, atol=1e-5)
|
|
|
|
def test_lambda_fn(self):
|
|
obj = {"activation": lambda x: x**2}
|
|
with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
|
|
self.roundtrip(obj, safe_mode=True)
|
|
|
|
_, new_obj, _ = self.roundtrip(obj, safe_mode=False)
|
|
self.assertEqual(obj["activation"](3), new_obj["activation"](3))
|
|
|
|
# TODO
|
|
# def test_lambda_layer(self):
|
|
# lmbda = keras_core.layers.Lambda(lambda x: x**2)
|
|
# with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
|
|
# self.roundtrip(lmbda, safe_mode=True)
|
|
|
|
# _, new_lmbda, _ = self.roundtrip(lmbda, safe_mode=False)
|
|
# x = ops.random.normal((2, 2))
|
|
# y1 = lmbda(x)
|
|
# y2 = new_lmbda(x)
|
|
# self.assertAllClose(y1, y2, atol=1e-5)
|
|
|
|
# def test_safe_mode_scope(self):
|
|
# lmbda = keras_core.layers.Lambda(lambda x: x**2)
|
|
# with serialization_lib.SafeModeScope(safe_mode=True):
|
|
# with self.assertRaisesRegex(
|
|
# ValueError, "arbitrary code execution"
|
|
# ):
|
|
# self.roundtrip(lmbda)
|
|
# with serialization_lib.SafeModeScope(safe_mode=False):
|
|
# _, new_lmbda, _ = self.roundtrip(lmbda)
|
|
# x = ops.random.normal((2, 2))
|
|
# y1 = lmbda(x)
|
|
# y2 = new_lmbda(x)
|
|
# self.assertAllClose(y1, y2, atol=1e-5)
|
|
|
|
def shared_inner_layer(self):
|
|
input_1 = keras_core.Input((2,))
|
|
input_2 = keras_core.Input((2,))
|
|
shared_layer = keras_core.layers.Dense(1)
|
|
output_1 = shared_layer(input_1)
|
|
wrapper_layer = WrapperLayer(shared_layer)
|
|
output_2 = wrapper_layer(input_2)
|
|
model = keras_core.Model([input_1, input_2], [output_1, output_2])
|
|
_, new_model, _ = self.roundtrip(
|
|
model, custom_objects={"WrapperLayer": WrapperLayer}
|
|
)
|
|
|
|
self.assertIs(model.layers[2], model.layers[3].layer)
|
|
self.assertIs(new_model.layers[2], new_model.layers[3].layer)
|
|
|
|
def test_functional_subclass(self):
|
|
class PlainFunctionalSubclass(keras_core.Model):
|
|
pass
|
|
|
|
inputs = keras_core.Input((2,), batch_size=3)
|
|
outputs = keras_core.layers.Dense(1)(inputs)
|
|
model = PlainFunctionalSubclass(inputs, outputs)
|
|
x = ops.random.normal((2, 2))
|
|
y1 = model(x)
|
|
_, new_model, _ = self.roundtrip(
|
|
model,
|
|
custom_objects={"PlainFunctionalSubclass": PlainFunctionalSubclass},
|
|
)
|
|
new_model.set_weights(model.get_weights())
|
|
y2 = new_model(x)
|
|
self.assertAllClose(y1, y2, atol=1e-5)
|
|
self.assertIsInstance(new_model, PlainFunctionalSubclass)
|
|
|
|
class FunctionalSubclassWCustomInit(keras_core.Model):
|
|
def __init__(self, num_units=2):
|
|
inputs = keras_core.Input((2,), batch_size=3)
|
|
outputs = keras_core.layers.Dense(num_units)(inputs)
|
|
super().__init__(inputs, outputs)
|
|
self.num_units = num_units
|
|
|
|
def get_config(self):
|
|
return {"num_units": self.num_units}
|
|
|
|
model = FunctionalSubclassWCustomInit(num_units=3)
|
|
x = ops.random.normal((2, 2))
|
|
y1 = model(x)
|
|
_, new_model, _ = self.roundtrip(
|
|
model,
|
|
custom_objects={
|
|
"FunctionalSubclassWCustomInit": FunctionalSubclassWCustomInit
|
|
},
|
|
)
|
|
new_model.set_weights(model.get_weights())
|
|
y2 = new_model(x)
|
|
self.assertAllClose(y1, y2, atol=1e-5)
|
|
self.assertIsInstance(new_model, FunctionalSubclassWCustomInit)
|
|
|
|
def test_shared_object(self):
|
|
class MyLayer(keras_core.layers.Layer):
|
|
def __init__(self, activation, **kwargs):
|
|
super().__init__(**kwargs)
|
|
if isinstance(activation, dict):
|
|
self.activation = (
|
|
serialization_lib.deserialize_keras_object(activation)
|
|
)
|
|
else:
|
|
self.activation = activation
|
|
|
|
def call(self, x):
|
|
return self.activation(x)
|
|
|
|
def get_config(self):
|
|
config = super().get_config()
|
|
config["activation"] = self.activation
|
|
return config
|
|
|
|
class SharedActivation:
|
|
def __call__(self, x):
|
|
return x**2
|
|
|
|
def get_config(self):
|
|
return {}
|
|
|
|
@classmethod
|
|
def from_config(cls, config):
|
|
return cls()
|
|
|
|
shared_act = SharedActivation()
|
|
layer_1 = MyLayer(activation=shared_act)
|
|
layer_2 = MyLayer(activation=shared_act)
|
|
layers = [layer_1, layer_2]
|
|
|
|
with serialization_lib.ObjectSharingScope():
|
|
serialized, new_layers, reserialized = self.roundtrip(
|
|
layers,
|
|
custom_objects={
|
|
"MyLayer": MyLayer,
|
|
"SharedActivation": SharedActivation,
|
|
},
|
|
)
|
|
self.assertIn("shared_object_id", serialized[0]["config"]["activation"])
|
|
obj_id = serialized[0]["config"]["activation"]
|
|
self.assertIn("shared_object_id", serialized[1]["config"]["activation"])
|
|
self.assertEqual(obj_id, serialized[1]["config"]["activation"])
|
|
self.assertIs(layers[0].activation, layers[1].activation)
|
|
self.assertIs(new_layers[0].activation, new_layers[1].activation)
|
|
|
|
|
|
@keras_core.saving.register_keras_serializable()
|
|
class MyDense(keras_core.layers.Layer):
|
|
def __init__(
|
|
self,
|
|
units,
|
|
*,
|
|
kernel_regularizer=None,
|
|
kernel_initializer=None,
|
|
**kwargs
|
|
):
|
|
super().__init__(**kwargs)
|
|
self._units = units
|
|
self._kernel_regularizer = kernel_regularizer
|
|
self._kernel_initializer = kernel_initializer
|
|
|
|
def get_config(self):
|
|
return dict(
|
|
units=self._units,
|
|
kernel_initializer=self._kernel_initializer,
|
|
kernel_regularizer=self._kernel_regularizer,
|
|
**super().get_config()
|
|
)
|
|
|
|
def build(self, input_shape):
|
|
_, input_units = input_shape
|
|
self._kernel = self.add_weight(
|
|
name="kernel",
|
|
shape=[input_units, self._units],
|
|
dtype="float32",
|
|
regularizer=self._kernel_regularizer,
|
|
initializer=self._kernel_initializer,
|
|
)
|
|
|
|
def call(self, inputs):
|
|
return ops.matmul(inputs, self._kernel)
|
|
|
|
|
|
@keras_core.saving.register_keras_serializable()
|
|
class MyWrapper(keras_core.layers.Layer):
|
|
def __init__(self, wrapped, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self._wrapped = wrapped
|
|
|
|
def get_config(self):
|
|
return dict(wrapped=self._wrapped, **super().get_config())
|
|
|
|
@classmethod
|
|
def from_config(cls, config):
|
|
config["wrapped"] = keras_core.saving.deserialize_keras_object(
|
|
config["wrapped"]
|
|
)
|
|
return cls(**config)
|
|
|
|
def call(self, inputs):
|
|
return self._wrapped(inputs)
|