2023-04-23 18:07:50 +00:00
|
|
|
import inspect
|
2023-04-12 21:27:30 +00:00
|
|
|
import warnings
|
|
|
|
|
|
|
|
from tensorflow import nest
|
|
|
|
|
2023-04-12 22:41:35 +00:00
|
|
|
from keras_core import backend
|
2023-04-12 21:27:30 +00:00
|
|
|
from keras_core import operations as ops
|
2023-04-09 19:35:32 +00:00
|
|
|
from keras_core.layers.layer import Layer
|
2023-04-12 18:31:58 +00:00
|
|
|
from keras_core.models.model import Model
|
|
|
|
from keras_core.operations.function import Function
|
2023-04-23 18:07:50 +00:00
|
|
|
from keras_core.utils import python_utils
|
2023-04-12 22:20:56 +00:00
|
|
|
from keras_core.utils import tracking
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Functional(Function, Model):
|
|
|
|
"""
|
|
|
|
Add support for extra call arguments compared to Function:
|
|
|
|
training, masks
|
|
|
|
|
|
|
|
Add support for arg standardization:
|
|
|
|
- list/dict duality
|
|
|
|
- upranking
|
|
|
|
|
|
|
|
Override .layers
|
|
|
|
|
|
|
|
Symbolic add_loss
|
|
|
|
"""
|
|
|
|
|
2023-04-23 18:05:04 +00:00
|
|
|
def __new__(cls, *args, **kwargs):
|
|
|
|
# Skip Model.__new__.
|
|
|
|
return Function.__new__(cls)
|
|
|
|
|
2023-04-12 22:20:56 +00:00
|
|
|
@tracking.no_automatic_dependency_tracking
|
2023-04-09 19:21:45 +00:00
|
|
|
def __init__(self, inputs, outputs, name=None, **kwargs):
|
|
|
|
# This is used by the Model class, since we have some logic to swap the
|
|
|
|
# class in the __new__ method, which will lead to __init__ get invoked
|
|
|
|
# twice. Using the skip_init to skip one of the invocation of __init__
|
|
|
|
# to avoid any side effects
|
2023-04-12 22:41:35 +00:00
|
|
|
if isinstance(inputs, dict):
|
|
|
|
for k, v in inputs.items():
|
|
|
|
if not isinstance(v, backend.KerasTensor):
|
|
|
|
raise ValueError(
|
2023-04-16 21:54:13 +00:00
|
|
|
"When providing `inputs` as a dict, all values in the dict "
|
2023-04-12 22:41:35 +00:00
|
|
|
f"must be KerasTensors. Received: inputs={inputs} including "
|
2023-04-12 22:43:56 +00:00
|
|
|
f"invalid value {v} of type {type(v)}"
|
|
|
|
)
|
2023-04-12 22:41:35 +00:00
|
|
|
if k != v.name:
|
2023-04-16 21:54:13 +00:00
|
|
|
# TODO: maybe make this a warning
|
2023-04-12 22:41:35 +00:00
|
|
|
raise ValueError(
|
2023-04-16 21:54:13 +00:00
|
|
|
"When providing `inputs` as a dict, all keys in the dict "
|
2023-04-12 22:41:35 +00:00
|
|
|
"must match the names of the corresponding tensors. "
|
|
|
|
f"Received key '{k}' mapping to value {v} which has name '{v.name}'. "
|
2023-04-12 22:43:56 +00:00
|
|
|
f"Change the tensor name to '{k}' (via `Input(..., name='{k}')`)"
|
|
|
|
)
|
2023-04-16 21:54:13 +00:00
|
|
|
elif isinstance(inputs, (list, tuple)):
|
|
|
|
for x in inputs:
|
|
|
|
if not isinstance(x, backend.KerasTensor):
|
|
|
|
raise ValueError(
|
|
|
|
"When providing `inputs` as a list/tuple, all values in the list/tuple "
|
|
|
|
f"must be KerasTensors. Received: inputs={inputs} including "
|
|
|
|
f"invalid value {x} of type {type(x)}"
|
|
|
|
)
|
|
|
|
elif not isinstance(inputs, backend.KerasTensor):
|
|
|
|
raise ValueError(
|
|
|
|
f"Unrecognized type for `inputs`: {inputs} (of type {type(inputs)})"
|
|
|
|
)
|
|
|
|
if isinstance(outputs, dict):
|
|
|
|
for k, v in outputs.items():
|
|
|
|
if not isinstance(v, backend.KerasTensor):
|
|
|
|
raise ValueError(
|
|
|
|
"When providing `outputs` as a dict, all values in the dict "
|
|
|
|
f"must be KerasTensors. Received: outputs={outputs} including "
|
|
|
|
f"invalid value {v} of type {type(v)}"
|
|
|
|
)
|
|
|
|
elif isinstance(outputs, (list, tuple)):
|
|
|
|
for x in outputs:
|
|
|
|
if not isinstance(x, backend.KerasTensor):
|
|
|
|
raise ValueError(
|
|
|
|
"When providing `outputs` as a list/tuple, all values in the list/tuple "
|
|
|
|
f"must be KerasTensors. Received: outputs={outputs} including "
|
|
|
|
f"invalid value {x} of type {type(x)}"
|
|
|
|
)
|
|
|
|
elif not isinstance(outputs, backend.KerasTensor):
|
|
|
|
raise ValueError(
|
|
|
|
f"Unrecognized type for `outputs`: {outputs} (of type {type(outputs)})"
|
|
|
|
)
|
|
|
|
|
2023-04-12 21:27:30 +00:00
|
|
|
super().__init__(inputs, outputs, name=name, **kwargs)
|
2023-04-09 19:21:45 +00:00
|
|
|
self._layers = self.layers
|
2023-04-12 18:00:14 +00:00
|
|
|
self.built = True
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def layers(self):
|
|
|
|
layers = []
|
2023-04-12 21:27:30 +00:00
|
|
|
for operation in self._operations:
|
2023-04-09 19:21:45 +00:00
|
|
|
if isinstance(operation, Layer):
|
|
|
|
layers.append(operation)
|
|
|
|
return layers
|
|
|
|
|
2023-04-13 00:12:57 +00:00
|
|
|
def call(self, inputs, training=None, mask=None):
|
2023-04-09 19:21:45 +00:00
|
|
|
# Add support for traning, masking
|
2023-04-12 21:27:30 +00:00
|
|
|
inputs = self._standardize_inputs(inputs)
|
2023-04-09 19:21:45 +00:00
|
|
|
if mask is None:
|
|
|
|
masks = [None] * len(inputs)
|
|
|
|
else:
|
|
|
|
masks = self._flatten_to_reference_inputs(mask)
|
2023-04-12 21:27:30 +00:00
|
|
|
for x, mask in zip(inputs, masks):
|
2023-04-12 22:43:56 +00:00
|
|
|
x._keras_mask = mask
|
2023-04-09 19:21:45 +00:00
|
|
|
return self._run_through_graph(
|
|
|
|
inputs, operation_fn=lambda op: operation_fn(op, training=training)
|
|
|
|
)
|
|
|
|
|
2023-04-13 00:12:57 +00:00
|
|
|
def compute_output_spec(self, inputs, training=None, mask=None):
|
|
|
|
# From Function
|
2023-04-12 21:27:30 +00:00
|
|
|
return super().compute_output_spec(inputs)
|
2023-04-12 22:43:56 +00:00
|
|
|
|
2023-04-12 22:41:35 +00:00
|
|
|
def _assert_input_compatibility(self, *args):
|
|
|
|
return super(Model, self)._assert_input_compatibility(*args)
|
2023-04-12 21:27:30 +00:00
|
|
|
|
2023-04-16 01:51:10 +00:00
|
|
|
def _flatten_to_reference_inputs(self, inputs, allow_extra_keys=True):
|
2023-04-12 21:27:30 +00:00
|
|
|
if isinstance(inputs, dict):
|
|
|
|
ref_inputs = self._inputs_struct
|
|
|
|
if not nest.is_nested(ref_inputs):
|
|
|
|
ref_inputs = [self._nested_inputs]
|
|
|
|
if isinstance(ref_inputs, dict):
|
|
|
|
# In the case that the graph is constructed with dict input
|
|
|
|
# tensors, We will use the original dict key to map with the
|
|
|
|
# keys in the input data. Note that the model.inputs is using
|
|
|
|
# nest.flatten to process the input tensors, which means the
|
|
|
|
# dict input tensors are ordered by their keys.
|
|
|
|
ref_input_names = sorted(ref_inputs.keys())
|
|
|
|
else:
|
|
|
|
ref_input_names = [
|
|
|
|
inp._keras_history.operation.name for inp in ref_inputs
|
|
|
|
]
|
|
|
|
# Raise an warning if there are more input data comparing to input
|
|
|
|
# tensor
|
2023-04-16 01:51:10 +00:00
|
|
|
if allow_extra_keys and len(inputs) > len(ref_input_names):
|
2023-04-12 21:27:30 +00:00
|
|
|
warnings.warn(
|
|
|
|
"Input dict contained keys {} which did not match any "
|
|
|
|
"model input. They will be ignored by the model.".format(
|
|
|
|
[n for n in inputs.keys() if n not in ref_input_names]
|
|
|
|
),
|
|
|
|
stacklevel=2,
|
|
|
|
)
|
|
|
|
# Flatten in the order `Input`s were passed during Model
|
|
|
|
# construction.
|
|
|
|
return [inputs[n] for n in ref_input_names]
|
2023-04-16 01:51:10 +00:00
|
|
|
# Otherwise both ref inputs and inputs will already be in same order.
|
2023-04-12 21:27:30 +00:00
|
|
|
return nest.flatten(inputs)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-04-12 21:27:30 +00:00
|
|
|
def _adjust_input_rank(self, flat_inputs):
|
|
|
|
flat_ref_shapes = [x.shape for x in self._inputs]
|
|
|
|
adjusted = []
|
|
|
|
for x, ref_shape in zip(flat_inputs, flat_ref_shapes):
|
|
|
|
x_rank = len(x.shape)
|
|
|
|
ref_rank = len(ref_shape)
|
|
|
|
if x_rank == ref_rank:
|
|
|
|
adjusted.append(x)
|
|
|
|
continue
|
|
|
|
if x_rank == ref_rank + 1:
|
|
|
|
if x.shape[-1] == 1:
|
|
|
|
adjusted.append(ops.squeeze(x, axis=-1))
|
|
|
|
continue
|
|
|
|
if x_rank == ref_rank - 1:
|
|
|
|
if ref_shape[-1] == 1:
|
|
|
|
adjusted.append(ops.expand_dims(x, axis=-1))
|
|
|
|
continue
|
|
|
|
raise ValueError(
|
|
|
|
f"Invalid input shape for input {x}. Expected shape "
|
|
|
|
f"{ref_shape}, but input has incompatible shape {x.shape}"
|
|
|
|
)
|
|
|
|
# Add back metadata.
|
|
|
|
for i in range(len(flat_inputs)):
|
|
|
|
if hasattr(flat_inputs[i], "_keras_history"):
|
|
|
|
adjusted[i]._keras_history = flat_inputs[i]._keras_history
|
|
|
|
if hasattr(flat_inputs[i], "_keras_mask"):
|
|
|
|
adjusted[i]._keras_mask = flat_inputs[i]._keras_mask
|
|
|
|
return adjusted
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
def _standardize_inputs(self, inputs):
|
2023-04-12 21:27:30 +00:00
|
|
|
flat_inputs = self._flatten_to_reference_inputs(inputs)
|
|
|
|
return self._adjust_input_rank(flat_inputs)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
def add_loss(self, loss):
|
2023-04-12 18:00:14 +00:00
|
|
|
# Symbolic only. TODO
|
2023-04-09 19:21:45 +00:00
|
|
|
raise NotImplementedError
|
|
|
|
|
2023-04-23 18:07:50 +00:00
|
|
|
@python_utils.default
|
|
|
|
def get_config(self):
|
|
|
|
# Prepare base arguments
|
|
|
|
config = {
|
|
|
|
"name": self.name,
|
|
|
|
"trainable": self.trainable,
|
|
|
|
}
|
|
|
|
# Check whether the class has a constructor compatible with a Functional
|
|
|
|
# model or if it has a custom constructor.
|
|
|
|
if functional_like_constructor(self.__class__):
|
|
|
|
# Only return a Functional config if the constructor is the same
|
|
|
|
# as that of a Functional model. This excludes subclassed Functional
|
|
|
|
# models with a custom __init__.
|
|
|
|
config = {**config, **get_functional_config(self)}
|
|
|
|
else:
|
|
|
|
# Try to autogenerate config
|
|
|
|
xtra_args = set(config.keys())
|
|
|
|
if getattr(self, "_auto_get_config", False):
|
|
|
|
config.update(self._auto_config.config)
|
|
|
|
# Remove args non explicitly supported
|
|
|
|
argspec = inspect.getfullargspec(self.__init__)
|
|
|
|
if argspec.varkw != "kwargs":
|
|
|
|
for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
|
|
|
|
config.pop(key, None)
|
|
|
|
return config
|
|
|
|
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
def operation_fn(operation, training):
|
2023-04-12 22:20:56 +00:00
|
|
|
def call(*args, **kwargs):
|
2023-04-12 21:27:30 +00:00
|
|
|
if (
|
|
|
|
hasattr(operation, "_call_has_training_arg")
|
|
|
|
and operation._call_has_training_arg()
|
|
|
|
):
|
2023-04-09 19:21:45 +00:00
|
|
|
kwargs["training"] = training
|
2023-04-12 22:20:56 +00:00
|
|
|
return operation(*args, **kwargs)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
return call
|
2023-04-23 18:07:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
def functional_like_constructor(cls):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
def get_functional_config(model):
|
|
|
|
raise NotImplementedError
|