2023-04-21 22:01:17 +00:00
|
|
|
import copy
|
2023-04-23 18:07:50 +00:00
|
|
|
import inspect
|
2023-04-12 21:27:30 +00:00
|
|
|
import warnings
|
|
|
|
|
|
|
|
from tensorflow import nest
|
|
|
|
|
2023-04-12 22:41:35 +00:00
|
|
|
from keras_core import backend
|
2023-04-12 21:27:30 +00:00
|
|
|
from keras_core import operations as ops
|
2023-04-09 19:35:32 +00:00
|
|
|
from keras_core.layers.layer import Layer
|
2023-04-12 18:31:58 +00:00
|
|
|
from keras_core.models.model import Model
|
|
|
|
from keras_core.operations.function import Function
|
2023-04-24 00:00:19 +00:00
|
|
|
from keras_core.operations.function import make_node_key
|
|
|
|
from keras_core.saving import serialization_lib
|
2023-04-12 22:20:56 +00:00
|
|
|
from keras_core.utils import tracking
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Functional(Function, Model):
|
|
|
|
"""
|
|
|
|
Add support for extra call arguments compared to Function:
|
|
|
|
training, masks
|
|
|
|
|
|
|
|
Add support for arg standardization:
|
|
|
|
- list/dict duality
|
|
|
|
- upranking
|
|
|
|
|
|
|
|
Override .layers
|
|
|
|
|
|
|
|
Symbolic add_loss
|
|
|
|
"""
|
|
|
|
|
2023-04-12 22:20:56 +00:00
|
|
|
@tracking.no_automatic_dependency_tracking
|
2023-04-09 19:21:45 +00:00
|
|
|
def __init__(self, inputs, outputs, name=None, **kwargs):
|
2023-04-12 22:41:35 +00:00
|
|
|
if isinstance(inputs, dict):
|
|
|
|
for k, v in inputs.items():
|
|
|
|
if not isinstance(v, backend.KerasTensor):
|
|
|
|
raise ValueError(
|
2023-04-16 21:54:13 +00:00
|
|
|
"When providing `inputs` as a dict, all values in the dict "
|
2023-04-12 22:41:35 +00:00
|
|
|
f"must be KerasTensors. Received: inputs={inputs} including "
|
2023-04-12 22:43:56 +00:00
|
|
|
f"invalid value {v} of type {type(v)}"
|
|
|
|
)
|
2023-04-12 22:41:35 +00:00
|
|
|
if k != v.name:
|
2023-04-16 21:54:13 +00:00
|
|
|
# TODO: maybe make this a warning
|
2023-04-12 22:41:35 +00:00
|
|
|
raise ValueError(
|
2023-04-16 21:54:13 +00:00
|
|
|
"When providing `inputs` as a dict, all keys in the dict "
|
2023-04-12 22:41:35 +00:00
|
|
|
"must match the names of the corresponding tensors. "
|
|
|
|
f"Received key '{k}' mapping to value {v} which has name '{v.name}'. "
|
2023-04-12 22:43:56 +00:00
|
|
|
f"Change the tensor name to '{k}' (via `Input(..., name='{k}')`)"
|
|
|
|
)
|
2023-04-16 21:54:13 +00:00
|
|
|
elif isinstance(inputs, (list, tuple)):
|
|
|
|
for x in inputs:
|
|
|
|
if not isinstance(x, backend.KerasTensor):
|
|
|
|
raise ValueError(
|
|
|
|
"When providing `inputs` as a list/tuple, all values in the list/tuple "
|
|
|
|
f"must be KerasTensors. Received: inputs={inputs} including "
|
|
|
|
f"invalid value {x} of type {type(x)}"
|
|
|
|
)
|
|
|
|
elif not isinstance(inputs, backend.KerasTensor):
|
|
|
|
raise ValueError(
|
|
|
|
f"Unrecognized type for `inputs`: {inputs} (of type {type(inputs)})"
|
|
|
|
)
|
|
|
|
if isinstance(outputs, dict):
|
|
|
|
for k, v in outputs.items():
|
|
|
|
if not isinstance(v, backend.KerasTensor):
|
|
|
|
raise ValueError(
|
|
|
|
"When providing `outputs` as a dict, all values in the dict "
|
|
|
|
f"must be KerasTensors. Received: outputs={outputs} including "
|
|
|
|
f"invalid value {v} of type {type(v)}"
|
|
|
|
)
|
|
|
|
elif isinstance(outputs, (list, tuple)):
|
|
|
|
for x in outputs:
|
|
|
|
if not isinstance(x, backend.KerasTensor):
|
|
|
|
raise ValueError(
|
|
|
|
"When providing `outputs` as a list/tuple, all values in the list/tuple "
|
|
|
|
f"must be KerasTensors. Received: outputs={outputs} including "
|
|
|
|
f"invalid value {x} of type {type(x)}"
|
|
|
|
)
|
|
|
|
elif not isinstance(outputs, backend.KerasTensor):
|
|
|
|
raise ValueError(
|
|
|
|
f"Unrecognized type for `outputs`: {outputs} (of type {type(outputs)})"
|
|
|
|
)
|
|
|
|
|
2023-04-21 22:01:17 +00:00
|
|
|
Function.__init__(self, inputs, outputs, name=name, **kwargs)
|
2023-04-09 19:21:45 +00:00
|
|
|
self._layers = self.layers
|
2023-04-12 18:00:14 +00:00
|
|
|
self.built = True
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def layers(self):
|
|
|
|
layers = []
|
2023-04-12 21:27:30 +00:00
|
|
|
for operation in self._operations:
|
2023-04-09 19:21:45 +00:00
|
|
|
if isinstance(operation, Layer):
|
|
|
|
layers.append(operation)
|
|
|
|
return layers
|
|
|
|
|
2023-04-13 00:12:57 +00:00
|
|
|
def call(self, inputs, training=None, mask=None):
|
2023-04-09 19:21:45 +00:00
|
|
|
# Add support for traning, masking
|
2023-04-12 21:27:30 +00:00
|
|
|
inputs = self._standardize_inputs(inputs)
|
2023-04-09 19:21:45 +00:00
|
|
|
if mask is None:
|
|
|
|
masks = [None] * len(inputs)
|
|
|
|
else:
|
|
|
|
masks = self._flatten_to_reference_inputs(mask)
|
2023-04-12 21:27:30 +00:00
|
|
|
for x, mask in zip(inputs, masks):
|
2023-04-12 22:43:56 +00:00
|
|
|
x._keras_mask = mask
|
2023-04-21 22:01:17 +00:00
|
|
|
outputs = self._run_through_graph(
|
2023-04-09 19:21:45 +00:00
|
|
|
inputs, operation_fn=lambda op: operation_fn(op, training=training)
|
|
|
|
)
|
2023-04-21 22:01:17 +00:00
|
|
|
return unpack_singleton(outputs)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-04-13 00:12:57 +00:00
|
|
|
def compute_output_spec(self, inputs, training=None, mask=None):
|
|
|
|
# From Function
|
2023-04-12 21:27:30 +00:00
|
|
|
return super().compute_output_spec(inputs)
|
2023-04-12 22:43:56 +00:00
|
|
|
|
2023-04-12 22:41:35 +00:00
|
|
|
def _assert_input_compatibility(self, *args):
|
|
|
|
return super(Model, self)._assert_input_compatibility(*args)
|
2023-04-12 21:27:30 +00:00
|
|
|
|
2023-04-16 01:51:10 +00:00
|
|
|
def _flatten_to_reference_inputs(self, inputs, allow_extra_keys=True):
|
2023-04-12 21:27:30 +00:00
|
|
|
if isinstance(inputs, dict):
|
|
|
|
ref_inputs = self._inputs_struct
|
|
|
|
if not nest.is_nested(ref_inputs):
|
|
|
|
ref_inputs = [self._nested_inputs]
|
|
|
|
if isinstance(ref_inputs, dict):
|
|
|
|
# In the case that the graph is constructed with dict input
|
|
|
|
# tensors, We will use the original dict key to map with the
|
|
|
|
# keys in the input data. Note that the model.inputs is using
|
|
|
|
# nest.flatten to process the input tensors, which means the
|
|
|
|
# dict input tensors are ordered by their keys.
|
|
|
|
ref_input_names = sorted(ref_inputs.keys())
|
|
|
|
else:
|
|
|
|
ref_input_names = [
|
|
|
|
inp._keras_history.operation.name for inp in ref_inputs
|
|
|
|
]
|
|
|
|
# Raise an warning if there are more input data comparing to input
|
|
|
|
# tensor
|
2023-04-16 01:51:10 +00:00
|
|
|
if allow_extra_keys and len(inputs) > len(ref_input_names):
|
2023-04-12 21:27:30 +00:00
|
|
|
warnings.warn(
|
|
|
|
"Input dict contained keys {} which did not match any "
|
|
|
|
"model input. They will be ignored by the model.".format(
|
|
|
|
[n for n in inputs.keys() if n not in ref_input_names]
|
|
|
|
),
|
|
|
|
stacklevel=2,
|
|
|
|
)
|
|
|
|
# Flatten in the order `Input`s were passed during Model
|
|
|
|
# construction.
|
|
|
|
return [inputs[n] for n in ref_input_names]
|
2023-04-16 01:51:10 +00:00
|
|
|
# Otherwise both ref inputs and inputs will already be in same order.
|
2023-04-12 21:27:30 +00:00
|
|
|
return nest.flatten(inputs)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-04-12 21:27:30 +00:00
|
|
|
def _adjust_input_rank(self, flat_inputs):
|
|
|
|
flat_ref_shapes = [x.shape for x in self._inputs]
|
|
|
|
adjusted = []
|
|
|
|
for x, ref_shape in zip(flat_inputs, flat_ref_shapes):
|
|
|
|
x_rank = len(x.shape)
|
|
|
|
ref_rank = len(ref_shape)
|
|
|
|
if x_rank == ref_rank:
|
|
|
|
adjusted.append(x)
|
|
|
|
continue
|
|
|
|
if x_rank == ref_rank + 1:
|
|
|
|
if x.shape[-1] == 1:
|
|
|
|
adjusted.append(ops.squeeze(x, axis=-1))
|
|
|
|
continue
|
|
|
|
if x_rank == ref_rank - 1:
|
|
|
|
if ref_shape[-1] == 1:
|
|
|
|
adjusted.append(ops.expand_dims(x, axis=-1))
|
|
|
|
continue
|
|
|
|
raise ValueError(
|
|
|
|
f"Invalid input shape for input {x}. Expected shape "
|
|
|
|
f"{ref_shape}, but input has incompatible shape {x.shape}"
|
|
|
|
)
|
|
|
|
# Add back metadata.
|
|
|
|
for i in range(len(flat_inputs)):
|
|
|
|
if hasattr(flat_inputs[i], "_keras_history"):
|
|
|
|
adjusted[i]._keras_history = flat_inputs[i]._keras_history
|
|
|
|
if hasattr(flat_inputs[i], "_keras_mask"):
|
|
|
|
adjusted[i]._keras_mask = flat_inputs[i]._keras_mask
|
|
|
|
return adjusted
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
def _standardize_inputs(self, inputs):
|
2023-04-12 21:27:30 +00:00
|
|
|
flat_inputs = self._flatten_to_reference_inputs(inputs)
|
|
|
|
return self._adjust_input_rank(flat_inputs)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
def add_loss(self, loss):
|
2023-04-12 18:00:14 +00:00
|
|
|
# Symbolic only. TODO
|
2023-04-09 19:21:45 +00:00
|
|
|
raise NotImplementedError
|
|
|
|
|
2023-04-23 18:07:50 +00:00
|
|
|
def get_config(self):
|
2023-04-21 22:01:17 +00:00
|
|
|
if not functional_like_constructor(self.__class__):
|
|
|
|
# Subclassed networks are not serializable
|
|
|
|
# (unless serialization is implemented by
|
|
|
|
# the author of the subclassed network).
|
|
|
|
return Model.get_config(self)
|
|
|
|
|
2023-04-23 18:07:50 +00:00
|
|
|
config = {
|
|
|
|
"name": self.name,
|
|
|
|
"trainable": self.trainable,
|
|
|
|
}
|
2023-04-21 22:01:17 +00:00
|
|
|
# Build a map from a layer unique name (make_node_key)
|
|
|
|
# to the index of the nodes that are saved in the config.
|
|
|
|
# Only nodes in network_nodes are saved.
|
|
|
|
node_reindexing_map = {}
|
|
|
|
for operation in self.operations:
|
|
|
|
if issubclass(operation.__class__, Functional):
|
|
|
|
# Functional models start with a pre-existing node
|
|
|
|
# linking their input to output.
|
|
|
|
kept_nodes = 1
|
|
|
|
else:
|
|
|
|
kept_nodes = 0
|
|
|
|
for original_node_index, node in enumerate(
|
|
|
|
operation._inbound_nodes
|
|
|
|
):
|
|
|
|
node_key = make_node_key(operation, original_node_index)
|
|
|
|
if node_key in self._nodes:
|
|
|
|
# i.e. we mark it to be saved
|
|
|
|
node_reindexing_map[node_key] = kept_nodes
|
|
|
|
kept_nodes += 1
|
|
|
|
|
|
|
|
# serialize and save the layers in layer_configs
|
|
|
|
layer_configs = []
|
|
|
|
for operation in self.operations: # From the earliest layers on.
|
|
|
|
filtered_inbound_nodes = []
|
|
|
|
for original_node_index, node in enumerate(
|
|
|
|
operation._inbound_nodes
|
|
|
|
):
|
|
|
|
node_key = make_node_key(operation, original_node_index)
|
|
|
|
if node_key in self._nodes:
|
|
|
|
# The node is relevant to the model:
|
|
|
|
# add to filtered_inbound_nodes.
|
|
|
|
node_data = serialize_node(node, node_reindexing_map)
|
|
|
|
if node_data is not None:
|
|
|
|
filtered_inbound_nodes.append(node_data)
|
|
|
|
|
|
|
|
layer_config = serialization_lib.serialize_keras_object(operation)
|
|
|
|
layer_config["name"] = operation.name
|
|
|
|
layer_config["inbound_nodes"] = filtered_inbound_nodes
|
|
|
|
layer_configs.append(layer_config)
|
|
|
|
config["layers"] = layer_configs
|
|
|
|
|
|
|
|
# Gather info about inputs and outputs.
|
|
|
|
model_inputs = []
|
|
|
|
for tensor in self._inputs:
|
|
|
|
operation = tensor._keras_history[0]
|
|
|
|
node_index = tensor._keras_history[1]
|
|
|
|
tensor_index = tensor._keras_history[2]
|
|
|
|
node_key = make_node_key(operation, node_index)
|
|
|
|
if node_key not in self._nodes:
|
|
|
|
continue
|
|
|
|
new_node_index = node_reindexing_map[node_key]
|
|
|
|
model_inputs.append([operation.name, new_node_index, tensor_index])
|
|
|
|
config["input_layers"] = model_inputs
|
|
|
|
model_outputs = []
|
|
|
|
for tensor in self._outputs:
|
|
|
|
operation = tensor._keras_history[0]
|
|
|
|
node_index = tensor._keras_history[1]
|
|
|
|
tensor_index = tensor._keras_history[2]
|
|
|
|
node_key = make_node_key(operation, node_index)
|
|
|
|
if node_key not in self._nodes:
|
|
|
|
continue
|
|
|
|
new_node_index = node_reindexing_map[node_key]
|
|
|
|
model_outputs.append([operation.name, new_node_index, tensor_index])
|
|
|
|
config["output_layers"] = model_outputs
|
|
|
|
return copy.deepcopy(config)
|
2023-04-24 21:58:38 +00:00
|
|
|
|
|
|
|
@classmethod
|
2023-04-21 22:01:17 +00:00
|
|
|
def from_config(cls, config, custom_objects=None):
|
|
|
|
functional_config_keys = [
|
|
|
|
"name",
|
|
|
|
"layers",
|
|
|
|
"input_layers",
|
|
|
|
"output_layers",
|
|
|
|
]
|
|
|
|
is_functional_config = all(
|
|
|
|
key in config for key in functional_config_keys
|
|
|
|
)
|
|
|
|
argspec = inspect.getfullargspec(cls.__init__)
|
|
|
|
functional_init_args = inspect.getfullargspec(Functional.__init__).args[
|
|
|
|
1:
|
|
|
|
]
|
|
|
|
revivable_as_functional = (
|
|
|
|
cls in {Functional, Model}
|
|
|
|
or argspec.args[1:] == functional_init_args
|
|
|
|
or (argspec.varargs == "args" and argspec.varkw == "kwargs")
|
|
|
|
)
|
|
|
|
if is_functional_config and revivable_as_functional:
|
|
|
|
# Revive Functional model
|
|
|
|
# (but not Functional subclasses with a custom __init__)
|
|
|
|
return cls._from_config(config, custom_objects=custom_objects)
|
|
|
|
|
|
|
|
# Either the model has a custom __init__, or the config
|
|
|
|
# does not contain all the information necessary to
|
|
|
|
# revive a Functional model. This happens when the user creates
|
|
|
|
# subclassed models where `get_config()` is returning
|
|
|
|
# insufficient information to be considered a Functional model.
|
|
|
|
# In this case, we fall back to provide all config into the
|
|
|
|
# constructor of the class.
|
|
|
|
try:
|
|
|
|
return cls(**config)
|
|
|
|
except TypeError as e:
|
|
|
|
raise TypeError(
|
|
|
|
"Unable to revive model from config. When overriding "
|
|
|
|
"the `get_config()` method, make sure that the "
|
|
|
|
"returned config contains all items used as arguments "
|
|
|
|
f"in the constructor to {cls}, "
|
|
|
|
"which is the default behavior. "
|
|
|
|
"You can override this default behavior by defining a "
|
|
|
|
"`from_config(cls, config)` class method to specify "
|
|
|
|
"how to create an "
|
|
|
|
f"instance of {cls.__name__} from its config.\n\n"
|
|
|
|
f"Received config={config}\n\n"
|
|
|
|
f"Error encountered during deserialization: {e}"
|
|
|
|
)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _from_config(cls, config, custom_objects=None):
|
|
|
|
"""Instantiates a Model from its config (output of `get_config()`)."""
|
|
|
|
# Layer instances created during
|
|
|
|
# the graph reconstruction process
|
|
|
|
created_layers = {}
|
|
|
|
|
|
|
|
# Dictionary mapping layer instances to
|
|
|
|
# node data that specifies a layer call.
|
|
|
|
# It acts as a queue that maintains any unprocessed
|
|
|
|
# layer call until it becomes possible to process it
|
|
|
|
# (i.e. until the input tensors to the call all exist).
|
|
|
|
unprocessed_nodes = {}
|
|
|
|
|
|
|
|
def add_unprocessed_node(layer, node_data):
|
|
|
|
"""Add node to layer list
|
|
|
|
|
|
|
|
Arg:
|
|
|
|
layer: layer object
|
|
|
|
node_data: Node data specifying layer call
|
|
|
|
"""
|
|
|
|
if layer not in unprocessed_nodes:
|
|
|
|
unprocessed_nodes[layer] = [node_data]
|
|
|
|
else:
|
|
|
|
unprocessed_nodes[layer].append(node_data)
|
|
|
|
|
|
|
|
def process_node(layer, node_data):
|
|
|
|
"""Reconstruct node by linking to inbound layers
|
|
|
|
|
|
|
|
Args:
|
|
|
|
layer: Layer to process
|
|
|
|
node_data: List of layer configs
|
|
|
|
"""
|
|
|
|
args, kwargs = deserialize_node(node_data, created_layers)
|
|
|
|
# Call layer on its inputs, thus creating the node
|
|
|
|
# and building the layer if needed.
|
|
|
|
layer(*args, **kwargs)
|
|
|
|
|
|
|
|
def process_layer(layer_data):
|
|
|
|
"""Deserializes a layer, then call it on appropriate inputs.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
layer_data: layer config dict.
|
|
|
|
"""
|
|
|
|
layer_name = layer_data["name"]
|
|
|
|
|
|
|
|
# Instantiate layer.
|
|
|
|
layer = serialization_lib.deserialize_keras_object(
|
|
|
|
layer_data, custom_objects=custom_objects
|
|
|
|
)
|
|
|
|
created_layers[layer_name] = layer
|
|
|
|
|
|
|
|
# Gather layer inputs.
|
|
|
|
inbound_nodes_data = layer_data["inbound_nodes"]
|
|
|
|
for node_data in inbound_nodes_data:
|
|
|
|
# We don't process nodes (i.e. make layer calls)
|
|
|
|
# on the fly because the inbound node may not yet exist,
|
|
|
|
# in case of layer shared at different topological depths
|
|
|
|
# (e.g. a model such as A(B(A(B(x)))))
|
|
|
|
add_unprocessed_node(layer, node_data)
|
|
|
|
|
|
|
|
# First, we create all layers and enqueue nodes to be processed
|
|
|
|
for layer_data in config["layers"]:
|
|
|
|
process_layer(layer_data)
|
|
|
|
|
|
|
|
# Then we process nodes in order of layer depth.
|
|
|
|
# Nodes that cannot yet be processed (if the inbound node
|
|
|
|
# does not yet exist) are re-enqueued, and the process
|
|
|
|
# is repeated until all nodes are processed.
|
|
|
|
while unprocessed_nodes:
|
|
|
|
for layer_data in config["layers"]:
|
|
|
|
layer = created_layers[layer_data["name"]]
|
|
|
|
|
|
|
|
# Process all nodes in layer, if not yet processed
|
|
|
|
if layer in unprocessed_nodes:
|
|
|
|
node_data_list = unprocessed_nodes[layer]
|
|
|
|
|
|
|
|
# Process nodes in order
|
|
|
|
node_index = 0
|
|
|
|
while node_index < len(node_data_list):
|
|
|
|
node_data = node_data_list[node_index]
|
|
|
|
try:
|
|
|
|
process_node(layer, node_data)
|
|
|
|
|
|
|
|
# If the node does not have all inbound layers
|
|
|
|
# available, stop processing and continue later
|
|
|
|
except IndexError:
|
|
|
|
break
|
|
|
|
|
|
|
|
node_index += 1
|
|
|
|
|
|
|
|
# If not all nodes processed then store unprocessed nodes
|
|
|
|
if node_index < len(node_data_list):
|
|
|
|
unprocessed_nodes[layer] = node_data_list[node_index:]
|
|
|
|
# If all nodes processed remove the layer
|
|
|
|
else:
|
|
|
|
del unprocessed_nodes[layer]
|
|
|
|
|
|
|
|
# Create lits of input and output tensors and return new class
|
|
|
|
name = config.get("name")
|
|
|
|
input_tensors = []
|
|
|
|
output_tensors = []
|
|
|
|
for layer_data in config["input_layers"]:
|
|
|
|
layer_name, node_index, tensor_index = layer_data
|
|
|
|
assert layer_name in created_layers
|
|
|
|
layer = created_layers[layer_name]
|
|
|
|
layer_output_tensors = layer._inbound_nodes[
|
|
|
|
node_index
|
|
|
|
].output_tensors
|
|
|
|
input_tensors.append(layer_output_tensors[tensor_index])
|
|
|
|
for layer_data in config["output_layers"]:
|
|
|
|
layer_name, node_index, tensor_index = layer_data
|
|
|
|
assert layer_name in created_layers
|
|
|
|
layer = created_layers[layer_name]
|
|
|
|
layer_output_tensors = layer._inbound_nodes[
|
|
|
|
node_index
|
|
|
|
].output_tensors
|
|
|
|
output_tensors.append(layer_output_tensors[tensor_index])
|
|
|
|
return cls(inputs=input_tensors, outputs=output_tensors, name=name)
|
2023-04-24 00:00:19 +00:00
|
|
|
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
def operation_fn(operation, training):
|
2023-04-12 22:20:56 +00:00
|
|
|
def call(*args, **kwargs):
|
2023-04-12 21:27:30 +00:00
|
|
|
if (
|
|
|
|
hasattr(operation, "_call_has_training_arg")
|
|
|
|
and operation._call_has_training_arg()
|
|
|
|
):
|
2023-04-09 19:21:45 +00:00
|
|
|
kwargs["training"] = training
|
2023-04-12 22:20:56 +00:00
|
|
|
return operation(*args, **kwargs)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
return call
|
2023-04-23 18:07:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
def functional_like_constructor(cls):
|
2023-04-23 21:01:46 +00:00
|
|
|
init_args = inspect.getfullargspec(cls.__init__).args[1:]
|
2023-04-24 00:00:19 +00:00
|
|
|
functional_init_args = inspect.getfullargspec(Functional.__init__).args[1:]
|
2023-04-23 21:01:46 +00:00
|
|
|
if init_args == functional_init_args:
|
|
|
|
return True
|
|
|
|
return False
|
2023-04-23 18:07:50 +00:00
|
|
|
|
|
|
|
|
2023-04-21 22:01:17 +00:00
|
|
|
def unpack_singleton(x):
|
|
|
|
if len(x) == 1:
|
|
|
|
return x[0]
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def serialize_node(node, node_reindexing_map):
|
|
|
|
if not node.input_tensors:
|
|
|
|
# Does not need to be serialized.
|
|
|
|
return
|
|
|
|
|
|
|
|
args = node.arguments.args
|
|
|
|
kwargs = node.arguments.kwargs
|
|
|
|
return {
|
|
|
|
"args": serialization_lib.serialize_keras_object(args),
|
|
|
|
"kwargs": serialization_lib.serialize_keras_object(kwargs),
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def deserialize_node(node_data, created_layers):
|
|
|
|
"""Return (args, kwargs) for calling the node layer."""
|
|
|
|
if not node_data:
|
|
|
|
return [], {}
|
|
|
|
|
|
|
|
if isinstance(node_data, list):
|
|
|
|
# Legacy case.
|
|
|
|
input_tensors = []
|
|
|
|
for input_data in node_data:
|
|
|
|
inbound_layer_name = input_data[0]
|
|
|
|
inbound_node_index = input_data[1]
|
|
|
|
inbound_tensor_index = input_data[2]
|
|
|
|
if len(input_data) == 3:
|
|
|
|
kwargs = {}
|
|
|
|
elif len(input_data) == 4:
|
|
|
|
kwargs = input_data[3]
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
"Cannot deserialize the model (invalid config data?)"
|
|
|
|
)
|
|
|
|
inbound_layer = created_layers[inbound_layer_name]
|
|
|
|
|
|
|
|
# Raise an error if the corresponding layer node
|
|
|
|
# has not yet been created
|
|
|
|
if len(inbound_layer._inbound_nodes) <= inbound_node_index:
|
|
|
|
raise IndexError(
|
|
|
|
"Layer node index out of bounds.\n"
|
|
|
|
f"inbound_layer = {inbound_layer}\n"
|
|
|
|
f"inbound_layer._inbound_nodes = {inbound_layer._inbound_nodes}\n"
|
|
|
|
f"inbound_node_index = {inbound_node_index}"
|
|
|
|
)
|
|
|
|
inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
|
|
|
|
input_tensors.append(
|
|
|
|
inbound_node.output_tensors[inbound_tensor_index]
|
|
|
|
)
|
|
|
|
return [unpack_singleton(input_tensors)], kwargs
|
|
|
|
|
|
|
|
args = serialization_lib.deserialize_keras_object(node_data["args"])
|
|
|
|
kwargs = serialization_lib.deserialize_keras_object(node_data["kwargs"])
|
|
|
|
|
|
|
|
def convert_revived_tensor(x):
|
|
|
|
if isinstance(x, backend.KerasTensor):
|
|
|
|
history = x._pre_serialization_keras_history
|
|
|
|
if history is None:
|
|
|
|
return x
|
|
|
|
layer = created_layers.get(history[0], None)
|
|
|
|
if layer is None:
|
|
|
|
raise ValueError(f"Unknown layer: {history[0]}")
|
|
|
|
inbound_node_index = history[1]
|
|
|
|
inbound_tensor_index = history[1]
|
|
|
|
if len(layer._inbound_nodes) <= inbound_node_index:
|
|
|
|
raise ValueError(
|
|
|
|
"Layer node index out of bounds.\n"
|
|
|
|
f"inbound_layer = {layer}\n"
|
|
|
|
f"inbound_layer._inbound_nodes = {layer._inbound_nodes}\n"
|
|
|
|
f"inbound_node_index = {inbound_node_index}"
|
|
|
|
)
|
|
|
|
inbound_node = layer._inbound_nodes[inbound_node_index]
|
|
|
|
return inbound_node.output_tensors[inbound_tensor_index]
|
|
|
|
return x
|
|
|
|
|
|
|
|
args = nest.map_structure(convert_revived_tensor, args)
|
|
|
|
kwargs = nest.map_structure(convert_revived_tensor, kwargs)
|
|
|
|
return args, kwargs
|