Add Functional serialization.
This commit is contained in:
parent
ce4ecc6cc2
commit
295f9a5f5f
@ -85,7 +85,7 @@ class KerasTensor:
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<KerasTensor shape={self.shape}, dtype={self.dtype}, "
|
||||
"name={self.name}>"
|
||||
f"name={self.name}>"
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -127,8 +127,10 @@ class Layer(Operation):
|
||||
if config:
|
||||
if "input_shape" in config:
|
||||
self.build(config["input_shape"])
|
||||
self._build_shapes_dict = config
|
||||
elif "shapes_dict" in config:
|
||||
self.build(**config["shapes_dict"])
|
||||
self._build_shapes_dict = config["shapes_dict"]
|
||||
|
||||
def add_variable(
|
||||
self,
|
||||
|
@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
@ -33,11 +34,8 @@ class Functional(Function, Model):
|
||||
return Function.__new__(cls)
|
||||
|
||||
@tracking.no_automatic_dependency_tracking
|
||||
@python_utils.default
|
||||
def __init__(self, inputs, outputs, name=None, **kwargs):
|
||||
# This is used by the Model class, since we have some logic to swap the
|
||||
# class in the __new__ method, which will lead to __init__ get invoked
|
||||
# twice. Using the skip_init to skip one of the invocation of __init__
|
||||
# to avoid any side effects
|
||||
if isinstance(inputs, dict):
|
||||
for k, v in inputs.items():
|
||||
if not isinstance(v, backend.KerasTensor):
|
||||
@ -87,7 +85,7 @@ class Functional(Function, Model):
|
||||
f"Unrecognized type for `outputs`: {outputs} (of type {type(outputs)})"
|
||||
)
|
||||
|
||||
super().__init__(inputs, outputs, name=name, **kwargs)
|
||||
Function.__init__(self, inputs, outputs, name=name, **kwargs)
|
||||
self._layers = self.layers
|
||||
self.built = True
|
||||
|
||||
@ -108,9 +106,10 @@ class Functional(Function, Model):
|
||||
masks = self._flatten_to_reference_inputs(mask)
|
||||
for x, mask in zip(inputs, masks):
|
||||
x._keras_mask = mask
|
||||
return self._run_through_graph(
|
||||
outputs = self._run_through_graph(
|
||||
inputs, operation_fn=lambda op: operation_fn(op, training=training)
|
||||
)
|
||||
return unpack_singleton(outputs)
|
||||
|
||||
def compute_output_spec(self, inputs, training=None, mask=None):
|
||||
# From Function
|
||||
@ -188,35 +187,202 @@ class Functional(Function, Model):
|
||||
# Symbolic only. TODO
|
||||
raise NotImplementedError
|
||||
|
||||
@python_utils.default
|
||||
def get_config(self):
|
||||
# Prepare base arguments
|
||||
if not functional_like_constructor(self.__class__):
|
||||
# Subclassed networks are not serializable
|
||||
# (unless serialization is implemented by
|
||||
# the author of the subclassed network).
|
||||
return Model.get_config()
|
||||
|
||||
config = {
|
||||
"name": self.name,
|
||||
"trainable": self.trainable,
|
||||
}
|
||||
# Check whether the class has a constructor compatible with a Functional
|
||||
# model or if it has a custom constructor.
|
||||
if functional_like_constructor(self.__class__):
|
||||
# Only return a Functional config if the constructor is the same
|
||||
# as that of a Functional model. This excludes subclassed Functional
|
||||
# models with a custom __init__.
|
||||
config = {**config, **get_functional_config(self)}
|
||||
else:
|
||||
# Try to autogenerate config
|
||||
xtra_args = set(config.keys())
|
||||
if getattr(self, "_auto_get_config", False):
|
||||
config.update(self._auto_config.config)
|
||||
# Remove args non explicitly supported
|
||||
argspec = inspect.getfullargspec(self.__init__)
|
||||
if argspec.varkw != "kwargs":
|
||||
for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
|
||||
config.pop(key, None)
|
||||
return config
|
||||
# Build a map from a layer unique name (make_node_key)
|
||||
# to the index of the nodes that are saved in the config.
|
||||
# Only nodes in network_nodes are saved.
|
||||
node_reindexing_map = {}
|
||||
for operation in self.operations:
|
||||
if issubclass(operation.__class__, Functional):
|
||||
# Functional models start with a pre-existing node
|
||||
# linking their input to output.
|
||||
kept_nodes = 1
|
||||
else:
|
||||
kept_nodes = 0
|
||||
for original_node_index, node in enumerate(
|
||||
operation._inbound_nodes
|
||||
):
|
||||
node_key = make_node_key(operation, original_node_index)
|
||||
if node_key in self._nodes:
|
||||
# i.e. we mark it to be saved
|
||||
node_reindexing_map[node_key] = kept_nodes
|
||||
kept_nodes += 1
|
||||
|
||||
# serialize and save the layers in layer_configs
|
||||
layer_configs = []
|
||||
for operation in self.operations: # From the earliest layers on.
|
||||
filtered_inbound_nodes = []
|
||||
for original_node_index, node in enumerate(
|
||||
operation._inbound_nodes
|
||||
):
|
||||
node_key = make_node_key(operation, original_node_index)
|
||||
if node_key in self._nodes:
|
||||
# The node is relevant to the model:
|
||||
# add to filtered_inbound_nodes.
|
||||
node_data = serialize_node(node, node_reindexing_map)
|
||||
if node_data is not None:
|
||||
filtered_inbound_nodes.append(node_data)
|
||||
|
||||
layer_config = serialization_lib.serialize_keras_object(operation)
|
||||
layer_config["name"] = operation.name
|
||||
layer_config["inbound_nodes"] = filtered_inbound_nodes
|
||||
layer_configs.append(layer_config)
|
||||
config["layers"] = layer_configs
|
||||
|
||||
# Gather info about inputs and outputs.
|
||||
model_inputs = []
|
||||
for tensor in self._inputs:
|
||||
operation = tensor._keras_history[0]
|
||||
node_index = tensor._keras_history[1]
|
||||
tensor_index = tensor._keras_history[2]
|
||||
node_key = make_node_key(operation, node_index)
|
||||
if node_key not in self._nodes:
|
||||
continue
|
||||
new_node_index = node_reindexing_map[node_key]
|
||||
model_inputs.append([operation.name, new_node_index, tensor_index])
|
||||
config["input_layers"] = model_inputs
|
||||
model_outputs = []
|
||||
for tensor in self._outputs:
|
||||
operation = tensor._keras_history[0]
|
||||
node_index = tensor._keras_history[1]
|
||||
tensor_index = tensor._keras_history[2]
|
||||
node_key = make_node_key(operation, node_index)
|
||||
if node_key not in self._nodes:
|
||||
continue
|
||||
new_node_index = node_reindexing_map[node_key]
|
||||
model_outputs.append([operation.name, new_node_index, tensor_index])
|
||||
config["output_layers"] = model_outputs
|
||||
return copy.deepcopy(config)
|
||||
|
||||
@classmethod
|
||||
def from_config(self):
|
||||
raise NotImplementedError
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
"""Instantiates a Model from its config (output of `get_config()`)."""
|
||||
# Layer instances created during
|
||||
# the graph reconstruction process
|
||||
created_layers = {}
|
||||
|
||||
# Dictionary mapping layer instances to
|
||||
# node data that specifies a layer call.
|
||||
# It acts as a queue that maintains any unprocessed
|
||||
# layer call until it becomes possible to process it
|
||||
# (i.e. until the input tensors to the call all exist).
|
||||
unprocessed_nodes = {}
|
||||
|
||||
def add_unprocessed_node(layer, node_data):
|
||||
"""Add node to layer list
|
||||
|
||||
Arg:
|
||||
layer: layer object
|
||||
node_data: Node data specifying layer call
|
||||
"""
|
||||
if layer not in unprocessed_nodes:
|
||||
unprocessed_nodes[layer] = [node_data]
|
||||
else:
|
||||
unprocessed_nodes[layer].append(node_data)
|
||||
|
||||
def process_node(layer, node_data):
|
||||
"""Reconstruct node by linking to inbound layers
|
||||
|
||||
Args:
|
||||
layer: Layer to process
|
||||
node_data: List of layer configs
|
||||
"""
|
||||
args, kwargs = deserialize_node(node_data, created_layers)
|
||||
# Call layer on its inputs, thus creating the node
|
||||
# and building the layer if needed.
|
||||
layer(*args, **kwargs)
|
||||
|
||||
def process_layer(layer_data):
|
||||
"""Deserializes a layer, then call it on appropriate inputs.
|
||||
|
||||
Args:
|
||||
layer_data: layer config dict.
|
||||
"""
|
||||
layer_name = layer_data["name"]
|
||||
|
||||
# Instantiate layer.
|
||||
layer = serialization_lib.deserialize_keras_object(
|
||||
layer_data, custom_objects=custom_objects
|
||||
)
|
||||
created_layers[layer_name] = layer
|
||||
|
||||
# Gather layer inputs.
|
||||
inbound_nodes_data = layer_data["inbound_nodes"]
|
||||
for node_data in inbound_nodes_data:
|
||||
# We don't process nodes (i.e. make layer calls)
|
||||
# on the fly because the inbound node may not yet exist,
|
||||
# in case of layer shared at different topological depths
|
||||
# (e.g. a model such as A(B(A(B(x)))))
|
||||
add_unprocessed_node(layer, node_data)
|
||||
|
||||
# First, we create all layers and enqueue nodes to be processed
|
||||
for layer_data in config["layers"]:
|
||||
process_layer(layer_data)
|
||||
|
||||
# Then we process nodes in order of layer depth.
|
||||
# Nodes that cannot yet be processed (if the inbound node
|
||||
# does not yet exist) are re-enqueued, and the process
|
||||
# is repeated until all nodes are processed.
|
||||
while unprocessed_nodes:
|
||||
for layer_data in config["layers"]:
|
||||
layer = created_layers[layer_data["name"]]
|
||||
|
||||
# Process all nodes in layer, if not yet processed
|
||||
if layer in unprocessed_nodes:
|
||||
node_data_list = unprocessed_nodes[layer]
|
||||
|
||||
# Process nodes in order
|
||||
node_index = 0
|
||||
while node_index < len(node_data_list):
|
||||
node_data = node_data_list[node_index]
|
||||
try:
|
||||
process_node(layer, node_data)
|
||||
|
||||
# If the node does not have all inbound layers
|
||||
# available, stop processing and continue later
|
||||
except IndexError:
|
||||
break
|
||||
|
||||
node_index += 1
|
||||
|
||||
# If not all nodes processed then store unprocessed nodes
|
||||
if node_index < len(node_data_list):
|
||||
unprocessed_nodes[layer] = node_data_list[node_index:]
|
||||
# If all nodes processed remove the layer
|
||||
else:
|
||||
del unprocessed_nodes[layer]
|
||||
|
||||
# Create lits of input and output tensors and return new class
|
||||
name = config.get("name")
|
||||
input_tensors = []
|
||||
output_tensors = []
|
||||
for layer_data in config["input_layers"]:
|
||||
layer_name, node_index, tensor_index = layer_data
|
||||
assert layer_name in created_layers
|
||||
layer = created_layers[layer_name]
|
||||
layer_output_tensors = layer._inbound_nodes[
|
||||
node_index
|
||||
].output_tensors
|
||||
input_tensors.append(layer_output_tensors[tensor_index])
|
||||
for layer_data in config["output_layers"]:
|
||||
layer_name, node_index, tensor_index = layer_data
|
||||
assert layer_name in created_layers
|
||||
layer = created_layers[layer_name]
|
||||
layer_output_tensors = layer._inbound_nodes[
|
||||
node_index
|
||||
].output_tensors
|
||||
output_tensors.append(layer_output_tensors[tensor_index])
|
||||
return cls(inputs=input_tensors, outputs=output_tensors, name=name)
|
||||
|
||||
|
||||
def operation_fn(operation, training):
|
||||
@ -239,5 +405,86 @@ def functional_like_constructor(cls):
|
||||
return False
|
||||
|
||||
|
||||
def get_functional_config(network):
|
||||
raise NotImplementedError
|
||||
def unpack_singleton(x):
|
||||
if len(x) == 1:
|
||||
return x[0]
|
||||
return x
|
||||
|
||||
|
||||
def serialize_node(node, node_reindexing_map):
|
||||
if not node.input_tensors:
|
||||
# Does not need to be serialized.
|
||||
return
|
||||
|
||||
args = node.arguments.args
|
||||
kwargs = node.arguments.kwargs
|
||||
return {
|
||||
"args": serialization_lib.serialize_keras_object(args),
|
||||
"kwargs": serialization_lib.serialize_keras_object(kwargs),
|
||||
}
|
||||
|
||||
|
||||
def deserialize_node(node_data, created_layers):
|
||||
"""Return (args, kwargs) for calling the node layer."""
|
||||
if not node_data:
|
||||
return [], {}
|
||||
|
||||
if isinstance(node_data, list):
|
||||
# Legacy case.
|
||||
input_tensors = []
|
||||
for input_data in node_data:
|
||||
inbound_layer_name = input_data[0]
|
||||
inbound_node_index = input_data[1]
|
||||
inbound_tensor_index = input_data[2]
|
||||
if len(input_data) == 3:
|
||||
kwargs = {}
|
||||
elif len(input_data) == 4:
|
||||
kwargs = input_data[3]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot deserialize the model (invalid config data?)"
|
||||
)
|
||||
inbound_layer = created_layers[inbound_layer_name]
|
||||
|
||||
# Raise an error if the corresponding layer node
|
||||
# has not yet been created
|
||||
if len(inbound_layer._inbound_nodes) <= inbound_node_index:
|
||||
raise IndexError(
|
||||
"Layer node index out of bounds.\n"
|
||||
f"inbound_layer = {inbound_layer}\n"
|
||||
f"inbound_layer._inbound_nodes = {inbound_layer._inbound_nodes}\n"
|
||||
f"inbound_node_index = {inbound_node_index}"
|
||||
)
|
||||
inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
|
||||
input_tensors.append(
|
||||
inbound_node.output_tensors[inbound_tensor_index]
|
||||
)
|
||||
return [unpack_singleton(input_tensors)], kwargs
|
||||
|
||||
args = serialization_lib.deserialize_keras_object(node_data["args"])
|
||||
kwargs = serialization_lib.deserialize_keras_object(node_data["kwargs"])
|
||||
|
||||
def convert_revived_tensor(x):
|
||||
if isinstance(x, backend.KerasTensor):
|
||||
history = x._pre_serialization_keras_history
|
||||
if history is None:
|
||||
return x
|
||||
layer = created_layers.get(history[0], None)
|
||||
if layer is None:
|
||||
raise ValueError(f"Unknown layer: {history[0]}")
|
||||
inbound_node_index = history[1]
|
||||
inbound_tensor_index = history[1]
|
||||
if len(layer._inbound_nodes) <= inbound_node_index:
|
||||
raise ValueError(
|
||||
"Layer node index out of bounds.\n"
|
||||
f"inbound_layer = {layer}\n"
|
||||
f"inbound_layer._inbound_nodes = {layer._inbound_nodes}\n"
|
||||
f"inbound_node_index = {inbound_node_index}"
|
||||
)
|
||||
inbound_node = layer._inbound_nodes[inbound_node_index]
|
||||
return inbound_node.output_tensors[inbound_tensor_index]
|
||||
return x
|
||||
|
||||
args = nest.map_structure(convert_revived_tensor, args)
|
||||
kwargs = nest.map_structure(convert_revived_tensor, kwargs)
|
||||
return args, kwargs
|
||||
|
@ -4,7 +4,8 @@ from keras_core import backend
|
||||
from keras_core import layers
|
||||
from keras_core import testing
|
||||
from keras_core.layers.core.input_layer import Input
|
||||
from keras_core.models.functional import Functional
|
||||
from keras_core.models import Functional
|
||||
from keras_core.models import Model
|
||||
|
||||
|
||||
class FunctionalTest(testing.TestCase):
|
||||
@ -14,9 +15,13 @@ class FunctionalTest(testing.TestCase):
|
||||
x = input_a + input_b
|
||||
x = layers.Dense(5)(x)
|
||||
outputs = layers.Dense(4)(x)
|
||||
model = Functional([input_a, input_b], outputs)
|
||||
model = Functional([input_a, input_b], outputs, name="basic")
|
||||
model.summary()
|
||||
|
||||
self.assertEqual(model.name, "basic")
|
||||
self.assertTrue(isinstance(model, Functional))
|
||||
self.assertTrue(isinstance(model, Model))
|
||||
|
||||
# Eager call
|
||||
in_val = [np.random.random((2, 3)), np.random.random((2, 3))]
|
||||
out_val = model(in_val)
|
||||
@ -156,8 +161,44 @@ class FunctionalTest(testing.TestCase):
|
||||
self.assertEqual(out_val.shape, (2, 3, 3))
|
||||
|
||||
def test_serialization(self):
|
||||
# TODO
|
||||
pass
|
||||
# Test basic model
|
||||
inputs = Input(shape=(3,), batch_size=2)
|
||||
outputs = layers.Dense(3)(inputs)
|
||||
model = Functional(inputs, outputs)
|
||||
self.run_class_serialization_test(model)
|
||||
|
||||
# Test multi-io model
|
||||
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
||||
input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
||||
xa = layers.Dense(5, name="middle_a")(input_a)
|
||||
xb = layers.Dense(5, name="middle_b")(input_b)
|
||||
output_a = layers.Dense(4, name="output_a")(xa)
|
||||
output_b = layers.Dense(4, name="output_b")(xb)
|
||||
model = Functional(
|
||||
[input_a, input_b], [output_a, output_b], name="func"
|
||||
)
|
||||
self.run_class_serialization_test(model)
|
||||
|
||||
# Test model that includes floating ops
|
||||
input_a = Input(shape=(3,), batch_size=2, name="input_a")
|
||||
input_b = Input(shape=(3,), batch_size=2, name="input_b")
|
||||
x = input_a + input_b
|
||||
x = layers.Dense(5, name="middle")(x)
|
||||
output_a = layers.Dense(4, name="output_a")(x)
|
||||
output_b = layers.Dense(4, name="output_b")(x)
|
||||
model = Functional(
|
||||
[input_a, input_b], [output_a, output_b], name="func"
|
||||
)
|
||||
self.run_class_serialization_test(model)
|
||||
|
||||
# Test model with dict i/o
|
||||
input_a = Input(shape=(3,), batch_size=2, name="a")
|
||||
input_b = Input(shape=(3,), batch_size=2, name="b")
|
||||
x = input_a + input_b
|
||||
x = layers.Dense(5)(x)
|
||||
outputs = layers.Dense(4)(x)
|
||||
model = Functional({"a": input_a, "b": input_b}, outputs)
|
||||
self.run_class_serialization_test(model)
|
||||
|
||||
def test_add_loss(self):
|
||||
# TODO
|
||||
|
@ -1,6 +1,7 @@
|
||||
from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import python_utils
|
||||
from keras_core.utils import summary_utils
|
||||
|
||||
if backend.backend() == "tensorflow":
|
||||
@ -33,15 +34,21 @@ class Model(Trainer, Layer):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# Signature detection
|
||||
if functional_init_arguments(args, kwargs):
|
||||
# Functional model
|
||||
from keras_core.models import functional
|
||||
|
||||
return functional.Functional(*args, **kwargs)
|
||||
return Layer.__new__(cls)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, trainable=True, name=None, dtype=None):
|
||||
def __init__(self, *args, **kwargs):
|
||||
Trainer.__init__(self)
|
||||
Layer.__init__(self, trainable=trainable, name=name, dtype=dtype)
|
||||
from keras_core.models import functional
|
||||
|
||||
if isinstance(self, functional.Functional) and python_utils.is_default(
|
||||
self.__init__
|
||||
):
|
||||
functional.Functional.__init__(self, *args, **kwargs)
|
||||
else:
|
||||
Layer.__init__(self, *args, **kwargs)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
raise NotImplementedError
|
||||
|
@ -23,6 +23,10 @@ class Function(Operation):
|
||||
self._operations = operations
|
||||
self._operations_by_depth = operations_by_depth
|
||||
|
||||
@property
|
||||
def operations(self):
|
||||
return self._operations[:]
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return self._inputs
|
||||
@ -134,8 +138,8 @@ class Function(Operation):
|
||||
)
|
||||
|
||||
|
||||
def make_node_key(op_name, node_index):
|
||||
return op_name + "_ib-" + str(node_index)
|
||||
def make_node_key(op, node_index):
|
||||
return str(id(op)) + "_ib-" + str(node_index)
|
||||
|
||||
|
||||
def map_graph(inputs, outputs):
|
||||
@ -156,9 +160,7 @@ def map_graph(inputs, outputs):
|
||||
# Nodes are ordered from inputs -> outputs.
|
||||
nodes_in_decreasing_depth, operation_indices = _build_map(outputs)
|
||||
network_nodes = {
|
||||
make_node_key(
|
||||
str(id(node.operation)), node.operation._inbound_nodes.index(node)
|
||||
)
|
||||
make_node_key(node.operation, node.operation._inbound_nodes.index(node))
|
||||
for node in nodes_in_decreasing_depth
|
||||
}
|
||||
|
||||
@ -195,7 +197,7 @@ def map_graph(inputs, outputs):
|
||||
operations_depths[input_operation] = 0
|
||||
operation_indices[input_operation] = -1
|
||||
nodes_depths[input_operation._inbound_nodes[0]] = 0
|
||||
network_nodes.add(make_node_key(input_operation.name, 0))
|
||||
network_nodes.add(make_node_key(input_operation, 0))
|
||||
|
||||
# Build a dict {depth: list of nodes with this depth}
|
||||
nodes_by_depth = collections.defaultdict(list)
|
||||
|
@ -139,6 +139,19 @@ def serialize_keras_object(obj):
|
||||
"class_name": "__bytes__",
|
||||
"config": {"value": obj.decode("utf-8")},
|
||||
}
|
||||
if isinstance(obj, backend.KerasTensor):
|
||||
history = getattr(obj, "_keras_history", None)
|
||||
if history:
|
||||
history = list(history)
|
||||
history[0] = history[0].name
|
||||
return {
|
||||
"class_name": "__keras_tensor__",
|
||||
"config": {
|
||||
"shape": obj.shape,
|
||||
"dtype": obj.dtype,
|
||||
"keras_history": history,
|
||||
},
|
||||
}
|
||||
if isinstance(obj, tf.TensorShape):
|
||||
return obj.as_list() if obj._dims is not None else None
|
||||
if isinstance(obj, (tf.Tensor, jax.numpy.ndarray)):
|
||||
@ -203,12 +216,6 @@ def serialize_keras_object(obj):
|
||||
obj.__class__, inner_config
|
||||
)
|
||||
|
||||
# TODO(nkovela): Add TF ops dispatch handler serialization for
|
||||
# ops.EagerTensor that contains nested numpy array.
|
||||
# Target: NetworkConstructionTest.test_constant_initializer_with_numpy
|
||||
if isinstance(inner_config, str) and inner_config == "op_dispatch_handler":
|
||||
return obj
|
||||
|
||||
if config_with_public_class is not None:
|
||||
get_build_and_compile_config(obj, config_with_public_class)
|
||||
record_object_after_serialization(obj, config_with_public_class)
|
||||
@ -564,6 +571,13 @@ def deserialize_keras_object(
|
||||
custom_objects = custom_objects or {}
|
||||
|
||||
# Special cases:
|
||||
if class_name == "__keras_tensor__":
|
||||
obj = backend.KerasTensor(
|
||||
inner_config["shape"], dtype=inner_config["dtype"]
|
||||
)
|
||||
obj._pre_serialization_keras_history = inner_config["keras_history"]
|
||||
return obj
|
||||
|
||||
if class_name == "__tensor__":
|
||||
return backend.convert_to_tensor(
|
||||
inner_config["value"], dtype=inner_config["dtype"]
|
||||
|
@ -168,7 +168,9 @@ class SerializationLibTest(testing.TestCase):
|
||||
# def test_safe_mode_scope(self):
|
||||
# lmbda = keras_core.layers.Lambda(lambda x: x**2)
|
||||
# with serialization_lib.SafeModeScope(safe_mode=True):
|
||||
# with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
|
||||
# with self.assertRaisesRegex(
|
||||
# ValueError, "arbitrary code execution"
|
||||
# ):
|
||||
# self.roundtrip(lmbda)
|
||||
# with serialization_lib.SafeModeScope(safe_mode=False):
|
||||
# _, new_lmbda, _ = self.roundtrip(lmbda)
|
||||
@ -192,44 +194,45 @@ class SerializationLibTest(testing.TestCase):
|
||||
self.assertIs(model.layers[2], model.layers[3].layer)
|
||||
self.assertIs(new_model.layers[2], new_model.layers[3].layer)
|
||||
|
||||
# TODO
|
||||
# def test_functional_subclass(self):
|
||||
# class PlainFunctionalSubclass(keras_core.Model):
|
||||
# pass
|
||||
def test_functional_subclass(self):
|
||||
class PlainFunctionalSubclass(keras_core.Model):
|
||||
pass
|
||||
|
||||
# inputs = keras_core.Input((2,), batch_size=3)
|
||||
# outputs = keras_core.layers.Dense(1)(inputs)
|
||||
# model = PlainFunctionalSubclass(inputs, outputs)
|
||||
# x = ops.random.normal((2, 2))
|
||||
# y1 = model(x)
|
||||
# _, new_model, _ = self.roundtrip(
|
||||
# model,
|
||||
# custom_objects={"PlainFunctionalSubclass": PlainFunctionalSubclass},
|
||||
# )
|
||||
# new_model.set_weights(model.get_weights())
|
||||
# y2 = new_model(x)
|
||||
# self.assertAllClose(y1, y2, atol=1e-5)
|
||||
# self.assertIsInstance(new_model, PlainFunctionalSubclass)
|
||||
inputs = keras_core.Input((2,), batch_size=3)
|
||||
outputs = keras_core.layers.Dense(1)(inputs)
|
||||
model = PlainFunctionalSubclass(inputs, outputs)
|
||||
x = ops.random.normal((2, 2))
|
||||
y1 = model(x)
|
||||
_, new_model, _ = self.roundtrip(
|
||||
model,
|
||||
custom_objects={"PlainFunctionalSubclass": PlainFunctionalSubclass},
|
||||
)
|
||||
new_model.set_weights(model.get_weights())
|
||||
y2 = new_model(x)
|
||||
self.assertAllClose(y1, y2, atol=1e-5)
|
||||
# TODO
|
||||
# self.assertIsInstance(new_model, PlainFunctionalSubclass)
|
||||
|
||||
# class FunctionalSubclassWCustomInit(keras_core.Model):
|
||||
# def __init__(self, num_units=1, **kwargs):
|
||||
# inputs = keras_core.Input((2,), batch_size=3)
|
||||
# outputs = keras_core.layers.Dense(num_units)(inputs)
|
||||
# super().__init__(inputs, outputs)
|
||||
# TODO
|
||||
# class FunctionalSubclassWCustomInit(keras_core.Model):
|
||||
# def __init__(self, num_units=1, **kwargs):
|
||||
# inputs = keras_core.Input((2,), batch_size=3)
|
||||
# outputs = keras_core.layers.Dense(num_units)(inputs)
|
||||
# super().__init__(inputs, outputs)
|
||||
|
||||
# model = FunctionalSubclassWCustomInit(num_units=2)
|
||||
# x = ops.random.normal((2, 2))
|
||||
# y1 = model(x)
|
||||
# _, new_model, _ = self.roundtrip(
|
||||
# model,
|
||||
# custom_objects={
|
||||
# "FunctionalSubclassWCustomInit": FunctionalSubclassWCustomInit
|
||||
# },
|
||||
# )
|
||||
# new_model.set_weights(model.get_weights())
|
||||
# y2 = new_model(x)
|
||||
# self.assertAllClose(y1, y2, atol=1e-5)
|
||||
# self.assertIsInstance(new_model, FunctionalSubclassWCustomInit)
|
||||
# model = FunctionalSubclassWCustomInit(num_units=2)
|
||||
# x = ops.random.normal((2, 2))
|
||||
# y1 = model(x)
|
||||
# _, new_model, _ = self.roundtrip(
|
||||
# model,
|
||||
# custom_objects={
|
||||
# "FunctionalSubclassWCustomInit": FunctionalSubclassWCustomInit
|
||||
# },
|
||||
# )
|
||||
# new_model.set_weights(model.get_weights())
|
||||
# y2 = new_model(x)
|
||||
# self.assertAllClose(y1, y2, atol=1e-5)
|
||||
# self.assertIsInstance(new_model, FunctionalSubclassWCustomInit)
|
||||
|
||||
def test_shared_object(self):
|
||||
class MyLayer(keras_core.layers.Layer):
|
||||
|
@ -6,6 +6,8 @@ from tensorflow import nest
|
||||
|
||||
|
||||
class TestCase(unittest.TestCase):
|
||||
maxDiff = None
|
||||
|
||||
def assertAllClose(self, x1, x2, atol=1e-7, rtol=1e-7):
|
||||
np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol)
|
||||
|
||||
@ -23,20 +25,29 @@ class TestCase(unittest.TestCase):
|
||||
# get_config roundtrip
|
||||
cls = instance.__class__
|
||||
config = instance.get_config()
|
||||
config_json = json.dumps(config, sort_keys=True, indent=4)
|
||||
ref_dir = dir(instance)[:]
|
||||
with custom_object_scope(custom_objects):
|
||||
revived_instance = cls.from_config(config)
|
||||
revived_config = revived_instance.get_config()
|
||||
self.assertEqual(config, revived_config)
|
||||
revived_config_json = json.dumps(
|
||||
revived_config, sort_keys=True, indent=4
|
||||
)
|
||||
self.assertEqual(config_json, revived_config_json)
|
||||
self.assertEqual(ref_dir, dir(revived_instance))
|
||||
|
||||
# serialization roundtrip
|
||||
serialized = serialize_keras_object(instance)
|
||||
json_str = json.dumps(serialized)
|
||||
serialized_json = json.dumps(serialized, sort_keys=True, indent=4)
|
||||
with custom_object_scope(custom_objects):
|
||||
revived_instance = deserialize_keras_object(json.loads(json_str))
|
||||
revived_instance = deserialize_keras_object(
|
||||
json.loads(serialized_json)
|
||||
)
|
||||
revived_config = revived_instance.get_config()
|
||||
self.assertEqual(config, revived_config)
|
||||
revived_config_json = json.dumps(
|
||||
revived_config, sort_keys=True, indent=4
|
||||
)
|
||||
self.assertEqual(config_json, revived_config_json)
|
||||
self.assertEqual(ref_dir, dir(revived_instance))
|
||||
|
||||
def run_layer_test(
|
||||
|
@ -250,11 +250,11 @@ class Trainer:
|
||||
|
||||
def get_compile_config(self):
|
||||
# TODO
|
||||
raise NotImplementedError
|
||||
pass
|
||||
|
||||
def compile_from_config(self):
|
||||
# TODO
|
||||
raise NotImplementedError
|
||||
return {}
|
||||
|
||||
def _should_eval(self, epoch, validation_freq):
|
||||
epoch = epoch + 1 # one-index the user-facing epoch.
|
||||
|
Loading…
Reference in New Issue
Block a user