2023-04-09 19:21:45 +00:00
|
|
|
"""Layer is an Operation with state.
|
|
|
|
|
|
|
|
Takes care of:
|
|
|
|
|
|
|
|
- Weights / variables (and tracking thereof)
|
|
|
|
- deferred build
|
|
|
|
- trainable argument value inference
|
|
|
|
- masking
|
|
|
|
- autocasting
|
|
|
|
|
|
|
|
And some more magic:
|
|
|
|
|
|
|
|
- add_loss
|
|
|
|
- metric tracking
|
|
|
|
- RNG seed tracking
|
2023-04-23 02:26:17 +00:00
|
|
|
- activity regularization
|
2023-04-09 19:21:45 +00:00
|
|
|
"""
|
2023-04-12 18:31:58 +00:00
|
|
|
import collections
|
|
|
|
import inspect
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow import nest
|
|
|
|
|
2023-04-09 19:21:45 +00:00
|
|
|
from keras_core import backend
|
2023-04-21 17:00:32 +00:00
|
|
|
from keras_core import initializers
|
2023-04-27 23:02:31 +00:00
|
|
|
from keras_core import mixed_precision
|
2023-04-23 02:26:17 +00:00
|
|
|
from keras_core import regularizers
|
2023-04-09 19:21:45 +00:00
|
|
|
from keras_core import utils
|
2023-04-09 19:53:37 +00:00
|
|
|
from keras_core.api_export import keras_core_export
|
2023-04-12 18:31:58 +00:00
|
|
|
from keras_core.backend import KerasTensor
|
2023-05-03 22:33:40 +00:00
|
|
|
from keras_core.backend.common import global_state
|
2023-04-12 18:31:58 +00:00
|
|
|
from keras_core.layers import input_spec
|
|
|
|
from keras_core.metrics.metric import Metric
|
|
|
|
from keras_core.operations.operation import Operation
|
|
|
|
from keras_core.utils import summary_utils
|
2023-05-06 21:34:46 +00:00
|
|
|
from keras_core.utils import traceback_utils
|
2023-05-15 01:04:38 +00:00
|
|
|
from keras_core.utils.shape_utils import map_shape_structure
|
2023-04-12 18:31:58 +00:00
|
|
|
from keras_core.utils.tracking import Tracker
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
2023-04-09 19:53:37 +00:00
|
|
|
@keras_core_export(["keras_core.Layer", "keras_core.layers.Layer"])
|
2023-04-09 19:21:45 +00:00
|
|
|
class Layer(Operation):
|
2023-04-23 02:26:17 +00:00
|
|
|
def __init__(
|
2023-04-26 20:46:23 +00:00
|
|
|
self,
|
|
|
|
*,
|
|
|
|
activity_regularizer=None,
|
|
|
|
trainable=True,
|
|
|
|
dtype=None,
|
2023-04-27 23:02:31 +00:00
|
|
|
autocast=True,
|
2023-04-26 20:46:23 +00:00
|
|
|
name=None,
|
2023-04-23 02:26:17 +00:00
|
|
|
):
|
2023-05-11 22:06:44 +00:00
|
|
|
self._lock = False
|
2023-04-09 19:21:45 +00:00
|
|
|
super().__init__(name=name)
|
2023-04-23 02:26:17 +00:00
|
|
|
self.activity_regularizer = regularizers.get(activity_regularizer)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
self.built = False
|
2023-04-27 23:02:31 +00:00
|
|
|
self.dtype_policy = mixed_precision.resolve_policy(dtype)
|
|
|
|
self.autocast = autocast
|
2023-04-09 19:21:45 +00:00
|
|
|
self.input_spec = None
|
|
|
|
|
2023-04-27 23:02:31 +00:00
|
|
|
self._trainable = trainable
|
2023-04-09 19:21:45 +00:00
|
|
|
self._layers = []
|
|
|
|
self._metrics = []
|
2023-04-18 23:21:27 +00:00
|
|
|
self._seed_generators = []
|
2023-04-09 19:21:45 +00:00
|
|
|
self._losses = []
|
2023-05-08 18:34:55 +00:00
|
|
|
self._trainable_variables = []
|
|
|
|
self._non_trainable_variables = []
|
2023-04-09 19:21:45 +00:00
|
|
|
self._supports_masking = not utils.is_default(self.compute_mask)
|
2023-05-08 23:44:25 +00:00
|
|
|
self._allow_non_tensor_positional_args = False
|
2023-04-09 19:21:45 +00:00
|
|
|
self._build_shapes_dict = None
|
2023-04-13 00:12:57 +00:00
|
|
|
self._call_signature_parameters = [
|
|
|
|
p.name for p in inspect.signature(self.call).parameters.values()
|
|
|
|
]
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
self._tracker = Tracker(
|
|
|
|
{
|
2023-05-08 18:34:55 +00:00
|
|
|
"trainable_variables": (
|
|
|
|
lambda x: isinstance(x, backend.Variable) and x.trainable,
|
|
|
|
self._trainable_variables,
|
|
|
|
),
|
|
|
|
"non_trainable_variables": (
|
|
|
|
lambda x: isinstance(x, backend.Variable)
|
|
|
|
and not x.trainable,
|
|
|
|
self._non_trainable_variables,
|
2023-04-09 19:21:45 +00:00
|
|
|
),
|
|
|
|
"metrics": (lambda x: isinstance(x, Metric), self._metrics),
|
|
|
|
"layers": (
|
2023-04-12 18:00:14 +00:00
|
|
|
lambda x: isinstance(x, Layer)
|
|
|
|
and not isinstance(x, Metric),
|
2023-04-09 19:21:45 +00:00
|
|
|
self._layers,
|
|
|
|
),
|
2023-04-19 01:45:30 +00:00
|
|
|
"seed_generators": (
|
|
|
|
lambda x: isinstance(x, backend.random.SeedGenerator),
|
|
|
|
self._seed_generators,
|
|
|
|
),
|
2023-04-09 19:21:45 +00:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
@utils.default
|
|
|
|
def build(self, input_shape):
|
|
|
|
self.built = True
|
|
|
|
|
|
|
|
def get_build_config(self):
|
|
|
|
"""Returns a dictionary with the layer's input shape.
|
|
|
|
|
|
|
|
This method returns a config dict that can be used by
|
|
|
|
`build_from_config(config)` to create all states (e.g. Variables and
|
|
|
|
Lookup tables) needed by the layer.
|
|
|
|
|
|
|
|
By default, the config only contains the input shape that the layer
|
|
|
|
was built with. If you're writing a custom layer that creates state in
|
|
|
|
an unusual way, you should override this method to make sure this state
|
|
|
|
is already created when Keras attempts to load its value upon model
|
|
|
|
loading.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A dict containing the input shape associated with the layer.
|
|
|
|
"""
|
|
|
|
if self._build_shapes_dict is not None:
|
|
|
|
if len(self._build_shapes_dict) == 1:
|
|
|
|
return {
|
|
|
|
"input_shape": tuple(self._build_shapes_dict.values())[0],
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
return {"shapes_dict": self._build_shapes_dict}
|
|
|
|
|
|
|
|
def build_from_config(self, config):
|
|
|
|
"""Builds the layer's states with the supplied config dict.
|
|
|
|
|
|
|
|
By default, this method calls the `build(config["input_shape"])` method,
|
|
|
|
which creates weights based on the layer's input shape in the supplied
|
|
|
|
config. If your config contains other information needed to load the
|
|
|
|
layer's state, you should override this method.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
config: Dict containing the input shape associated with this layer.
|
|
|
|
"""
|
|
|
|
if config:
|
|
|
|
if "input_shape" in config:
|
|
|
|
self.build(config["input_shape"])
|
2023-04-21 22:01:17 +00:00
|
|
|
self._build_shapes_dict = config
|
2023-04-09 19:21:45 +00:00
|
|
|
elif "shapes_dict" in config:
|
|
|
|
self.build(**config["shapes_dict"])
|
2023-04-21 22:01:17 +00:00
|
|
|
self._build_shapes_dict = config["shapes_dict"]
|
2023-04-25 23:12:14 +00:00
|
|
|
self.built = True
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
def add_variable(
|
|
|
|
self,
|
|
|
|
shape,
|
|
|
|
initializer,
|
|
|
|
dtype=None,
|
2023-04-17 21:55:17 +00:00
|
|
|
trainable=True,
|
2023-04-09 19:21:45 +00:00
|
|
|
regularizer=None,
|
|
|
|
constraint=None,
|
|
|
|
name=None,
|
|
|
|
):
|
|
|
|
# TODO: handle layout
|
|
|
|
self._check_super_called()
|
2023-04-21 17:00:32 +00:00
|
|
|
initializer = initializers.get(initializer)
|
2023-04-09 19:21:45 +00:00
|
|
|
variable = backend.Variable(
|
2023-04-26 17:29:40 +00:00
|
|
|
initializer=initializer,
|
|
|
|
shape=shape,
|
2023-04-27 23:02:31 +00:00
|
|
|
dtype=dtype or self.variable_dtype,
|
2023-04-09 19:21:45 +00:00
|
|
|
trainable=trainable,
|
|
|
|
name=name,
|
|
|
|
)
|
|
|
|
# Will be added to layer.losses
|
|
|
|
variable.regularizer = regularizer
|
|
|
|
variable.constraint = constraint
|
2023-05-08 18:34:55 +00:00
|
|
|
if trainable:
|
|
|
|
self._trainable_variables.append(variable)
|
|
|
|
# Prevent double-tracking
|
|
|
|
self._tracker.stored_ids["trainable_variables"].add(id(variable))
|
|
|
|
else:
|
|
|
|
self._non_trainable_variables.append(variable)
|
|
|
|
# Prevent double-tracking
|
|
|
|
self._tracker.stored_ids["non_trainable_variables"].add(
|
|
|
|
id(variable)
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
return variable
|
|
|
|
|
2023-04-12 21:27:30 +00:00
|
|
|
def add_weight(self, *args, **kwargs):
|
|
|
|
return self.add_variable(*args, **kwargs)
|
|
|
|
|
2023-04-09 19:21:45 +00:00
|
|
|
@property
|
|
|
|
def trainable(self):
|
|
|
|
return self._trainable
|
|
|
|
|
|
|
|
@trainable.setter
|
|
|
|
def trainable(self, value):
|
|
|
|
"""Sets trainable attribute for the layer and its sublayers.
|
|
|
|
|
|
|
|
When this value is changed during training (e.g. with a
|
|
|
|
`Callback`) you need to call the parent
|
|
|
|
`Model.make_train_function` with `force=True` in order to
|
|
|
|
recompile the training graph.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
value: Boolean with the desired state for the layer's trainable
|
|
|
|
attribute.
|
|
|
|
"""
|
2023-05-09 18:11:56 +00:00
|
|
|
value = bool(value)
|
|
|
|
self._trainable = value
|
|
|
|
for v in self._trainable_variables:
|
|
|
|
v.trainable = value
|
|
|
|
for layer in self._layers:
|
|
|
|
layer.trainable = value
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def variables(self):
|
2023-05-08 20:51:15 +00:00
|
|
|
# Return only weights/rng state/metric variables
|
|
|
|
# of all Layers, recursively.
|
|
|
|
# Also deduplicate them.
|
|
|
|
variables = []
|
|
|
|
seen_ids = set()
|
|
|
|
for v in self._trainable_variables + self._non_trainable_variables:
|
|
|
|
if id(v) not in seen_ids:
|
|
|
|
variables.append(v)
|
|
|
|
seen_ids.add(id(v))
|
2023-04-18 23:21:27 +00:00
|
|
|
for m in self._metrics:
|
|
|
|
variables.extend(m.variables)
|
|
|
|
for sg in self._seed_generators:
|
|
|
|
variables.append(sg.state)
|
2023-05-08 20:51:15 +00:00
|
|
|
for layer in self._layers:
|
|
|
|
for v in layer.variables:
|
|
|
|
if id(v) not in seen_ids:
|
|
|
|
variables.append(v)
|
|
|
|
seen_ids.add(id(v))
|
2023-04-09 19:21:45 +00:00
|
|
|
return variables
|
|
|
|
|
|
|
|
@property
|
|
|
|
def trainable_variables(self):
|
2023-05-09 18:11:56 +00:00
|
|
|
if not self.trainable:
|
|
|
|
return []
|
2023-04-09 19:21:45 +00:00
|
|
|
return [v for v in self.variables if v.trainable]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def non_trainable_variables(self):
|
2023-05-09 18:11:56 +00:00
|
|
|
if not self.trainable:
|
|
|
|
return self.variables
|
2023-04-09 19:21:45 +00:00
|
|
|
return [v for v in self.variables if not v.trainable]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def weights(self):
|
2023-04-26 21:54:00 +00:00
|
|
|
# Return only "own weights" of all Layers, recursively.
|
|
|
|
# Also deduplicate them.
|
|
|
|
weights = []
|
|
|
|
seen_ids = set()
|
2023-05-08 18:34:55 +00:00
|
|
|
for w in self._trainable_variables + self._non_trainable_variables:
|
2023-04-26 21:54:00 +00:00
|
|
|
if id(w) not in seen_ids:
|
|
|
|
weights.append(w)
|
|
|
|
seen_ids.add(id(w))
|
2023-04-09 19:21:45 +00:00
|
|
|
for layer in self._layers:
|
2023-04-26 21:54:00 +00:00
|
|
|
for w in layer.weights:
|
|
|
|
if id(w) not in seen_ids:
|
|
|
|
weights.append(w)
|
|
|
|
seen_ids.add(id(w))
|
2023-04-09 19:21:45 +00:00
|
|
|
return weights
|
|
|
|
|
|
|
|
@property
|
|
|
|
def trainable_weights(self):
|
2023-05-09 18:11:56 +00:00
|
|
|
if not self.trainable:
|
|
|
|
return []
|
2023-04-09 19:21:45 +00:00
|
|
|
return [v for v in self.weights if v.trainable]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def non_trainable_weights(self):
|
2023-05-09 18:11:56 +00:00
|
|
|
if not self.trainable:
|
|
|
|
return self.weights
|
2023-04-09 19:21:45 +00:00
|
|
|
return [v for v in self.weights if not v.trainable]
|
|
|
|
|
|
|
|
def get_weights(self):
|
|
|
|
return [v.numpy() for v in self.weights]
|
|
|
|
|
|
|
|
def set_weights(self, weights):
|
|
|
|
layer_weights = self.weights
|
|
|
|
if len(layer_weights) != len(weights):
|
|
|
|
raise ValueError(
|
|
|
|
f"You called `set_weights(weights)` on layer '{self.name}' "
|
2023-04-27 03:42:23 +00:00
|
|
|
f"with a weight list of length {len(weights)}, but the layer "
|
|
|
|
f"was expecting {len(layer_weights)} weights."
|
2023-04-09 19:21:45 +00:00
|
|
|
)
|
|
|
|
for variable, value in zip(layer_weights, weights):
|
|
|
|
if variable.shape != value.shape:
|
|
|
|
raise ValueError(
|
|
|
|
f"Layer {self.name} weight shape {variable.shape} "
|
|
|
|
"is not compatible with provided weight "
|
|
|
|
f"shape {value.shape}."
|
|
|
|
)
|
|
|
|
variable.assign(value)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def dtype(self):
|
|
|
|
"""The dtype of the state (weights) of the layer."""
|
|
|
|
return self.variable_dtype
|
|
|
|
|
|
|
|
@property
|
|
|
|
def compute_dtype(self):
|
|
|
|
"""The dtype of the computations performed by the layer."""
|
|
|
|
return self.dtype_policy.compute_dtype
|
|
|
|
|
|
|
|
@property
|
|
|
|
def variable_dtype(self):
|
|
|
|
"""The dtype of the state (weights) of the layer."""
|
2023-04-27 23:02:31 +00:00
|
|
|
return self.dtype_policy.variable_dtype
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def supports_masking(self):
|
|
|
|
"""Whether this layer supports computing a mask using `compute_mask`."""
|
|
|
|
return self._supports_masking
|
|
|
|
|
|
|
|
@supports_masking.setter
|
|
|
|
def supports_masking(self, value):
|
|
|
|
self._supports_masking = value
|
|
|
|
|
|
|
|
@utils.default
|
2023-04-28 20:49:38 +00:00
|
|
|
def compute_mask(self, inputs, previous_mask):
|
|
|
|
return previous_mask
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-05-06 21:34:46 +00:00
|
|
|
@traceback_utils.filter_traceback
|
2023-04-09 19:21:45 +00:00
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
self._check_super_called()
|
|
|
|
|
2023-04-28 22:10:53 +00:00
|
|
|
#####################################
|
2023-04-12 21:27:30 +00:00
|
|
|
# 1. Convert any array arguments to tensors of correct dtype.
|
2023-04-09 19:21:45 +00:00
|
|
|
def maybe_convert(x):
|
2023-04-28 21:22:29 +00:00
|
|
|
if backend.is_tensor(x):
|
|
|
|
if (
|
|
|
|
self.autocast
|
|
|
|
and backend.is_float_dtype(x.dtype)
|
|
|
|
and x.dtype != self.compute_dtype
|
|
|
|
):
|
|
|
|
return backend.cast(x, dtype=self.compute_dtype)
|
2023-05-15 05:45:20 +00:00
|
|
|
return x
|
2023-04-28 21:22:29 +00:00
|
|
|
elif isinstance(x, backend.KerasTensor):
|
|
|
|
if (
|
|
|
|
self.autocast
|
|
|
|
and backend.is_float_dtype(x.dtype)
|
|
|
|
and x.dtype != self.compute_dtype
|
|
|
|
):
|
|
|
|
x.dtype = self.compute_dtype
|
2023-05-15 05:45:20 +00:00
|
|
|
return x
|
|
|
|
elif hasattr(x, "__array__"):
|
|
|
|
return backend.convert_to_tensor(x, dtype=self.compute_dtype)
|
2023-04-09 19:21:45 +00:00
|
|
|
return x
|
|
|
|
|
|
|
|
args = nest.map_structure(maybe_convert, args)
|
|
|
|
kwargs = nest.map_structure(maybe_convert, kwargs)
|
|
|
|
|
2023-04-28 22:10:53 +00:00
|
|
|
##########################################################
|
|
|
|
# 2. Enforce that only tensors can be passed positionally.
|
2023-05-08 23:44:25 +00:00
|
|
|
if not self._allow_non_tensor_positional_args:
|
|
|
|
for arg in nest.flatten(args):
|
|
|
|
if not isinstance(arg, KerasTensor) and not backend.is_tensor(
|
|
|
|
arg
|
|
|
|
):
|
|
|
|
raise ValueError(
|
|
|
|
"Only input tensors may be passed as "
|
|
|
|
"positional arguments. The following argument value "
|
|
|
|
f"should be passed as a keyword argument: {arg} "
|
|
|
|
f"(of type {type(arg)})"
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-04-28 22:10:53 +00:00
|
|
|
# Caches info about `call()` signature, args, kwargs.
|
2023-04-28 20:49:38 +00:00
|
|
|
call_spec = CallSpec(self.call, args, kwargs)
|
|
|
|
|
2023-04-28 22:10:53 +00:00
|
|
|
############################################
|
|
|
|
# 3. Check input spec for 1st positional arg.
|
2023-04-12 22:20:56 +00:00
|
|
|
# TODO: consider extending this to all args and kwargs.
|
2023-04-28 20:49:38 +00:00
|
|
|
self._assert_input_compatibility(call_spec.first_arg)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-04-28 22:10:53 +00:00
|
|
|
################
|
|
|
|
# 4. Call build.
|
2023-04-28 20:49:38 +00:00
|
|
|
self._maybe_build(call_spec)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-04-28 22:10:53 +00:00
|
|
|
##########################
|
|
|
|
# 5. Infer training value
|
2023-04-09 19:21:45 +00:00
|
|
|
# Training phase for `Layer.call` is set via (in order of priority):
|
2023-04-27 03:42:23 +00:00
|
|
|
# (1) The `training` argument passed to this `Layer.call`, if not None
|
2023-04-09 19:21:45 +00:00
|
|
|
# (2) The training argument of an outer `Layer.call`.
|
2023-04-27 03:42:23 +00:00
|
|
|
# (4) Any non-None default value for `training` in the call signature
|
2023-04-09 19:21:45 +00:00
|
|
|
# (5) False (treating the layer as if it's in inference)
|
2023-04-28 22:10:53 +00:00
|
|
|
|
|
|
|
# Maintains info about the `Layer.call` stack
|
|
|
|
# across nested calls.
|
|
|
|
call_context = self._get_call_context()
|
|
|
|
|
|
|
|
# This is the value explicity passed by the user
|
|
|
|
training = call_spec.user_arguments_dict.get("training", None)
|
2023-04-09 19:21:45 +00:00
|
|
|
if training is None:
|
2023-04-28 22:10:53 +00:00
|
|
|
# Wasn't passed explicitly: use context value
|
2023-04-09 19:21:45 +00:00
|
|
|
training = call_context.training
|
|
|
|
if training is None:
|
2023-04-28 22:10:53 +00:00
|
|
|
# Get signature default value; else False
|
|
|
|
training = call_spec.arguments_dict.get("training", False)
|
2023-04-09 19:21:45 +00:00
|
|
|
call_context.training = training
|
|
|
|
if self._call_has_training_arg():
|
|
|
|
kwargs["training"] = training
|
|
|
|
|
2023-04-28 22:10:53 +00:00
|
|
|
##############################
|
|
|
|
# 6. Populate mask argument(s)
|
2023-04-28 20:49:38 +00:00
|
|
|
if self.supports_masking:
|
|
|
|
if len(call_spec.tensor_arguments_dict) == 1:
|
|
|
|
if (
|
|
|
|
"mask" in call_spec.argument_names
|
|
|
|
and call_spec.arguments_dict["mask"] is None
|
|
|
|
):
|
|
|
|
arg_name = list(call_spec.tensor_arguments_dict.keys())[0]
|
|
|
|
only_tensor_arg = call_spec.tensor_arguments_dict[arg_name]
|
|
|
|
mask = nest.map_structure(
|
|
|
|
lambda x: getattr(x, "_keras_mask", None),
|
|
|
|
only_tensor_arg,
|
|
|
|
)
|
|
|
|
kwargs["mask"] = mask
|
|
|
|
elif len(call_spec.tensor_arguments_dict) > 1:
|
|
|
|
for k, v in call_spec.tensor_arguments_dict.items():
|
|
|
|
expected_mask_arg_name = f"{k}_mask"
|
|
|
|
if expected_mask_arg_name in call_spec.argument_names:
|
|
|
|
if (
|
|
|
|
call_spec.arguments_dict[expected_mask_arg_name]
|
|
|
|
is None
|
|
|
|
):
|
|
|
|
mask = nest.map_structure(
|
|
|
|
lambda x: getattr(x, "_keras_mask", None), v
|
|
|
|
)
|
|
|
|
kwargs[expected_mask_arg_name] = mask
|
|
|
|
|
2023-04-28 22:10:53 +00:00
|
|
|
####################
|
|
|
|
# 7. Call the layer.
|
2023-04-27 23:23:08 +00:00
|
|
|
try:
|
|
|
|
with backend.name_scope(self.name):
|
|
|
|
if self.autocast and self.compute_dtype != self.variable_dtype:
|
|
|
|
# For mixed precision, we automatically cast layer variables
|
|
|
|
# (float ones only) to the compute dtype upon access.
|
|
|
|
with backend.AutocastScope(self.compute_dtype):
|
|
|
|
outputs = super().__call__(*args, **kwargs)
|
|
|
|
else:
|
2023-04-27 23:02:31 +00:00
|
|
|
outputs = super().__call__(*args, **kwargs)
|
2023-05-09 19:17:31 +00:00
|
|
|
if not self.built:
|
|
|
|
self.built = True
|
2023-04-27 23:23:08 +00:00
|
|
|
# Record activity regularizer loss.
|
|
|
|
if self.activity_regularizer is not None:
|
|
|
|
for output in nest.flatten(outputs):
|
|
|
|
if backend.is_tensor(output):
|
|
|
|
self.add_loss(self.activity_regularizer(output))
|
|
|
|
|
2023-04-28 20:49:38 +00:00
|
|
|
if self.supports_masking:
|
|
|
|
# Set masks on outputs,
|
|
|
|
# provided only the first positional input arg and its mask.
|
|
|
|
# TODO: consider extending this to all args and kwargs.
|
|
|
|
previous_mask = getattr(
|
|
|
|
call_spec.first_arg, "_keras_mask", None
|
|
|
|
)
|
|
|
|
self._set_mask_metadata(
|
|
|
|
call_spec.first_arg, outputs, previous_mask
|
|
|
|
)
|
2023-04-27 23:23:08 +00:00
|
|
|
finally:
|
|
|
|
# Destroy call context if we created it
|
|
|
|
self._maybe_reset_call_context()
|
2023-04-09 19:21:45 +00:00
|
|
|
return outputs
|
|
|
|
|
|
|
|
def call(self, *args, **kwargs):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2023-05-06 21:34:46 +00:00
|
|
|
@traceback_utils.filter_traceback
|
2023-04-09 19:21:45 +00:00
|
|
|
def stateless_call(
|
2023-05-04 21:52:00 +00:00
|
|
|
self,
|
|
|
|
trainable_variables,
|
|
|
|
non_trainable_variables,
|
|
|
|
*args,
|
|
|
|
return_losses=False,
|
|
|
|
**kwargs,
|
2023-04-09 19:21:45 +00:00
|
|
|
):
|
2023-05-04 21:52:00 +00:00
|
|
|
"""Call the layer without any side effects.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
trainable_variables: List of trainable variables of the model.
|
2023-05-14 18:41:50 +00:00
|
|
|
non_trainable_variables: List of non-trainable variables of the
|
|
|
|
model.
|
2023-05-04 21:52:00 +00:00
|
|
|
*args: Positional argumets to be passed to `call()`.
|
|
|
|
return_losses: If `True`, `stateless_call()` will return the list of
|
|
|
|
losses created during `call()` as part of its return values.
|
|
|
|
**kwargs: Keyword arguments to be passed to `call()`.
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-05-04 21:52:00 +00:00
|
|
|
Returns:
|
|
|
|
A tuple. By default, returns `(outputs, non_trainable_variables)`.
|
|
|
|
If `return_losses = True`, then returns
|
|
|
|
`(outputs, non_trainable_variables, losses)`.
|
|
|
|
|
|
|
|
Note: `non_trainable_variables` include not only non-trainable weights
|
|
|
|
such as `BatchNormalization` statistics, but also RNG seed state
|
|
|
|
(if there are any random operations part of the layer, such as dropout),
|
|
|
|
and `Metric` state (if there are any metrics attached to the layer).
|
|
|
|
These are all elements of state of the layer.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
```python
|
|
|
|
model = ...
|
|
|
|
data = ...
|
|
|
|
trainable_variables = model.trainable_variables
|
|
|
|
non_trainable_variables = model.non_trainable_variables
|
|
|
|
# Call the model with zero side effects
|
|
|
|
outputs, non_trainable_variables = model.stateless_call(
|
|
|
|
trainable_variables,
|
|
|
|
non_trainable_variables,
|
|
|
|
data,
|
|
|
|
)
|
|
|
|
# Attach the updated state to the model
|
|
|
|
# (until you do this, the model is still in its pre-call state).
|
2023-05-14 18:41:50 +00:00
|
|
|
for ref_var, value in zip(
|
|
|
|
model.non_trainable_variables, non_trainable_variables
|
|
|
|
):
|
2023-05-04 21:52:00 +00:00
|
|
|
ref_var.assign(value)
|
|
|
|
```
|
|
|
|
"""
|
2023-04-09 19:21:45 +00:00
|
|
|
self._check_super_called()
|
|
|
|
|
|
|
|
if not self.built:
|
|
|
|
raise ValueError(
|
2023-05-04 21:52:00 +00:00
|
|
|
f"To call stateless_call, {self.__class__.__name__} must be "
|
2023-04-27 03:42:23 +00:00
|
|
|
"built (i.e. its variables must have been already created). "
|
2023-04-09 19:21:45 +00:00
|
|
|
"You can build it by calling it on some data."
|
|
|
|
)
|
|
|
|
if len(trainable_variables) != len(self.trainable_variables):
|
|
|
|
raise ValueError(
|
|
|
|
"Argument `trainable_variables` must be a list of tensors "
|
2023-04-27 03:42:23 +00:00
|
|
|
"corresponding 1:1 to "
|
|
|
|
f"{self.__class__.__name__}().trainable_variables. "
|
|
|
|
f"Received list with length {len(trainable_variables)}, "
|
|
|
|
f"but expected {len(self.trainable_variables)} variables."
|
2023-04-09 19:21:45 +00:00
|
|
|
)
|
|
|
|
if len(non_trainable_variables) != len(self.non_trainable_variables):
|
|
|
|
raise ValueError(
|
|
|
|
"Argument `non_trainable_variables` must be a list of tensors "
|
2023-04-27 03:42:23 +00:00
|
|
|
"corresponding 1:1 to "
|
|
|
|
f"{self.__class__.__name__}().non_trainable_variables. "
|
|
|
|
f"Received list with length {len(non_trainable_variables)}, "
|
|
|
|
f"but expected {len(self.non_trainable_variables)} variables."
|
2023-04-09 19:21:45 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# Gather variable mapping
|
|
|
|
trainable_mapping = zip(self.trainable_variables, trainable_variables)
|
|
|
|
non_trainable_mapping = zip(
|
|
|
|
self.non_trainable_variables, non_trainable_variables
|
|
|
|
)
|
|
|
|
mapping = list(trainable_mapping) + list(non_trainable_mapping)
|
|
|
|
|
|
|
|
# Call in stateless scope
|
2023-05-04 21:52:00 +00:00
|
|
|
with backend.StatelessScope(
|
|
|
|
state_mapping=mapping, collect_losses=return_losses
|
|
|
|
) as scope:
|
2023-04-09 19:21:45 +00:00
|
|
|
outputs = self.call(*args, **kwargs)
|
|
|
|
|
|
|
|
# Gather updated non-trainable variables
|
|
|
|
non_trainable_variables = []
|
|
|
|
for v in self.non_trainable_variables:
|
|
|
|
new_v = scope.get_current_value(v)
|
|
|
|
if new_v is not None:
|
|
|
|
non_trainable_variables.append(new_v)
|
|
|
|
else:
|
|
|
|
non_trainable_variables.append(v)
|
2023-05-04 21:52:00 +00:00
|
|
|
|
|
|
|
if return_losses:
|
|
|
|
return outputs, non_trainable_variables, scope.losses[:]
|
2023-04-09 19:21:45 +00:00
|
|
|
return outputs, non_trainable_variables
|
|
|
|
|
|
|
|
def compute_output_spec(self, *args, **kwargs):
|
|
|
|
if utils.is_default(self.compute_output_shape):
|
|
|
|
return super().compute_output_spec(*args, **kwargs)
|
|
|
|
else:
|
|
|
|
# Use compute_output_shape() to return the right output spec
|
2023-04-28 20:49:38 +00:00
|
|
|
call_spec = CallSpec(self.call, args, kwargs)
|
2023-05-17 01:22:42 +00:00
|
|
|
shapes_dict = get_shapes_dict(self.compute_output_shape, call_spec)
|
2023-04-09 19:21:45 +00:00
|
|
|
if len(shapes_dict) == 1:
|
|
|
|
# Single arg: pass it positionally
|
|
|
|
input_shape = tuple(shapes_dict.values())[0]
|
|
|
|
output_shape = self.compute_output_shape(input_shape)
|
|
|
|
else:
|
|
|
|
# More than one shape: pass them by name.
|
|
|
|
output_shape = self.compute_output_shape(**shapes_dict)
|
2023-05-12 03:53:38 +00:00
|
|
|
|
2023-05-05 01:25:08 +00:00
|
|
|
if (
|
|
|
|
isinstance(output_shape, list)
|
|
|
|
and output_shape
|
|
|
|
and isinstance(output_shape[0], (int, type(None)))
|
|
|
|
):
|
|
|
|
output_shape = tuple(output_shape)
|
|
|
|
if not isinstance(output_shape, (list, tuple, dict)):
|
|
|
|
try:
|
|
|
|
output_shape = tuple(output_shape)
|
|
|
|
except:
|
|
|
|
raise ValueError(
|
|
|
|
"Method `compute_output_shape()` of layer "
|
|
|
|
f"{self.__class__.__name__} is returning "
|
|
|
|
"a type that cannot be interpreted as a shape. "
|
|
|
|
"It should return a shape tuple. "
|
|
|
|
f"Received: {output_shape}"
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
if (
|
|
|
|
isinstance(output_shape, tuple)
|
|
|
|
and output_shape
|
|
|
|
and isinstance(output_shape[0], (int, type(None)))
|
|
|
|
):
|
|
|
|
return KerasTensor(output_shape, dtype=self.compute_dtype)
|
2023-05-14 18:41:50 +00:00
|
|
|
# Case: nested. Could be a tuple/list of shapes, or a dict of
|
2023-05-15 01:04:38 +00:00
|
|
|
# shapes. Could be deeply nested.
|
|
|
|
return map_shape_structure(
|
|
|
|
lambda s: KerasTensor(s, dtype=self.compute_dtype), output_shape
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
@utils.default
|
|
|
|
def compute_output_shape(self, *args, **kwargs):
|
|
|
|
return NotImplementedError
|
|
|
|
|
|
|
|
def add_loss(self, loss):
|
|
|
|
# Eager only.
|
|
|
|
losses = nest.flatten(loss)
|
|
|
|
for x in losses:
|
|
|
|
if not backend.is_tensor(x):
|
|
|
|
raise ValueError(
|
2023-04-27 03:42:23 +00:00
|
|
|
"`add_loss()` can only be called from inside `build()` or "
|
|
|
|
f"`call()`, on a tensor input. Received invalid value: {x}"
|
2023-04-09 19:21:45 +00:00
|
|
|
)
|
|
|
|
if backend.in_stateless_scope():
|
|
|
|
scope = backend.get_stateless_scope()
|
|
|
|
if scope.collect_losses:
|
|
|
|
for x in losses:
|
|
|
|
scope.add_loss(loss)
|
|
|
|
else:
|
|
|
|
self._losses.extend(losses)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def losses(self):
|
|
|
|
losses = self._losses[:]
|
|
|
|
for layer in self._layers:
|
|
|
|
losses.extend(layer._losses)
|
|
|
|
weight_regularization_losses = []
|
|
|
|
for v in self.trainable_weights:
|
2023-04-18 22:46:57 +00:00
|
|
|
regularizer = getattr(v, "regularizer", None)
|
2023-04-09 19:21:45 +00:00
|
|
|
if regularizer:
|
|
|
|
weight_regularization_losses.append(regularizer(v))
|
|
|
|
losses.extend(weight_regularization_losses)
|
|
|
|
return losses
|
|
|
|
|
2023-04-25 01:46:03 +00:00
|
|
|
def save_own_variables(self, store):
|
|
|
|
"""Saves the state of the layer.
|
|
|
|
|
|
|
|
You can override this method to take full control of how the state of
|
|
|
|
the layer is saved upon calling `model.save()`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
store: Dict where the state of the model will be saved.
|
|
|
|
"""
|
2023-05-08 18:34:55 +00:00
|
|
|
all_vars = self._trainable_variables + self._non_trainable_variables
|
2023-04-25 01:46:03 +00:00
|
|
|
for i, v in enumerate(all_vars):
|
|
|
|
store[f"{i}"] = np.array(v)
|
|
|
|
|
|
|
|
def load_own_variables(self, store):
|
|
|
|
"""Loads the state of the layer.
|
|
|
|
|
|
|
|
You can override this method to take full control of how the state of
|
|
|
|
the layer is loaded upon calling `keras.models.load_model()`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
store: Dict from which the state of the model will be loaded.
|
|
|
|
"""
|
2023-05-08 18:34:55 +00:00
|
|
|
all_vars = self._trainable_variables + self._non_trainable_variables
|
2023-04-25 01:46:03 +00:00
|
|
|
if len(store.keys()) != len(all_vars):
|
2023-04-25 19:59:32 +00:00
|
|
|
if len(all_vars) == 0 and not self.built:
|
|
|
|
raise ValueError(
|
|
|
|
f"Layer '{self.name}' was never built "
|
|
|
|
"and thus it doesn't have any variables. "
|
|
|
|
f"However the weights file lists {len(store.keys())} "
|
|
|
|
"variables for this layer. In most cases, "
|
|
|
|
"this indicates that you need to implement the "
|
|
|
|
"`def build_from_config(self, config)` method "
|
|
|
|
"on the layer. "
|
|
|
|
"You might also want to implement the method "
|
|
|
|
"that generates the config at saving time, "
|
|
|
|
"`def get_build_config(self)`. "
|
|
|
|
"The method `build_from_config()` is meant "
|
|
|
|
"to create the state "
|
|
|
|
"of the layer (i.e. its variables) upon deserialization.",
|
|
|
|
)
|
2023-04-25 01:46:03 +00:00
|
|
|
raise ValueError(
|
|
|
|
f"Layer '{self.name}' expected {len(all_vars)} variables, "
|
|
|
|
"but received "
|
|
|
|
f"{len(store.keys())} variables during loading. "
|
|
|
|
f"Expected: {[v.name for v in all_vars]}"
|
|
|
|
)
|
|
|
|
for i, v in enumerate(all_vars):
|
|
|
|
v.assign(store[f"{i}"])
|
|
|
|
|
2023-04-09 19:21:45 +00:00
|
|
|
def _clear_losses(self):
|
|
|
|
if backend.in_stateless_scope():
|
|
|
|
scope = backend.get_stateless_scope()
|
|
|
|
if scope.collect_losses:
|
|
|
|
for x in scope.losses:
|
|
|
|
if x in self._losses:
|
|
|
|
scope.losses.remove(x)
|
|
|
|
self._losses = []
|
|
|
|
|
|
|
|
def add_metric(self):
|
|
|
|
# Permanently disabled
|
|
|
|
raise NotImplementedError
|
2023-04-12 18:00:14 +00:00
|
|
|
|
2023-04-09 22:51:23 +00:00
|
|
|
def count_params(self):
|
|
|
|
"""Count the total number of scalars composing the weights.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An integer count.
|
|
|
|
"""
|
|
|
|
if not self.built:
|
|
|
|
raise ValueError(
|
|
|
|
"You tried to call `count_params` "
|
2023-05-06 21:34:46 +00:00
|
|
|
f"on layer '{self.name}', "
|
|
|
|
"but the layer isn't built. "
|
2023-04-09 22:51:23 +00:00
|
|
|
"You can build it manually via: "
|
|
|
|
f"`layer.build(input_shape)`."
|
|
|
|
)
|
|
|
|
return summary_utils.count_params(self.weights)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-04-28 20:49:38 +00:00
|
|
|
def _maybe_build(self, call_spec):
|
2023-04-09 19:21:45 +00:00
|
|
|
if not self.built:
|
2023-05-17 01:22:42 +00:00
|
|
|
shapes_dict = get_shapes_dict(self.build, call_spec)
|
2023-04-09 19:21:45 +00:00
|
|
|
self._build_shapes_dict = shapes_dict
|
2023-04-26 18:57:46 +00:00
|
|
|
failure = False
|
2023-04-09 19:21:45 +00:00
|
|
|
if len(shapes_dict) == 1:
|
|
|
|
# Single arg: pass it positionally
|
|
|
|
input_shape = tuple(shapes_dict.values())[0]
|
|
|
|
with backend.name_scope(self.name):
|
2023-04-26 18:59:47 +00:00
|
|
|
if utils.is_default(
|
|
|
|
self.build
|
|
|
|
) and might_have_unbuilt_state(self):
|
|
|
|
status = self._build_by_run_for_single_pos_arg(
|
|
|
|
input_shape
|
|
|
|
)
|
2023-04-26 18:57:46 +00:00
|
|
|
if not status:
|
|
|
|
failure = True
|
2023-04-26 17:59:32 +00:00
|
|
|
else:
|
|
|
|
self.build(input_shape)
|
2023-04-09 19:21:45 +00:00
|
|
|
else:
|
|
|
|
with backend.name_scope(self.name):
|
2023-04-28 20:49:38 +00:00
|
|
|
if utils.is_default(self.build):
|
|
|
|
if might_have_unbuilt_state(self):
|
|
|
|
status = self._build_by_run_for_kwargs(shapes_dict)
|
|
|
|
if not status:
|
|
|
|
failure = True
|
2023-04-26 17:59:32 +00:00
|
|
|
else:
|
2023-05-17 01:22:42 +00:00
|
|
|
self.build(**shapes_dict)
|
2023-04-26 19:15:04 +00:00
|
|
|
if failure:
|
2023-05-09 19:17:31 +00:00
|
|
|
if call_spec.eager:
|
|
|
|
# Will let the actual eager call do the state-building
|
|
|
|
return
|
2023-04-26 19:15:04 +00:00
|
|
|
raise ValueError(
|
|
|
|
f"Layer '{self.name}' looks like it has "
|
|
|
|
"unbuilt state, but Keras is not able to "
|
|
|
|
"trace the layer `call()` in order to "
|
|
|
|
"build it automatically. You must implement "
|
|
|
|
"the `def build(self, input_shape)` method on your "
|
|
|
|
"layer. It should create all variables used by the "
|
|
|
|
"layer (e.g. by calling `layer.build()` on all its "
|
|
|
|
"children layers)."
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
self.built = True
|
|
|
|
|
|
|
|
# Check input spec again (after build, since self.input_spec
|
|
|
|
# may have been updated
|
2023-04-28 20:49:38 +00:00
|
|
|
self._assert_input_compatibility(call_spec.first_arg)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-04-26 18:57:46 +00:00
|
|
|
def _build_by_run_for_single_pos_arg(self, input_shape):
|
2023-04-26 17:59:32 +00:00
|
|
|
# Case: all inputs are in the first arg (possibly nested).
|
|
|
|
if is_shape_tuple(input_shape):
|
|
|
|
input_shape = tuple(input_shape)
|
|
|
|
if isinstance(input_shape, list):
|
|
|
|
input_tensors = [
|
2023-05-09 18:11:56 +00:00
|
|
|
backend.KerasTensor(shape) for shape in input_shape
|
2023-04-26 17:59:32 +00:00
|
|
|
]
|
|
|
|
elif isinstance(input_shape, dict):
|
|
|
|
input_tensors = {
|
2023-05-09 18:11:56 +00:00
|
|
|
k: backend.KerasTensor(shape)
|
2023-04-26 17:59:32 +00:00
|
|
|
for k, shape in input_shape.items()
|
|
|
|
}
|
|
|
|
else:
|
2023-05-09 18:11:56 +00:00
|
|
|
input_tensors = backend.KerasTensor(input_shape)
|
2023-04-26 17:59:32 +00:00
|
|
|
try:
|
2023-05-09 18:11:56 +00:00
|
|
|
backend.compute_output_spec(self.call, input_tensors)
|
2023-04-26 17:59:32 +00:00
|
|
|
return True
|
2023-05-09 19:17:31 +00:00
|
|
|
except:
|
2023-04-26 17:59:32 +00:00
|
|
|
return False
|
|
|
|
|
2023-04-26 18:57:46 +00:00
|
|
|
def _build_by_run_for_kwargs(self, shapes_dict):
|
2023-04-26 17:59:32 +00:00
|
|
|
# Case: inputs were recorded as multiple keyword arguments.
|
|
|
|
if all(is_shape_tuple(s) for s in shapes_dict.values()):
|
|
|
|
# Case: all input keyword arguments were plain tensors.
|
|
|
|
input_tensors = {
|
2023-04-26 20:46:23 +00:00
|
|
|
# We strip the `_shape` suffix to recover kwarg names.
|
2023-05-17 01:22:42 +00:00
|
|
|
k.removesuffix("_shape"): backend.KerasTensor(shape)
|
2023-04-26 19:09:08 +00:00
|
|
|
for k, shape in shapes_dict.items()
|
2023-04-26 17:59:32 +00:00
|
|
|
}
|
|
|
|
try:
|
2023-05-09 18:11:56 +00:00
|
|
|
backend.compute_output_spec(self.call, **input_tensors)
|
2023-04-26 20:46:23 +00:00
|
|
|
return True
|
2023-05-09 19:17:31 +00:00
|
|
|
except:
|
2023-04-26 17:59:32 +00:00
|
|
|
return False
|
|
|
|
else:
|
|
|
|
# Not supported: nested input keyword arguments.
|
|
|
|
return False
|
|
|
|
|
2023-04-09 19:21:45 +00:00
|
|
|
def __repr__(self):
|
2023-04-28 20:49:38 +00:00
|
|
|
return (
|
|
|
|
f"<{self.__class__.__name__} "
|
|
|
|
f"name={self.name}, built={self.built}>"
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
def __str__(self):
|
2023-04-28 20:49:38 +00:00
|
|
|
return (
|
|
|
|
f"<{self.__class__.__name__} "
|
|
|
|
f"name={self.name}, built={self.built}>"
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
def __setattr__(self, name, value):
|
2023-05-11 22:06:44 +00:00
|
|
|
# Prevent users from attaching state to the
|
|
|
|
# layer before `super()` is called -- since that
|
|
|
|
# state would silently not be tracked.
|
|
|
|
if name != "_lock":
|
|
|
|
self._check_super_called()
|
|
|
|
# Track Variables, Layers, Metrics, SeedGenerators.
|
2023-04-09 19:21:45 +00:00
|
|
|
if hasattr(self, "_tracker"):
|
|
|
|
value = self._tracker.track(value)
|
|
|
|
return super().__setattr__(name, value)
|
|
|
|
|
|
|
|
def _check_super_called(self):
|
2023-05-11 22:06:44 +00:00
|
|
|
if getattr(self, "_lock", True):
|
2023-04-09 19:21:45 +00:00
|
|
|
raise RuntimeError(
|
|
|
|
f"In layer '{self.__class__.__name__}', you forgot to call "
|
2023-05-11 22:06:44 +00:00
|
|
|
"`super().__init__()` as the first statement "
|
|
|
|
"in the `__init__()` method. Go add it!"
|
2023-04-09 19:21:45 +00:00
|
|
|
)
|
|
|
|
|
2023-04-28 20:49:38 +00:00
|
|
|
def _assert_input_compatibility(self, arg_0):
|
|
|
|
if self.input_spec:
|
2023-04-09 19:21:45 +00:00
|
|
|
input_spec.assert_input_compatibility(
|
2023-04-28 20:49:38 +00:00
|
|
|
self.input_spec, arg_0, layer_name=self.name
|
2023-04-09 19:21:45 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
def _call_has_training_arg(self):
|
2023-04-13 00:12:57 +00:00
|
|
|
return "training" in self._call_signature_parameters
|
|
|
|
|
|
|
|
def _call_has_mask_arg(self):
|
|
|
|
return "mask" in self._call_signature_parameters
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
def _get_call_context(self):
|
|
|
|
"""Returns currently active `CallContext`."""
|
2023-04-30 02:15:11 +00:00
|
|
|
layer_call_ctx = global_state.get_global_attribute("current_call_ctx")
|
|
|
|
if layer_call_ctx is None:
|
2023-04-09 19:21:45 +00:00
|
|
|
# Enter new call context.
|
2023-04-30 02:15:11 +00:00
|
|
|
layer_call_ctx = CallContext(entry_layer=self)
|
|
|
|
global_state.set_global_attribute(
|
|
|
|
"current_call_ctx", layer_call_ctx
|
|
|
|
)
|
2023-04-09 19:21:45 +00:00
|
|
|
self._clear_losses()
|
2023-04-30 02:15:11 +00:00
|
|
|
return layer_call_ctx
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
def _maybe_reset_call_context(self):
|
2023-04-30 02:15:11 +00:00
|
|
|
layer_call_ctx = global_state.get_global_attribute("current_call_ctx")
|
|
|
|
if layer_call_ctx is None or layer_call_ctx.entry_layer == self:
|
|
|
|
global_state.set_global_attribute("current_call_ctx", None)
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
def _flatten_layers(self, include_self=True, recursive=True):
|
|
|
|
layers = []
|
|
|
|
if include_self:
|
|
|
|
layers.append(self)
|
|
|
|
seen_object_ids = set()
|
|
|
|
deque = collections.deque(self._layers)
|
|
|
|
while deque:
|
|
|
|
layer = deque.popleft()
|
|
|
|
if id(layer) in seen_object_ids:
|
|
|
|
continue
|
|
|
|
seen_object_ids.add(id(layer))
|
2023-04-18 22:46:57 +00:00
|
|
|
layers.append(layer)
|
2023-04-09 19:21:45 +00:00
|
|
|
# Introspect recursively through sublayers.
|
|
|
|
if recursive:
|
|
|
|
deque.extendleft(layer._layers)
|
2023-04-18 22:46:57 +00:00
|
|
|
return layers
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-04-23 02:26:17 +00:00
|
|
|
def _set_mask_metadata(self, inputs, outputs, previous_mask):
|
|
|
|
flat_outputs = nest.flatten(outputs)
|
|
|
|
|
|
|
|
mask_already_computed = all(
|
|
|
|
getattr(x, "_keras_mask", None) is not None for x in flat_outputs
|
|
|
|
)
|
|
|
|
if mask_already_computed:
|
|
|
|
return
|
|
|
|
|
|
|
|
output_masks = self.compute_mask(inputs, previous_mask)
|
|
|
|
if output_masks is None:
|
|
|
|
return
|
|
|
|
|
|
|
|
flat_masks = nest.flatten(output_masks)
|
|
|
|
for tensor, mask in zip(flat_outputs, flat_masks):
|
2023-04-28 20:49:38 +00:00
|
|
|
if getattr(tensor, "_keras_mask", None) is None:
|
|
|
|
tensor._keras_mask = mask
|
|
|
|
|
|
|
|
|
|
|
|
def is_backend_tensor_or_symbolic(x):
|
|
|
|
return backend.is_tensor(x) or isinstance(x, backend.KerasTensor)
|
|
|
|
|
|
|
|
|
|
|
|
class CallSpec:
|
|
|
|
def __init__(self, call_fn, args, kwargs):
|
|
|
|
sig = inspect.signature(call_fn)
|
|
|
|
bound_args = sig.bind(*args, **kwargs)
|
2023-04-28 22:10:53 +00:00
|
|
|
self.user_arguments_dict = {
|
|
|
|
k: v for k, v in bound_args.arguments.items()
|
|
|
|
}
|
2023-04-28 20:49:38 +00:00
|
|
|
bound_args.apply_defaults()
|
|
|
|
arg_dict = {}
|
|
|
|
arg_names = []
|
|
|
|
tensor_arg_dict = {}
|
|
|
|
tensor_args = []
|
|
|
|
tensor_arg_names = []
|
|
|
|
nested_tensor_arg_names = []
|
|
|
|
for name, value in bound_args.arguments.items():
|
|
|
|
arg_dict[name] = value
|
|
|
|
arg_names.append(name)
|
|
|
|
if is_backend_tensor_or_symbolic(value):
|
|
|
|
tensor_args.append(value)
|
|
|
|
tensor_arg_names.append(name)
|
|
|
|
tensor_arg_dict[name] = value
|
|
|
|
elif nest.is_nested(value):
|
|
|
|
flat_values = nest.flatten(value)
|
|
|
|
if all(is_backend_tensor_or_symbolic(x) for x in flat_values):
|
|
|
|
tensor_args.append(value)
|
|
|
|
tensor_arg_names.append(name)
|
|
|
|
tensor_arg_dict[name] = value
|
|
|
|
nested_tensor_arg_names.append(name)
|
|
|
|
elif any(is_backend_tensor_or_symbolic(x) for x in flat_values):
|
|
|
|
raise ValueError(
|
|
|
|
"In a nested call() argument, "
|
|
|
|
"you cannot mix tensors and non-tensors. "
|
|
|
|
"Received invalid mixed argument: "
|
|
|
|
f"{name}={value}"
|
|
|
|
)
|
|
|
|
self.arguments_dict = arg_dict
|
|
|
|
self.argument_names = arg_names
|
|
|
|
self.tensor_arguments_dict = tensor_arg_dict
|
|
|
|
self.tensor_arguments_names = tensor_arg_names
|
|
|
|
self.nested_tensor_argument_names = nested_tensor_arg_names
|
|
|
|
self.first_arg = arg_dict[arg_names[0]]
|
2023-05-09 19:17:53 +00:00
|
|
|
if all(
|
|
|
|
backend.is_tensor(x) for x in self.tensor_arguments_dict.values()
|
|
|
|
):
|
2023-05-09 19:17:31 +00:00
|
|
|
self.eager = True
|
|
|
|
else:
|
|
|
|
self.eager = False
|
2023-04-23 02:26:17 +00:00
|
|
|
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-04-28 20:49:38 +00:00
|
|
|
def get_arguments_dict(fn, args, kwargs):
|
2023-04-09 19:21:45 +00:00
|
|
|
"""Return a dict mapping argument names to their values."""
|
|
|
|
sig = inspect.signature(fn)
|
|
|
|
bound_args = sig.bind(*args, **kwargs)
|
|
|
|
arg_dict = {}
|
|
|
|
for name, value in bound_args.arguments.items():
|
|
|
|
arg_dict[name] = value
|
|
|
|
return arg_dict
|
|
|
|
|
|
|
|
|
2023-05-17 01:22:42 +00:00
|
|
|
def get_shapes_dict(target_fn, call_spec):
|
2023-04-09 19:21:45 +00:00
|
|
|
"""Convert the call() arguments dict into a dict of input shape arguments.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
```
|
2023-05-17 01:22:42 +00:00
|
|
|
>>> get_shapes_dict(self.build, call_spec)
|
2023-04-09 19:21:45 +00:00
|
|
|
{"input_a_shape": (2, 3)}
|
|
|
|
```
|
|
|
|
"""
|
2023-05-17 01:22:42 +00:00
|
|
|
expected_names = check_shapes_signature(target_fn, call_spec)
|
2023-04-09 19:21:45 +00:00
|
|
|
shapes_dict = {}
|
2023-04-28 20:49:38 +00:00
|
|
|
for k, v in call_spec.tensor_arguments_dict.items():
|
2023-05-04 02:56:56 +00:00
|
|
|
if k == "mask" or k.startswith("mask_"):
|
|
|
|
# Do not include mask tensors in shapes dict
|
|
|
|
continue
|
2023-05-08 20:51:15 +00:00
|
|
|
if k == "kwargs" or k == "args":
|
|
|
|
# Do not include catch-alls in shapes dict
|
|
|
|
continue
|
2023-05-17 01:22:42 +00:00
|
|
|
if expected_names is not None and f"{k}_shape" not in expected_names:
|
|
|
|
continue
|
2023-04-28 20:49:38 +00:00
|
|
|
if k in call_spec.nested_tensor_argument_names:
|
2023-04-09 19:21:45 +00:00
|
|
|
shapes_dict[f"{k}_shape"] = nest.map_structure(
|
|
|
|
lambda x: backend.standardize_shape(x.shape), v
|
|
|
|
)
|
2023-04-28 20:49:38 +00:00
|
|
|
else:
|
|
|
|
shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape)
|
2023-04-09 19:21:45 +00:00
|
|
|
return shapes_dict
|
|
|
|
|
|
|
|
|
2023-05-17 01:22:42 +00:00
|
|
|
def check_shapes_signature(target_fn, call_spec):
|
|
|
|
"""Asserts that the argument names in `target_fn` match arguments in `call`.
|
|
|
|
|
|
|
|
We use this to check that `build()` and `compute_output_shape()` arguments
|
|
|
|
align with `call()` arguments.
|
|
|
|
|
|
|
|
For instance if `build()` has the signature
|
|
|
|
`def build(self, a_shape, b_shape)` we expect `call()` to accept the
|
|
|
|
arguments `a` and `b`.
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-05-17 01:22:42 +00:00
|
|
|
When there is a single argument accepted by `target_fn`, we do allow any
|
|
|
|
name and do not check the call signature.
|
2023-04-09 19:21:45 +00:00
|
|
|
|
2023-05-17 01:22:42 +00:00
|
|
|
Returns:
|
|
|
|
The list of arguments names expected by the `target_fn` or
|
|
|
|
`None` if any passed name is acceptable.
|
2023-04-09 19:21:45 +00:00
|
|
|
"""
|
2023-05-17 01:22:42 +00:00
|
|
|
if utils.is_default(target_fn):
|
|
|
|
return None
|
|
|
|
sig = inspect.signature(target_fn)
|
2023-04-09 19:21:45 +00:00
|
|
|
expected_names = []
|
|
|
|
for name, param in sig.parameters.items():
|
|
|
|
if param.kind in (
|
|
|
|
param.POSITIONAL_OR_KEYWORD,
|
|
|
|
param.POSITIONAL_ONLY,
|
|
|
|
param.KEYWORD_ONLY,
|
|
|
|
):
|
|
|
|
expected_names.append(name)
|
2023-05-17 01:22:42 +00:00
|
|
|
if len(expected_names) == 1:
|
|
|
|
return None
|
|
|
|
for name in expected_names:
|
|
|
|
method_name = target_fn.__name__
|
|
|
|
error_preamble = (
|
|
|
|
f"For a `{method_name}()` method with more than one argument, all "
|
|
|
|
"arguments should have a `_shape` suffix and match an argument "
|
|
|
|
f"from `call()`. E.g. `{method_name}(self, foo_shape, bar_shape)` "
|
|
|
|
"would match `call(self, foo, bar)`."
|
2023-04-09 19:21:45 +00:00
|
|
|
)
|
2023-05-17 01:22:42 +00:00
|
|
|
if not name.endswith("_shape"):
|
|
|
|
raise ValueError(
|
|
|
|
f"{error_preamble} Received `{method_name}()` argument "
|
|
|
|
f"`{name}`, which does not end in `_shape`."
|
|
|
|
)
|
|
|
|
expected_call_arg = name.removesuffix("_shape")
|
|
|
|
if expected_call_arg not in call_spec.arguments_dict:
|
|
|
|
raise ValueError(
|
|
|
|
f"{error_preamble} Received `{method_name}()` argument "
|
|
|
|
f"`{name}`, but `call()` does not have argument "
|
|
|
|
f"`{expected_call_arg}`."
|
|
|
|
)
|
|
|
|
return expected_names
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
class CallContext:
|
|
|
|
def __init__(self, entry_layer):
|
|
|
|
self.entry_layer = entry_layer
|
|
|
|
self.training = None
|
2023-04-26 17:29:40 +00:00
|
|
|
|
|
|
|
|
|
|
|
def is_shape_tuple(s):
|
|
|
|
return isinstance(s, (list, tuple)) and all(
|
|
|
|
d is None or isinstance(d, int) for d in s
|
|
|
|
)
|
2023-04-26 17:59:32 +00:00
|
|
|
|
|
|
|
|
2023-04-26 18:57:46 +00:00
|
|
|
def might_have_unbuilt_state(layer):
|
|
|
|
return any(not lr.built for lr in layer._layers)
|