Add Functional serialization.

This commit is contained in:
Francois Chollet 2023-04-24 13:46:31 -07:00
parent ce4ecc6cc2
commit 295f9a5f5f
10 changed files with 420 additions and 93 deletions

@ -85,7 +85,7 @@ class KerasTensor:
def __repr__(self):
return (
f"<KerasTensor shape={self.shape}, dtype={self.dtype}, "
"name={self.name}>"
f"name={self.name}>"
)
def __iter__(self):

@ -127,8 +127,10 @@ class Layer(Operation):
if config:
if "input_shape" in config:
self.build(config["input_shape"])
self._build_shapes_dict = config
elif "shapes_dict" in config:
self.build(**config["shapes_dict"])
self._build_shapes_dict = config["shapes_dict"]
def add_variable(
self,

@ -1,3 +1,4 @@
import copy
import inspect
import warnings
@ -33,11 +34,8 @@ class Functional(Function, Model):
return Function.__new__(cls)
@tracking.no_automatic_dependency_tracking
@python_utils.default
def __init__(self, inputs, outputs, name=None, **kwargs):
# This is used by the Model class, since we have some logic to swap the
# class in the __new__ method, which will lead to __init__ get invoked
# twice. Using the skip_init to skip one of the invocation of __init__
# to avoid any side effects
if isinstance(inputs, dict):
for k, v in inputs.items():
if not isinstance(v, backend.KerasTensor):
@ -87,7 +85,7 @@ class Functional(Function, Model):
f"Unrecognized type for `outputs`: {outputs} (of type {type(outputs)})"
)
super().__init__(inputs, outputs, name=name, **kwargs)
Function.__init__(self, inputs, outputs, name=name, **kwargs)
self._layers = self.layers
self.built = True
@ -108,9 +106,10 @@ class Functional(Function, Model):
masks = self._flatten_to_reference_inputs(mask)
for x, mask in zip(inputs, masks):
x._keras_mask = mask
return self._run_through_graph(
outputs = self._run_through_graph(
inputs, operation_fn=lambda op: operation_fn(op, training=training)
)
return unpack_singleton(outputs)
def compute_output_spec(self, inputs, training=None, mask=None):
# From Function
@ -188,35 +187,202 @@ class Functional(Function, Model):
# Symbolic only. TODO
raise NotImplementedError
@python_utils.default
def get_config(self):
# Prepare base arguments
if not functional_like_constructor(self.__class__):
# Subclassed networks are not serializable
# (unless serialization is implemented by
# the author of the subclassed network).
return Model.get_config()
config = {
"name": self.name,
"trainable": self.trainable,
}
# Check whether the class has a constructor compatible with a Functional
# model or if it has a custom constructor.
if functional_like_constructor(self.__class__):
# Only return a Functional config if the constructor is the same
# as that of a Functional model. This excludes subclassed Functional
# models with a custom __init__.
config = {**config, **get_functional_config(self)}
else:
# Try to autogenerate config
xtra_args = set(config.keys())
if getattr(self, "_auto_get_config", False):
config.update(self._auto_config.config)
# Remove args non explicitly supported
argspec = inspect.getfullargspec(self.__init__)
if argspec.varkw != "kwargs":
for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
config.pop(key, None)
return config
# Build a map from a layer unique name (make_node_key)
# to the index of the nodes that are saved in the config.
# Only nodes in network_nodes are saved.
node_reindexing_map = {}
for operation in self.operations:
if issubclass(operation.__class__, Functional):
# Functional models start with a pre-existing node
# linking their input to output.
kept_nodes = 1
else:
kept_nodes = 0
for original_node_index, node in enumerate(
operation._inbound_nodes
):
node_key = make_node_key(operation, original_node_index)
if node_key in self._nodes:
# i.e. we mark it to be saved
node_reindexing_map[node_key] = kept_nodes
kept_nodes += 1
# serialize and save the layers in layer_configs
layer_configs = []
for operation in self.operations: # From the earliest layers on.
filtered_inbound_nodes = []
for original_node_index, node in enumerate(
operation._inbound_nodes
):
node_key = make_node_key(operation, original_node_index)
if node_key in self._nodes:
# The node is relevant to the model:
# add to filtered_inbound_nodes.
node_data = serialize_node(node, node_reindexing_map)
if node_data is not None:
filtered_inbound_nodes.append(node_data)
layer_config = serialization_lib.serialize_keras_object(operation)
layer_config["name"] = operation.name
layer_config["inbound_nodes"] = filtered_inbound_nodes
layer_configs.append(layer_config)
config["layers"] = layer_configs
# Gather info about inputs and outputs.
model_inputs = []
for tensor in self._inputs:
operation = tensor._keras_history[0]
node_index = tensor._keras_history[1]
tensor_index = tensor._keras_history[2]
node_key = make_node_key(operation, node_index)
if node_key not in self._nodes:
continue
new_node_index = node_reindexing_map[node_key]
model_inputs.append([operation.name, new_node_index, tensor_index])
config["input_layers"] = model_inputs
model_outputs = []
for tensor in self._outputs:
operation = tensor._keras_history[0]
node_index = tensor._keras_history[1]
tensor_index = tensor._keras_history[2]
node_key = make_node_key(operation, node_index)
if node_key not in self._nodes:
continue
new_node_index = node_reindexing_map[node_key]
model_outputs.append([operation.name, new_node_index, tensor_index])
config["output_layers"] = model_outputs
return copy.deepcopy(config)
@classmethod
def from_config(self):
raise NotImplementedError
def from_config(cls, config, custom_objects=None):
"""Instantiates a Model from its config (output of `get_config()`)."""
# Layer instances created during
# the graph reconstruction process
created_layers = {}
# Dictionary mapping layer instances to
# node data that specifies a layer call.
# It acts as a queue that maintains any unprocessed
# layer call until it becomes possible to process it
# (i.e. until the input tensors to the call all exist).
unprocessed_nodes = {}
def add_unprocessed_node(layer, node_data):
"""Add node to layer list
Arg:
layer: layer object
node_data: Node data specifying layer call
"""
if layer not in unprocessed_nodes:
unprocessed_nodes[layer] = [node_data]
else:
unprocessed_nodes[layer].append(node_data)
def process_node(layer, node_data):
"""Reconstruct node by linking to inbound layers
Args:
layer: Layer to process
node_data: List of layer configs
"""
args, kwargs = deserialize_node(node_data, created_layers)
# Call layer on its inputs, thus creating the node
# and building the layer if needed.
layer(*args, **kwargs)
def process_layer(layer_data):
"""Deserializes a layer, then call it on appropriate inputs.
Args:
layer_data: layer config dict.
"""
layer_name = layer_data["name"]
# Instantiate layer.
layer = serialization_lib.deserialize_keras_object(
layer_data, custom_objects=custom_objects
)
created_layers[layer_name] = layer
# Gather layer inputs.
inbound_nodes_data = layer_data["inbound_nodes"]
for node_data in inbound_nodes_data:
# We don't process nodes (i.e. make layer calls)
# on the fly because the inbound node may not yet exist,
# in case of layer shared at different topological depths
# (e.g. a model such as A(B(A(B(x)))))
add_unprocessed_node(layer, node_data)
# First, we create all layers and enqueue nodes to be processed
for layer_data in config["layers"]:
process_layer(layer_data)
# Then we process nodes in order of layer depth.
# Nodes that cannot yet be processed (if the inbound node
# does not yet exist) are re-enqueued, and the process
# is repeated until all nodes are processed.
while unprocessed_nodes:
for layer_data in config["layers"]:
layer = created_layers[layer_data["name"]]
# Process all nodes in layer, if not yet processed
if layer in unprocessed_nodes:
node_data_list = unprocessed_nodes[layer]
# Process nodes in order
node_index = 0
while node_index < len(node_data_list):
node_data = node_data_list[node_index]
try:
process_node(layer, node_data)
# If the node does not have all inbound layers
# available, stop processing and continue later
except IndexError:
break
node_index += 1
# If not all nodes processed then store unprocessed nodes
if node_index < len(node_data_list):
unprocessed_nodes[layer] = node_data_list[node_index:]
# If all nodes processed remove the layer
else:
del unprocessed_nodes[layer]
# Create lits of input and output tensors and return new class
name = config.get("name")
input_tensors = []
output_tensors = []
for layer_data in config["input_layers"]:
layer_name, node_index, tensor_index = layer_data
assert layer_name in created_layers
layer = created_layers[layer_name]
layer_output_tensors = layer._inbound_nodes[
node_index
].output_tensors
input_tensors.append(layer_output_tensors[tensor_index])
for layer_data in config["output_layers"]:
layer_name, node_index, tensor_index = layer_data
assert layer_name in created_layers
layer = created_layers[layer_name]
layer_output_tensors = layer._inbound_nodes[
node_index
].output_tensors
output_tensors.append(layer_output_tensors[tensor_index])
return cls(inputs=input_tensors, outputs=output_tensors, name=name)
def operation_fn(operation, training):
@ -239,5 +405,86 @@ def functional_like_constructor(cls):
return False
def get_functional_config(network):
raise NotImplementedError
def unpack_singleton(x):
if len(x) == 1:
return x[0]
return x
def serialize_node(node, node_reindexing_map):
if not node.input_tensors:
# Does not need to be serialized.
return
args = node.arguments.args
kwargs = node.arguments.kwargs
return {
"args": serialization_lib.serialize_keras_object(args),
"kwargs": serialization_lib.serialize_keras_object(kwargs),
}
def deserialize_node(node_data, created_layers):
"""Return (args, kwargs) for calling the node layer."""
if not node_data:
return [], {}
if isinstance(node_data, list):
# Legacy case.
input_tensors = []
for input_data in node_data:
inbound_layer_name = input_data[0]
inbound_node_index = input_data[1]
inbound_tensor_index = input_data[2]
if len(input_data) == 3:
kwargs = {}
elif len(input_data) == 4:
kwargs = input_data[3]
else:
raise ValueError(
"Cannot deserialize the model (invalid config data?)"
)
inbound_layer = created_layers[inbound_layer_name]
# Raise an error if the corresponding layer node
# has not yet been created
if len(inbound_layer._inbound_nodes) <= inbound_node_index:
raise IndexError(
"Layer node index out of bounds.\n"
f"inbound_layer = {inbound_layer}\n"
f"inbound_layer._inbound_nodes = {inbound_layer._inbound_nodes}\n"
f"inbound_node_index = {inbound_node_index}"
)
inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
input_tensors.append(
inbound_node.output_tensors[inbound_tensor_index]
)
return [unpack_singleton(input_tensors)], kwargs
args = serialization_lib.deserialize_keras_object(node_data["args"])
kwargs = serialization_lib.deserialize_keras_object(node_data["kwargs"])
def convert_revived_tensor(x):
if isinstance(x, backend.KerasTensor):
history = x._pre_serialization_keras_history
if history is None:
return x
layer = created_layers.get(history[0], None)
if layer is None:
raise ValueError(f"Unknown layer: {history[0]}")
inbound_node_index = history[1]
inbound_tensor_index = history[1]
if len(layer._inbound_nodes) <= inbound_node_index:
raise ValueError(
"Layer node index out of bounds.\n"
f"inbound_layer = {layer}\n"
f"inbound_layer._inbound_nodes = {layer._inbound_nodes}\n"
f"inbound_node_index = {inbound_node_index}"
)
inbound_node = layer._inbound_nodes[inbound_node_index]
return inbound_node.output_tensors[inbound_tensor_index]
return x
args = nest.map_structure(convert_revived_tensor, args)
kwargs = nest.map_structure(convert_revived_tensor, kwargs)
return args, kwargs

@ -4,7 +4,8 @@ from keras_core import backend
from keras_core import layers
from keras_core import testing
from keras_core.layers.core.input_layer import Input
from keras_core.models.functional import Functional
from keras_core.models import Functional
from keras_core.models import Model
class FunctionalTest(testing.TestCase):
@ -14,9 +15,13 @@ class FunctionalTest(testing.TestCase):
x = input_a + input_b
x = layers.Dense(5)(x)
outputs = layers.Dense(4)(x)
model = Functional([input_a, input_b], outputs)
model = Functional([input_a, input_b], outputs, name="basic")
model.summary()
self.assertEqual(model.name, "basic")
self.assertTrue(isinstance(model, Functional))
self.assertTrue(isinstance(model, Model))
# Eager call
in_val = [np.random.random((2, 3)), np.random.random((2, 3))]
out_val = model(in_val)
@ -156,8 +161,44 @@ class FunctionalTest(testing.TestCase):
self.assertEqual(out_val.shape, (2, 3, 3))
def test_serialization(self):
# TODO
pass
# Test basic model
inputs = Input(shape=(3,), batch_size=2)
outputs = layers.Dense(3)(inputs)
model = Functional(inputs, outputs)
self.run_class_serialization_test(model)
# Test multi-io model
input_a = Input(shape=(3,), batch_size=2, name="input_a")
input_b = Input(shape=(3,), batch_size=2, name="input_b")
xa = layers.Dense(5, name="middle_a")(input_a)
xb = layers.Dense(5, name="middle_b")(input_b)
output_a = layers.Dense(4, name="output_a")(xa)
output_b = layers.Dense(4, name="output_b")(xb)
model = Functional(
[input_a, input_b], [output_a, output_b], name="func"
)
self.run_class_serialization_test(model)
# Test model that includes floating ops
input_a = Input(shape=(3,), batch_size=2, name="input_a")
input_b = Input(shape=(3,), batch_size=2, name="input_b")
x = input_a + input_b
x = layers.Dense(5, name="middle")(x)
output_a = layers.Dense(4, name="output_a")(x)
output_b = layers.Dense(4, name="output_b")(x)
model = Functional(
[input_a, input_b], [output_a, output_b], name="func"
)
self.run_class_serialization_test(model)
# Test model with dict i/o
input_a = Input(shape=(3,), batch_size=2, name="a")
input_b = Input(shape=(3,), batch_size=2, name="b")
x = input_a + input_b
x = layers.Dense(5)(x)
outputs = layers.Dense(4)(x)
model = Functional({"a": input_a, "b": input_b}, outputs)
self.run_class_serialization_test(model)
def test_add_loss(self):
# TODO

@ -1,6 +1,7 @@
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import python_utils
from keras_core.utils import summary_utils
if backend.backend() == "tensorflow":
@ -33,15 +34,21 @@ class Model(Trainer, Layer):
def __new__(cls, *args, **kwargs):
# Signature detection
if functional_init_arguments(args, kwargs):
# Functional model
from keras_core.models import functional
return functional.Functional(*args, **kwargs)
return Layer.__new__(cls)
return super().__new__(cls)
def __init__(self, trainable=True, name=None, dtype=None):
def __init__(self, *args, **kwargs):
Trainer.__init__(self)
Layer.__init__(self, trainable=trainable, name=name, dtype=dtype)
from keras_core.models import functional
if isinstance(self, functional.Functional) and python_utils.is_default(
self.__init__
):
functional.Functional.__init__(self, *args, **kwargs)
else:
Layer.__init__(self, *args, **kwargs)
def call(self, inputs, training=False):
raise NotImplementedError

@ -23,6 +23,10 @@ class Function(Operation):
self._operations = operations
self._operations_by_depth = operations_by_depth
@property
def operations(self):
return self._operations[:]
@property
def inputs(self):
return self._inputs
@ -134,8 +138,8 @@ class Function(Operation):
)
def make_node_key(op_name, node_index):
return op_name + "_ib-" + str(node_index)
def make_node_key(op, node_index):
return str(id(op)) + "_ib-" + str(node_index)
def map_graph(inputs, outputs):
@ -156,9 +160,7 @@ def map_graph(inputs, outputs):
# Nodes are ordered from inputs -> outputs.
nodes_in_decreasing_depth, operation_indices = _build_map(outputs)
network_nodes = {
make_node_key(
str(id(node.operation)), node.operation._inbound_nodes.index(node)
)
make_node_key(node.operation, node.operation._inbound_nodes.index(node))
for node in nodes_in_decreasing_depth
}
@ -195,7 +197,7 @@ def map_graph(inputs, outputs):
operations_depths[input_operation] = 0
operation_indices[input_operation] = -1
nodes_depths[input_operation._inbound_nodes[0]] = 0
network_nodes.add(make_node_key(input_operation.name, 0))
network_nodes.add(make_node_key(input_operation, 0))
# Build a dict {depth: list of nodes with this depth}
nodes_by_depth = collections.defaultdict(list)

@ -139,6 +139,19 @@ def serialize_keras_object(obj):
"class_name": "__bytes__",
"config": {"value": obj.decode("utf-8")},
}
if isinstance(obj, backend.KerasTensor):
history = getattr(obj, "_keras_history", None)
if history:
history = list(history)
history[0] = history[0].name
return {
"class_name": "__keras_tensor__",
"config": {
"shape": obj.shape,
"dtype": obj.dtype,
"keras_history": history,
},
}
if isinstance(obj, tf.TensorShape):
return obj.as_list() if obj._dims is not None else None
if isinstance(obj, (tf.Tensor, jax.numpy.ndarray)):
@ -203,12 +216,6 @@ def serialize_keras_object(obj):
obj.__class__, inner_config
)
# TODO(nkovela): Add TF ops dispatch handler serialization for
# ops.EagerTensor that contains nested numpy array.
# Target: NetworkConstructionTest.test_constant_initializer_with_numpy
if isinstance(inner_config, str) and inner_config == "op_dispatch_handler":
return obj
if config_with_public_class is not None:
get_build_and_compile_config(obj, config_with_public_class)
record_object_after_serialization(obj, config_with_public_class)
@ -564,6 +571,13 @@ def deserialize_keras_object(
custom_objects = custom_objects or {}
# Special cases:
if class_name == "__keras_tensor__":
obj = backend.KerasTensor(
inner_config["shape"], dtype=inner_config["dtype"]
)
obj._pre_serialization_keras_history = inner_config["keras_history"]
return obj
if class_name == "__tensor__":
return backend.convert_to_tensor(
inner_config["value"], dtype=inner_config["dtype"]

@ -168,7 +168,9 @@ class SerializationLibTest(testing.TestCase):
# def test_safe_mode_scope(self):
# lmbda = keras_core.layers.Lambda(lambda x: x**2)
# with serialization_lib.SafeModeScope(safe_mode=True):
# with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
# with self.assertRaisesRegex(
# ValueError, "arbitrary code execution"
# ):
# self.roundtrip(lmbda)
# with serialization_lib.SafeModeScope(safe_mode=False):
# _, new_lmbda, _ = self.roundtrip(lmbda)
@ -192,44 +194,45 @@ class SerializationLibTest(testing.TestCase):
self.assertIs(model.layers[2], model.layers[3].layer)
self.assertIs(new_model.layers[2], new_model.layers[3].layer)
# TODO
# def test_functional_subclass(self):
# class PlainFunctionalSubclass(keras_core.Model):
# pass
def test_functional_subclass(self):
class PlainFunctionalSubclass(keras_core.Model):
pass
# inputs = keras_core.Input((2,), batch_size=3)
# outputs = keras_core.layers.Dense(1)(inputs)
# model = PlainFunctionalSubclass(inputs, outputs)
# x = ops.random.normal((2, 2))
# y1 = model(x)
# _, new_model, _ = self.roundtrip(
# model,
# custom_objects={"PlainFunctionalSubclass": PlainFunctionalSubclass},
# )
# new_model.set_weights(model.get_weights())
# y2 = new_model(x)
# self.assertAllClose(y1, y2, atol=1e-5)
# self.assertIsInstance(new_model, PlainFunctionalSubclass)
inputs = keras_core.Input((2,), batch_size=3)
outputs = keras_core.layers.Dense(1)(inputs)
model = PlainFunctionalSubclass(inputs, outputs)
x = ops.random.normal((2, 2))
y1 = model(x)
_, new_model, _ = self.roundtrip(
model,
custom_objects={"PlainFunctionalSubclass": PlainFunctionalSubclass},
)
new_model.set_weights(model.get_weights())
y2 = new_model(x)
self.assertAllClose(y1, y2, atol=1e-5)
# TODO
# self.assertIsInstance(new_model, PlainFunctionalSubclass)
# class FunctionalSubclassWCustomInit(keras_core.Model):
# def __init__(self, num_units=1, **kwargs):
# inputs = keras_core.Input((2,), batch_size=3)
# outputs = keras_core.layers.Dense(num_units)(inputs)
# super().__init__(inputs, outputs)
# TODO
# class FunctionalSubclassWCustomInit(keras_core.Model):
# def __init__(self, num_units=1, **kwargs):
# inputs = keras_core.Input((2,), batch_size=3)
# outputs = keras_core.layers.Dense(num_units)(inputs)
# super().__init__(inputs, outputs)
# model = FunctionalSubclassWCustomInit(num_units=2)
# x = ops.random.normal((2, 2))
# y1 = model(x)
# _, new_model, _ = self.roundtrip(
# model,
# custom_objects={
# "FunctionalSubclassWCustomInit": FunctionalSubclassWCustomInit
# },
# )
# new_model.set_weights(model.get_weights())
# y2 = new_model(x)
# self.assertAllClose(y1, y2, atol=1e-5)
# self.assertIsInstance(new_model, FunctionalSubclassWCustomInit)
# model = FunctionalSubclassWCustomInit(num_units=2)
# x = ops.random.normal((2, 2))
# y1 = model(x)
# _, new_model, _ = self.roundtrip(
# model,
# custom_objects={
# "FunctionalSubclassWCustomInit": FunctionalSubclassWCustomInit
# },
# )
# new_model.set_weights(model.get_weights())
# y2 = new_model(x)
# self.assertAllClose(y1, y2, atol=1e-5)
# self.assertIsInstance(new_model, FunctionalSubclassWCustomInit)
def test_shared_object(self):
class MyLayer(keras_core.layers.Layer):

@ -6,6 +6,8 @@ from tensorflow import nest
class TestCase(unittest.TestCase):
maxDiff = None
def assertAllClose(self, x1, x2, atol=1e-7, rtol=1e-7):
np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol)
@ -23,20 +25,29 @@ class TestCase(unittest.TestCase):
# get_config roundtrip
cls = instance.__class__
config = instance.get_config()
config_json = json.dumps(config, sort_keys=True, indent=4)
ref_dir = dir(instance)[:]
with custom_object_scope(custom_objects):
revived_instance = cls.from_config(config)
revived_config = revived_instance.get_config()
self.assertEqual(config, revived_config)
revived_config_json = json.dumps(
revived_config, sort_keys=True, indent=4
)
self.assertEqual(config_json, revived_config_json)
self.assertEqual(ref_dir, dir(revived_instance))
# serialization roundtrip
serialized = serialize_keras_object(instance)
json_str = json.dumps(serialized)
serialized_json = json.dumps(serialized, sort_keys=True, indent=4)
with custom_object_scope(custom_objects):
revived_instance = deserialize_keras_object(json.loads(json_str))
revived_instance = deserialize_keras_object(
json.loads(serialized_json)
)
revived_config = revived_instance.get_config()
self.assertEqual(config, revived_config)
revived_config_json = json.dumps(
revived_config, sort_keys=True, indent=4
)
self.assertEqual(config_json, revived_config_json)
self.assertEqual(ref_dir, dir(revived_instance))
def run_layer_test(

@ -250,11 +250,11 @@ class Trainer:
def get_compile_config(self):
# TODO
raise NotImplementedError
pass
def compile_from_config(self):
# TODO
raise NotImplementedError
return {}
def _should_eval(self, epoch, validation_freq):
epoch = epoch + 1 # one-index the user-facing epoch.