keras/keras_core/models/functional.py

234 lines
9.3 KiB
Python
Raw Normal View History

2023-04-23 18:07:50 +00:00
import inspect
2023-04-12 21:27:30 +00:00
import warnings
from tensorflow import nest
from keras_core import backend
2023-04-12 21:27:30 +00:00
from keras_core import operations as ops
2023-04-09 19:35:32 +00:00
from keras_core.layers.layer import Layer
from keras_core.models.model import Model
from keras_core.operations.function import Function
2023-04-23 18:07:50 +00:00
from keras_core.utils import python_utils
2023-04-12 22:20:56 +00:00
from keras_core.utils import tracking
2023-04-09 19:21:45 +00:00
class Functional(Function, Model):
"""
Add support for extra call arguments compared to Function:
training, masks
Add support for arg standardization:
- list/dict duality
- upranking
Override .layers
Symbolic add_loss
"""
2023-04-23 18:05:04 +00:00
def __new__(cls, *args, **kwargs):
# Skip Model.__new__.
return Function.__new__(cls)
2023-04-12 22:20:56 +00:00
@tracking.no_automatic_dependency_tracking
2023-04-09 19:21:45 +00:00
def __init__(self, inputs, outputs, name=None, **kwargs):
# This is used by the Model class, since we have some logic to swap the
# class in the __new__ method, which will lead to __init__ get invoked
# twice. Using the skip_init to skip one of the invocation of __init__
# to avoid any side effects
if isinstance(inputs, dict):
for k, v in inputs.items():
if not isinstance(v, backend.KerasTensor):
raise ValueError(
2023-04-16 21:54:13 +00:00
"When providing `inputs` as a dict, all values in the dict "
f"must be KerasTensors. Received: inputs={inputs} including "
2023-04-12 22:43:56 +00:00
f"invalid value {v} of type {type(v)}"
)
if k != v.name:
2023-04-16 21:54:13 +00:00
# TODO: maybe make this a warning
raise ValueError(
2023-04-16 21:54:13 +00:00
"When providing `inputs` as a dict, all keys in the dict "
"must match the names of the corresponding tensors. "
f"Received key '{k}' mapping to value {v} which has name '{v.name}'. "
2023-04-12 22:43:56 +00:00
f"Change the tensor name to '{k}' (via `Input(..., name='{k}')`)"
)
2023-04-16 21:54:13 +00:00
elif isinstance(inputs, (list, tuple)):
for x in inputs:
if not isinstance(x, backend.KerasTensor):
raise ValueError(
"When providing `inputs` as a list/tuple, all values in the list/tuple "
f"must be KerasTensors. Received: inputs={inputs} including "
f"invalid value {x} of type {type(x)}"
)
elif not isinstance(inputs, backend.KerasTensor):
raise ValueError(
f"Unrecognized type for `inputs`: {inputs} (of type {type(inputs)})"
)
if isinstance(outputs, dict):
for k, v in outputs.items():
if not isinstance(v, backend.KerasTensor):
raise ValueError(
"When providing `outputs` as a dict, all values in the dict "
f"must be KerasTensors. Received: outputs={outputs} including "
f"invalid value {v} of type {type(v)}"
)
elif isinstance(outputs, (list, tuple)):
for x in outputs:
if not isinstance(x, backend.KerasTensor):
raise ValueError(
"When providing `outputs` as a list/tuple, all values in the list/tuple "
f"must be KerasTensors. Received: outputs={outputs} including "
f"invalid value {x} of type {type(x)}"
)
elif not isinstance(outputs, backend.KerasTensor):
raise ValueError(
f"Unrecognized type for `outputs`: {outputs} (of type {type(outputs)})"
)
2023-04-12 21:27:30 +00:00
super().__init__(inputs, outputs, name=name, **kwargs)
2023-04-09 19:21:45 +00:00
self._layers = self.layers
self.built = True
2023-04-09 19:21:45 +00:00
@property
def layers(self):
layers = []
2023-04-12 21:27:30 +00:00
for operation in self._operations:
2023-04-09 19:21:45 +00:00
if isinstance(operation, Layer):
layers.append(operation)
return layers
2023-04-13 00:12:57 +00:00
def call(self, inputs, training=None, mask=None):
2023-04-09 19:21:45 +00:00
# Add support for traning, masking
2023-04-12 21:27:30 +00:00
inputs = self._standardize_inputs(inputs)
2023-04-09 19:21:45 +00:00
if mask is None:
masks = [None] * len(inputs)
else:
masks = self._flatten_to_reference_inputs(mask)
2023-04-12 21:27:30 +00:00
for x, mask in zip(inputs, masks):
2023-04-12 22:43:56 +00:00
x._keras_mask = mask
2023-04-09 19:21:45 +00:00
return self._run_through_graph(
inputs, operation_fn=lambda op: operation_fn(op, training=training)
)
2023-04-13 00:12:57 +00:00
def compute_output_spec(self, inputs, training=None, mask=None):
# From Function
2023-04-12 21:27:30 +00:00
return super().compute_output_spec(inputs)
2023-04-12 22:43:56 +00:00
def _assert_input_compatibility(self, *args):
return super(Model, self)._assert_input_compatibility(*args)
2023-04-12 21:27:30 +00:00
2023-04-16 01:51:10 +00:00
def _flatten_to_reference_inputs(self, inputs, allow_extra_keys=True):
2023-04-12 21:27:30 +00:00
if isinstance(inputs, dict):
ref_inputs = self._inputs_struct
if not nest.is_nested(ref_inputs):
ref_inputs = [self._nested_inputs]
if isinstance(ref_inputs, dict):
# In the case that the graph is constructed with dict input
# tensors, We will use the original dict key to map with the
# keys in the input data. Note that the model.inputs is using
# nest.flatten to process the input tensors, which means the
# dict input tensors are ordered by their keys.
ref_input_names = sorted(ref_inputs.keys())
else:
ref_input_names = [
inp._keras_history.operation.name for inp in ref_inputs
]
# Raise an warning if there are more input data comparing to input
# tensor
2023-04-16 01:51:10 +00:00
if allow_extra_keys and len(inputs) > len(ref_input_names):
2023-04-12 21:27:30 +00:00
warnings.warn(
"Input dict contained keys {} which did not match any "
"model input. They will be ignored by the model.".format(
[n for n in inputs.keys() if n not in ref_input_names]
),
stacklevel=2,
)
# Flatten in the order `Input`s were passed during Model
# construction.
return [inputs[n] for n in ref_input_names]
2023-04-16 01:51:10 +00:00
# Otherwise both ref inputs and inputs will already be in same order.
2023-04-12 21:27:30 +00:00
return nest.flatten(inputs)
2023-04-09 19:21:45 +00:00
2023-04-12 21:27:30 +00:00
def _adjust_input_rank(self, flat_inputs):
flat_ref_shapes = [x.shape for x in self._inputs]
adjusted = []
for x, ref_shape in zip(flat_inputs, flat_ref_shapes):
x_rank = len(x.shape)
ref_rank = len(ref_shape)
if x_rank == ref_rank:
adjusted.append(x)
continue
if x_rank == ref_rank + 1:
if x.shape[-1] == 1:
adjusted.append(ops.squeeze(x, axis=-1))
continue
if x_rank == ref_rank - 1:
if ref_shape[-1] == 1:
adjusted.append(ops.expand_dims(x, axis=-1))
continue
raise ValueError(
f"Invalid input shape for input {x}. Expected shape "
f"{ref_shape}, but input has incompatible shape {x.shape}"
)
# Add back metadata.
for i in range(len(flat_inputs)):
if hasattr(flat_inputs[i], "_keras_history"):
adjusted[i]._keras_history = flat_inputs[i]._keras_history
if hasattr(flat_inputs[i], "_keras_mask"):
adjusted[i]._keras_mask = flat_inputs[i]._keras_mask
return adjusted
2023-04-09 19:21:45 +00:00
def _standardize_inputs(self, inputs):
2023-04-12 21:27:30 +00:00
flat_inputs = self._flatten_to_reference_inputs(inputs)
return self._adjust_input_rank(flat_inputs)
2023-04-09 19:21:45 +00:00
def add_loss(self, loss):
# Symbolic only. TODO
2023-04-09 19:21:45 +00:00
raise NotImplementedError
2023-04-23 18:07:50 +00:00
@python_utils.default
def get_config(self):
# Prepare base arguments
config = {
"name": self.name,
"trainable": self.trainable,
}
# Check whether the class has a constructor compatible with a Functional
# model or if it has a custom constructor.
if functional_like_constructor(self.__class__):
# Only return a Functional config if the constructor is the same
# as that of a Functional model. This excludes subclassed Functional
# models with a custom __init__.
config = {**config, **get_functional_config(self)}
else:
# Try to autogenerate config
xtra_args = set(config.keys())
if getattr(self, "_auto_get_config", False):
config.update(self._auto_config.config)
# Remove args non explicitly supported
argspec = inspect.getfullargspec(self.__init__)
if argspec.varkw != "kwargs":
for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
config.pop(key, None)
return config
2023-04-09 19:21:45 +00:00
def operation_fn(operation, training):
2023-04-12 22:20:56 +00:00
def call(*args, **kwargs):
2023-04-12 21:27:30 +00:00
if (
hasattr(operation, "_call_has_training_arg")
and operation._call_has_training_arg()
):
2023-04-09 19:21:45 +00:00
kwargs["training"] = training
2023-04-12 22:20:56 +00:00
return operation(*args, **kwargs)
2023-04-09 19:21:45 +00:00
return call
2023-04-23 18:07:50 +00:00
def functional_like_constructor(cls):
raise NotImplementedError
def get_functional_config(model):
raise NotImplementedError