keras/keras_core/layers/layer.py

1318 lines
50 KiB
Python
Raw Normal View History

2023-04-09 19:21:45 +00:00
"""Layer is an Operation with state.
Takes care of:
- Weights / variables (and tracking thereof)
- deferred build
- trainable argument value inference
- masking
- autocasting
And some more magic:
- add_loss
- metric tracking
- RNG seed tracking
2023-04-23 02:26:17 +00:00
- activity regularization
2023-04-09 19:21:45 +00:00
"""
import collections
import inspect
import warnings
import numpy as np
from tensorflow import nest
2023-04-09 19:21:45 +00:00
from keras_core import backend
2023-04-21 17:00:32 +00:00
from keras_core import initializers
2023-04-27 23:02:31 +00:00
from keras_core import mixed_precision
2023-04-23 02:26:17 +00:00
from keras_core import regularizers
2023-04-09 19:21:45 +00:00
from keras_core import utils
2023-04-09 19:53:37 +00:00
from keras_core.api_export import keras_core_export
from keras_core.backend import KerasTensor
2023-05-03 22:33:40 +00:00
from keras_core.backend.common import global_state
from keras_core.layers import input_spec
from keras_core.metrics.metric import Metric
from keras_core.operations.operation import Operation
from keras_core.utils import python_utils
from keras_core.utils import summary_utils
2023-05-06 21:34:46 +00:00
from keras_core.utils import traceback_utils
from keras_core.utils import tracking
from keras_core.utils.shape_utils import map_shape_structure
2023-04-09 19:21:45 +00:00
if backend.backend() == "tensorflow":
from keras_core.backend.tensorflow.layer import TFLayer as BackendLayer
elif backend.backend() == "jax":
from keras_core.backend.jax.layer import JaxLayer as BackendLayer
elif backend.backend() == "torch":
from keras_core.backend.torch.layer import TorchLayer as BackendLayer
else:
raise RuntimeError(
2023-05-30 16:19:46 +00:00
f"Backend '{backend.backend()}' must implement a layer mixin class."
)
2023-04-09 19:21:45 +00:00
2023-04-09 19:53:37 +00:00
@keras_core_export(["keras_core.Layer", "keras_core.layers.Layer"])
class Layer(BackendLayer, Operation):
2023-05-26 04:59:04 +00:00
"""This is the class from which all layers inherit.
A layer is a callable object that takes as input one or more tensors and
that outputs one or more tensors. It involves *computation*, defined
in the `call()` method, and a *state* (weight variables). State can be
created:
2023-06-06 19:57:43 +00:00
* in `__init__()`, for instance via `self.add_weight()`;
2023-05-26 04:59:04 +00:00
* in the optional `build()` method, which is invoked by the first
`__call__()` to the layer, and supplies the shape(s) of the input(s),
which may not have been known at initialization time.
Layers are recursively composable: If you assign a Layer instance as an
attribute of another Layer, the outer layer will start tracking the weights
created by the inner layer. Nested layers should be instantiated in the
`__init__()` method or `build()` method.
Users will just instantiate a layer and then treat it as a callable.
Args:
trainable: Boolean, whether the layer's variables should be trainable.
name: String name of the layer.
dtype: The dtype of the layer's computations and weights. Can also be a
`keras_core.mixed_precision.DTypePolicy`,
which allows the computation and
weight dtype to differ. Default of `None` means to use
`keras_core.mixed_precision.dtype_policy()`,
which is a `float32` policy unless set to different value
(via `keras_core.mixed_precision.set_dtype_policy()`).
Attributes:
name: The name of the layer (string).
dtype: The dtype of the layer's weights.
variable_dtype: Dtype of the layer's variables.
compute_dtype: The dtype of the layer's computations.
Layers automatically cast inputs to this dtype, which causes
the computations and output to also be in this dtype.
When mixed precision is used with a
`keras_core.mixed_precision.DTypePolicy`, this will be different
than `variable_dtype`.
trainable_weights: List of variables to be included in backprop.
non_trainable_weights: List of variables that should not be
included in backprop.
weights: The concatenation of the lists trainable_weights and
non_trainable_weights (in this order).
trainable: Whether the layer should be trained (boolean), i.e.
whether its potentially-trainable weights should be returned
as part of `layer.trainable_weights`.
input_spec: Optional (list of) `InputSpec` object(s) specifying the
constraints on inputs that can be accepted by the layer.
We recommend that descendants of `Layer` implement the following methods:
* `__init__()`: Defines custom layer attributes, and creates layer weights
that do not depend on input shapes, using `add_weight()`,
or other state.
* `build(self, input_shape)`: This method can be used to create weights that
depend on the shape(s) of the input(s), using `add_weight()`, or other
state. `__call__()` will automatically build the layer
(if it has not been built yet) by calling `build()`.
* `call(self, *args, **kwargs)`: Called in `__call__` after making
sure `build()` has been called. `call()` performs the logic of applying
the layer to the input arguments.
Two reserved keyword arguments you can optionally use in `call()` are:
1. `training` (boolean, whether the call is in inference mode or
training mode).
2. `mask` (boolean tensor encoding masked timesteps in the input,
used e.g. in RNN layers).
A typical signature for this method is `call(self, inputs)`, and user
could optionally add `training` and `mask` if the layer need them.
* `get_config(self)`: Returns a dictionary containing the configuration
used to initialize this layer. If the keys differ from the arguments
in `__init__()`, then override `from_config(self)` as well.
This method is used when saving
the layer or a model that contains this layer.
Examples:
Here's a basic example: a layer with two variables, `w` and `b`,
that returns `y = w . x + b`.
It shows how to implement `build()` and `call()`.
Variables set as attributes of a layer are tracked as weights
of the layers (in `layer.weights`).
```python
class SimpleDense(Layer):
def __init__(self, units=32):
super().__init__()
self.units = units
# Create the state of the layer (weights)
def build(self, input_shape):
self.kernel = self.add_weight(
shape=(input_shape[-1], self.units),
initializer="glorot_uniform",
trainable=True,
name="kernel",
)
self.bias = self.add_weight(
shape=(self.units,),
initializer="zeros",
trainable=True,
name="bias",
)
# Defines the computation
def call(self, inputs):
return ops.matmul(inputs, self.kernel) + self.bias
# Instantiates the layer.
linear_layer = SimpleDense(4)
# This will also call `build(input_shape)` and create the weights.
y = linear_layer(ops.ones((2, 2)))
assert len(linear_layer.weights) == 2
# These weights are trainable, so they're listed in `trainable_weights`:
assert len(linear_layer.trainable_weights) == 2
```
Besides trainable weights, updated via backpropagation during training,
layers can also have non-trainable weights. These weights are meant to
be updated manually during `call()`. Here's a example layer that computes
the running sum of its inputs:
```python
class ComputeSum(Layer):
def __init__(self, input_dim):
super(ComputeSum, self).__init__()
# Create a non-trainable weight.
self.total = self.add_weight(
shape=(),
initializer="zeros",
trainable=False,
name="total",
)
def call(self, inputs):
self.total.assign(self.total + ops.sum(inputs))
return self.total
my_sum = ComputeSum(2)
x = ops.ones((2, 2))
y = my_sum(x)
assert my_sum.weights == [my_sum.total]
assert my_sum.non_trainable_weights == [my_sum.total]
assert my_sum.trainable_weights == []
```
"""
2023-04-23 02:26:17 +00:00
def __init__(
2023-04-26 20:46:23 +00:00
self,
*,
activity_regularizer=None,
trainable=True,
dtype=None,
2023-04-27 23:02:31 +00:00
autocast=True,
2023-04-26 20:46:23 +00:00
name=None,
**kwargs,
2023-04-23 02:26:17 +00:00
):
BackendLayer.__init__(self)
self._lock = False
Operation.__init__(self, name=name)
2023-04-23 02:26:17 +00:00
self.activity_regularizer = regularizers.get(activity_regularizer)
input_dim_arg = kwargs.pop("input_dim", None)
if input_dim_arg is not None:
input_shape_arg = (input_dim_arg,)
else:
input_shape_arg = kwargs.pop("input_shape", None)
if input_shape_arg is not None:
warnings.warn(
"Do not pass an `input_shape`/`input_dim` argument to "
"a layer. When using Sequential models, "
"prefer using an `Input(shape)` object as the "
"first layer in the model instead.",
stacklevel=2,
)
self._input_shape_arg = input_shape_arg
if kwargs:
raise ValueError(
"Unrecognized keyword arguments "
f"passed to {self.__class__.__name__}: {kwargs}"
)
2023-04-09 19:21:45 +00:00
self.built = False
2023-04-27 23:02:31 +00:00
self.dtype_policy = mixed_precision.resolve_policy(dtype)
self.autocast = autocast
2023-04-09 19:21:45 +00:00
self.input_spec = None
self.supports_jit = True
2023-04-09 19:21:45 +00:00
2023-04-27 23:02:31 +00:00
self._trainable = trainable
2023-04-09 19:21:45 +00:00
self._losses = []
2023-06-06 01:39:26 +00:00
2023-04-09 19:21:45 +00:00
self._supports_masking = not utils.is_default(self.compute_mask)
2023-05-18 22:55:43 +00:00
# Whether to automatically convert (+ auto-cast) inputs to `call()`.
self._convert_input_args = True
# Whether to allow non-tensors as positional arguments in `call()`.
2023-05-08 23:44:25 +00:00
self._allow_non_tensor_positional_args = False
2023-05-18 22:55:43 +00:00
# Dict of shapes that were used to call `build()`.
2023-04-09 19:21:45 +00:00
self._build_shapes_dict = None
2023-04-13 00:12:57 +00:00
self._call_signature_parameters = [
p.name for p in inspect.signature(self.call).parameters.values()
]
self._initializer_tracker()
2023-04-09 19:21:45 +00:00
@tracking.no_automatic_dependency_tracking
def _initializer_tracker(self):
2023-06-06 01:39:26 +00:00
if hasattr(self, "_tracker"):
return
trainable_variables = []
non_trainable_variables = []
layers = []
metrics = []
seed_generators = []
self._tracker = tracking.Tracker(
2023-04-09 19:21:45 +00:00
{
2023-05-08 18:34:55 +00:00
"trainable_variables": (
lambda x: isinstance(x, backend.Variable) and x.trainable,
trainable_variables,
2023-05-08 18:34:55 +00:00
),
"non_trainable_variables": (
lambda x: isinstance(x, backend.Variable)
and not x.trainable,
non_trainable_variables,
2023-04-09 19:21:45 +00:00
),
"metrics": (lambda x: isinstance(x, Metric), metrics),
2023-04-09 19:21:45 +00:00
"layers": (
lambda x: isinstance(x, Layer)
and not isinstance(x, Metric),
layers,
2023-04-09 19:21:45 +00:00
),
"seed_generators": (
lambda x: isinstance(x, backend.random.SeedGenerator),
seed_generators,
),
2023-04-09 19:21:45 +00:00
}
)
self._trainable_variables = trainable_variables
self._non_trainable_variables = non_trainable_variables
self._layers = layers
self._metrics = metrics
self._seed_generators = seed_generators
2023-04-09 19:21:45 +00:00
@utils.default
def build(self, input_shape):
if utils.is_default(self.build) and might_have_unbuilt_state(self):
warnings.warn(
f"`build()` was called on layer '{self.name}', however "
"the layer does not have a `build()` method implemented "
"and it looks like it has unbuilt state. This will cause "
"the layer to be marked as built, despite not being "
"actually built, which may cause failures down the line. "
"Make sure to implement a proper `build()` method."
)
2023-04-09 19:21:45 +00:00
self.built = True
def get_build_config(self):
"""Returns a dictionary with the layer's input shape.
This method returns a config dict that can be used by
`build_from_config(config)` to create all states (e.g. Variables and
Lookup tables) needed by the layer.
By default, the config only contains the input shape that the layer
was built with. If you're writing a custom layer that creates state in
an unusual way, you should override this method to make sure this state
is already created when Keras attempts to load its value upon model
loading.
Returns:
A dict containing the input shape associated with the layer.
"""
if self._build_shapes_dict is not None:
if len(self._build_shapes_dict) == 1:
return {
"input_shape": tuple(self._build_shapes_dict.values())[0],
}
else:
return {"shapes_dict": self._build_shapes_dict}
def build_from_config(self, config):
"""Builds the layer's states with the supplied config dict.
By default, this method calls the `build(config["input_shape"])` method,
which creates weights based on the layer's input shape in the supplied
config. If your config contains other information needed to load the
layer's state, you should override this method.
Args:
config: Dict containing the input shape associated with this layer.
"""
if config:
if "input_shape" in config:
self.build(config["input_shape"])
2023-04-21 22:01:17 +00:00
self._build_shapes_dict = config
2023-04-09 19:21:45 +00:00
elif "shapes_dict" in config:
self.build(**config["shapes_dict"])
2023-04-21 22:01:17 +00:00
self._build_shapes_dict = config["shapes_dict"]
2023-04-25 23:12:14 +00:00
self.built = True
2023-04-09 19:21:45 +00:00
def add_variable(
self,
shape,
initializer,
dtype=None,
2023-04-17 21:55:17 +00:00
trainable=True,
2023-04-09 19:21:45 +00:00
regularizer=None,
constraint=None,
name=None,
):
# TODO: handle layout
self._check_super_called()
2023-04-21 17:00:32 +00:00
initializer = initializers.get(initializer)
2023-04-09 19:21:45 +00:00
variable = backend.Variable(
2023-04-26 17:29:40 +00:00
initializer=initializer,
shape=shape,
2023-04-27 23:02:31 +00:00
dtype=dtype or self.variable_dtype,
2023-04-09 19:21:45 +00:00
trainable=trainable,
name=name,
)
# Will be added to layer.losses
variable.regularizer = regularizer
variable.constraint = constraint
2023-06-05 03:55:24 +00:00
self._track_variable(variable)
2023-04-09 19:21:45 +00:00
return variable
2023-04-12 21:27:30 +00:00
def add_weight(self, *args, **kwargs):
return self.add_variable(*args, **kwargs)
2023-04-09 19:21:45 +00:00
@property
def trainable(self):
return self._trainable
@trainable.setter
def trainable(self, value):
"""Sets trainable attribute for the layer and its sublayers.
When this value is changed during training (e.g. with a
`Callback`) you need to call the parent
`Model.make_train_function` with `force=True` in order to
recompile the training graph.
Args:
value: Boolean with the desired state for the layer's trainable
attribute.
"""
value = bool(value)
self._trainable = value
for v in self._trainable_variables:
v.trainable = value
for layer in self._layers:
layer.trainable = value
2023-04-09 19:21:45 +00:00
@property
def variables(self):
2023-05-08 20:51:15 +00:00
# Return only weights/rng state/metric variables
# of all Layers, recursively.
# Also deduplicate them.
variables = []
seen_ids = set()
for v in self._trainable_variables + self._non_trainable_variables:
if id(v) not in seen_ids:
variables.append(v)
seen_ids.add(id(v))
for m in self._metrics:
variables.extend(m.variables)
for sg in self._seed_generators:
variables.append(sg.state)
2023-05-08 20:51:15 +00:00
for layer in self._layers:
for v in layer.variables:
if id(v) not in seen_ids:
variables.append(v)
seen_ids.add(id(v))
2023-04-09 19:21:45 +00:00
return variables
@property
def trainable_variables(self):
if not self.trainable:
return []
2023-04-09 19:21:45 +00:00
return [v for v in self.variables if v.trainable]
@property
def non_trainable_variables(self):
if not self.trainable:
return self.variables
2023-04-09 19:21:45 +00:00
return [v for v in self.variables if not v.trainable]
@property
def weights(self):
2023-04-26 21:54:00 +00:00
# Return only "own weights" of all Layers, recursively.
# Also deduplicate them.
weights = []
seen_ids = set()
2023-05-08 18:34:55 +00:00
for w in self._trainable_variables + self._non_trainable_variables:
2023-04-26 21:54:00 +00:00
if id(w) not in seen_ids:
weights.append(w)
seen_ids.add(id(w))
2023-04-09 19:21:45 +00:00
for layer in self._layers:
2023-04-26 21:54:00 +00:00
for w in layer.weights:
if id(w) not in seen_ids:
weights.append(w)
seen_ids.add(id(w))
2023-04-09 19:21:45 +00:00
return weights
@property
def trainable_weights(self):
if not self.trainable:
return []
2023-04-09 19:21:45 +00:00
return [v for v in self.weights if v.trainable]
@property
def non_trainable_weights(self):
if not self.trainable:
return self.weights
2023-04-09 19:21:45 +00:00
return [v for v in self.weights if not v.trainable]
def get_weights(self):
return [v.numpy() for v in self.weights]
def set_weights(self, weights):
layer_weights = self.weights
if len(layer_weights) != len(weights):
raise ValueError(
f"You called `set_weights(weights)` on layer '{self.name}' "
2023-04-27 03:42:23 +00:00
f"with a weight list of length {len(weights)}, but the layer "
f"was expecting {len(layer_weights)} weights."
2023-04-09 19:21:45 +00:00
)
for variable, value in zip(layer_weights, weights):
if variable.shape != value.shape:
raise ValueError(
f"Layer {self.name} weight shape {variable.shape} "
"is not compatible with provided weight "
f"shape {value.shape}."
)
variable.assign(value)
@property
def dtype(self):
"""The dtype of the state (weights) of the layer."""
return self.variable_dtype
@property
def compute_dtype(self):
"""The dtype of the computations performed by the layer."""
return self.dtype_policy.compute_dtype
@property
def variable_dtype(self):
"""The dtype of the state (weights) of the layer."""
2023-04-27 23:02:31 +00:00
return self.dtype_policy.variable_dtype
2023-04-09 19:21:45 +00:00
@property
def supports_masking(self):
"""Whether this layer supports computing a mask using `compute_mask`."""
return self._supports_masking
@supports_masking.setter
def supports_masking(self, value):
self._supports_masking = value
@utils.default
2023-04-28 20:49:38 +00:00
def compute_mask(self, inputs, previous_mask):
return previous_mask
2023-04-09 19:21:45 +00:00
2023-05-06 21:34:46 +00:00
@traceback_utils.filter_traceback
2023-04-09 19:21:45 +00:00
def __call__(self, *args, **kwargs):
self._check_super_called()
2023-04-28 22:10:53 +00:00
#####################################
2023-04-12 21:27:30 +00:00
# 1. Convert any array arguments to tensors of correct dtype.
2023-04-09 19:21:45 +00:00
def maybe_convert(x):
2023-04-28 21:22:29 +00:00
if backend.is_tensor(x):
# Handle Torch device placement.
if backend.backend() == "torch":
x = backend.convert_to_tensor(x)
2023-04-28 21:22:29 +00:00
if (
self.autocast
and backend.is_float_dtype(x.dtype)
and x.dtype != self.compute_dtype
):
x = backend.cast(x, dtype=self.compute_dtype)
2023-05-15 05:45:20 +00:00
return x
2023-04-28 21:22:29 +00:00
elif isinstance(x, backend.KerasTensor):
if (
self.autocast
and backend.is_float_dtype(x.dtype)
and x.dtype != self.compute_dtype
):
x.dtype = self.compute_dtype
2023-05-15 05:45:20 +00:00
return x
elif hasattr(x, "__array__"):
return backend.convert_to_tensor(x, dtype=self.compute_dtype)
2023-04-09 19:21:45 +00:00
return x
2023-05-18 22:55:43 +00:00
if self._convert_input_args:
args = nest.map_structure(maybe_convert, args)
kwargs = nest.map_structure(maybe_convert, kwargs)
2023-04-09 19:21:45 +00:00
2023-04-28 22:10:53 +00:00
##########################################################
# 2. Enforce that only tensors can be passed positionally.
2023-05-08 23:44:25 +00:00
if not self._allow_non_tensor_positional_args:
for arg in nest.flatten(args):
if not isinstance(arg, KerasTensor) and not backend.is_tensor(
arg
):
raise ValueError(
"Only input tensors may be passed as "
"positional arguments. The following argument value "
f"should be passed as a keyword argument: {arg} "
f"(of type {type(arg)})"
)
2023-04-09 19:21:45 +00:00
2023-04-28 22:10:53 +00:00
# Caches info about `call()` signature, args, kwargs.
2023-04-28 20:49:38 +00:00
call_spec = CallSpec(self.call, args, kwargs)
2023-04-28 22:10:53 +00:00
############################################
# 3. Check input spec for 1st positional arg.
2023-04-12 22:20:56 +00:00
# TODO: consider extending this to all args and kwargs.
2023-04-28 20:49:38 +00:00
self._assert_input_compatibility(call_spec.first_arg)
2023-04-09 19:21:45 +00:00
2023-04-28 22:10:53 +00:00
################
# 4. Call build.
2023-04-28 20:49:38 +00:00
self._maybe_build(call_spec)
2023-04-09 19:21:45 +00:00
2023-04-28 22:10:53 +00:00
##########################
# 5. Infer training value
2023-04-09 19:21:45 +00:00
# Training phase for `Layer.call` is set via (in order of priority):
2023-04-27 03:42:23 +00:00
# (1) The `training` argument passed to this `Layer.call`, if not None
2023-04-09 19:21:45 +00:00
# (2) The training argument of an outer `Layer.call`.
2023-04-27 03:42:23 +00:00
# (4) Any non-None default value for `training` in the call signature
2023-04-09 19:21:45 +00:00
# (5) False (treating the layer as if it's in inference)
2023-04-28 22:10:53 +00:00
# Maintains info about the `Layer.call` stack
# across nested calls.
call_context = self._get_call_context()
# This is the value explicity passed by the user
training = call_spec.user_arguments_dict.get("training", None)
2023-04-09 19:21:45 +00:00
if training is None:
2023-04-28 22:10:53 +00:00
# Wasn't passed explicitly: use context value
2023-04-09 19:21:45 +00:00
training = call_context.training
if training is None:
2023-04-28 22:10:53 +00:00
# Get signature default value; else False
training = call_spec.arguments_dict.get("training", False)
2023-04-09 19:21:45 +00:00
call_context.training = training
if self._call_has_training_arg():
kwargs["training"] = training
2023-04-28 22:10:53 +00:00
##############################
# 6. Populate mask argument(s)
2023-04-28 20:49:38 +00:00
if self.supports_masking:
if len(call_spec.tensor_arguments_dict) == 1:
if (
"mask" in call_spec.argument_names
and call_spec.arguments_dict["mask"] is None
):
arg_name = list(call_spec.tensor_arguments_dict.keys())[0]
only_tensor_arg = call_spec.tensor_arguments_dict[arg_name]
mask = nest.map_structure(
lambda x: getattr(x, "_keras_mask", None),
only_tensor_arg,
)
kwargs["mask"] = mask
elif len(call_spec.tensor_arguments_dict) > 1:
for k, v in call_spec.tensor_arguments_dict.items():
expected_mask_arg_name = f"{k}_mask"
if expected_mask_arg_name in call_spec.argument_names:
if (
call_spec.arguments_dict[expected_mask_arg_name]
is None
):
mask = nest.map_structure(
lambda x: getattr(x, "_keras_mask", None), v
)
kwargs[expected_mask_arg_name] = mask
2023-04-28 22:10:53 +00:00
####################
# 7. Call the layer.
2023-04-27 23:23:08 +00:00
try:
with backend.name_scope(self.name):
if self.autocast and self.compute_dtype != self.variable_dtype:
# For mixed precision, we automatically cast layer variables
# (float ones only) to the compute dtype upon access.
with backend.AutocastScope(self.compute_dtype):
outputs = super().__call__(*args, **kwargs)
else:
2023-04-27 23:02:31 +00:00
outputs = super().__call__(*args, **kwargs)
2023-05-09 19:17:31 +00:00
if not self.built:
self.built = True
2023-04-27 23:23:08 +00:00
# Record activity regularizer loss.
if self.activity_regularizer is not None:
for output in nest.flatten(outputs):
if backend.is_tensor(output):
self.add_loss(self.activity_regularizer(output))
2023-04-28 20:49:38 +00:00
if self.supports_masking:
# Set masks on outputs,
# provided only the first positional input arg and its mask.
# TODO: consider extending this to all args and kwargs.
previous_mask = getattr(
call_spec.first_arg, "_keras_mask", None
)
self._set_mask_metadata(
call_spec.first_arg, outputs, previous_mask
)
2023-04-27 23:23:08 +00:00
finally:
# Destroy call context if we created it
self._maybe_reset_call_context()
2023-04-09 19:21:45 +00:00
return outputs
def call(self, *args, **kwargs):
raise NotImplementedError
2023-05-06 21:34:46 +00:00
@traceback_utils.filter_traceback
2023-04-09 19:21:45 +00:00
def stateless_call(
self,
trainable_variables,
non_trainable_variables,
*args,
return_losses=False,
**kwargs,
2023-04-09 19:21:45 +00:00
):
"""Call the layer without any side effects.
Args:
trainable_variables: List of trainable variables of the model.
non_trainable_variables: List of non-trainable variables of the
model.
*args: Positional argumets to be passed to `call()`.
return_losses: If `True`, `stateless_call()` will return the list of
losses created during `call()` as part of its return values.
**kwargs: Keyword arguments to be passed to `call()`.
2023-04-09 19:21:45 +00:00
Returns:
A tuple. By default, returns `(outputs, non_trainable_variables)`.
If `return_losses = True`, then returns
`(outputs, non_trainable_variables, losses)`.
Note: `non_trainable_variables` include not only non-trainable weights
such as `BatchNormalization` statistics, but also RNG seed state
(if there are any random operations part of the layer, such as dropout),
and `Metric` state (if there are any metrics attached to the layer).
These are all elements of state of the layer.
Example:
```python
model = ...
data = ...
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
# Call the model with zero side effects
outputs, non_trainable_variables = model.stateless_call(
trainable_variables,
non_trainable_variables,
data,
)
# Attach the updated state to the model
# (until you do this, the model is still in its pre-call state).
for ref_var, value in zip(
model.non_trainable_variables, non_trainable_variables
):
ref_var.assign(value)
```
"""
2023-04-09 19:21:45 +00:00
self._check_super_called()
if not self.built:
raise ValueError(
f"To call stateless_call, {self.__class__.__name__} must be "
2023-04-27 03:42:23 +00:00
"built (i.e. its variables must have been already created). "
2023-04-09 19:21:45 +00:00
"You can build it by calling it on some data."
)
if len(trainable_variables) != len(self.trainable_variables):
raise ValueError(
"Argument `trainable_variables` must be a list of tensors "
2023-04-27 03:42:23 +00:00
"corresponding 1:1 to "
f"{self.__class__.__name__}().trainable_variables. "
f"Received list with length {len(trainable_variables)}, "
f"but expected {len(self.trainable_variables)} variables."
2023-04-09 19:21:45 +00:00
)
if len(non_trainable_variables) != len(self.non_trainable_variables):
raise ValueError(
"Argument `non_trainable_variables` must be a list of tensors "
2023-04-27 03:42:23 +00:00
"corresponding 1:1 to "
f"{self.__class__.__name__}().non_trainable_variables. "
f"Received list with length {len(non_trainable_variables)}, "
f"but expected {len(self.non_trainable_variables)} variables."
2023-04-09 19:21:45 +00:00
)
# Gather variable mapping
trainable_mapping = zip(self.trainable_variables, trainable_variables)
non_trainable_mapping = zip(
self.non_trainable_variables, non_trainable_variables
)
mapping = list(trainable_mapping) + list(non_trainable_mapping)
# Call in stateless scope
with backend.StatelessScope(
state_mapping=mapping, collect_losses=return_losses
) as scope:
2023-04-09 19:21:45 +00:00
outputs = self.call(*args, **kwargs)
# Gather updated non-trainable variables
non_trainable_variables = []
for v in self.non_trainable_variables:
new_v = scope.get_current_value(v)
if new_v is not None:
non_trainable_variables.append(new_v)
else:
non_trainable_variables.append(v)
if return_losses:
return outputs, non_trainable_variables, scope.losses[:]
2023-04-09 19:21:45 +00:00
return outputs, non_trainable_variables
def compute_output_spec(self, *args, **kwargs):
if utils.is_default(self.compute_output_shape):
return super().compute_output_spec(*args, **kwargs)
else:
# Use compute_output_shape() to return the right output spec
2023-04-28 20:49:38 +00:00
call_spec = CallSpec(self.call, args, kwargs)
2023-05-17 23:06:01 +00:00
shapes_dict = get_shapes_dict(
self.compute_output_shape, call_spec, self.__class__
)
2023-04-09 19:21:45 +00:00
if len(shapes_dict) == 1:
# Single arg: pass it positionally
input_shape = tuple(shapes_dict.values())[0]
output_shape = self.compute_output_shape(input_shape)
else:
# More than one shape: pass them by name.
output_shape = self.compute_output_shape(**shapes_dict)
2023-05-12 03:53:38 +00:00
2023-05-05 01:25:08 +00:00
if (
isinstance(output_shape, list)
and output_shape
and isinstance(output_shape[0], (int, type(None)))
):
output_shape = tuple(output_shape)
if not isinstance(output_shape, (list, tuple, dict)):
try:
output_shape = tuple(output_shape)
except:
raise ValueError(
"Method `compute_output_shape()` of layer "
f"{self.__class__.__name__} is returning "
"a type that cannot be interpreted as a shape. "
"It should return a shape tuple. "
f"Received: {output_shape}"
)
2023-04-09 19:21:45 +00:00
if (
isinstance(output_shape, tuple)
and output_shape
and isinstance(output_shape[0], (int, type(None)))
):
return KerasTensor(output_shape, dtype=self.compute_dtype)
# Case: nested. Could be a tuple/list of shapes, or a dict of
# shapes. Could be deeply nested.
return map_shape_structure(
lambda s: KerasTensor(s, dtype=self.compute_dtype), output_shape
)
2023-04-09 19:21:45 +00:00
@utils.default
def compute_output_shape(self, *args, **kwargs):
return NotImplementedError
def add_loss(self, loss):
# Eager only.
losses = nest.flatten(loss)
for x in losses:
if not backend.is_tensor(x):
raise ValueError(
2023-04-27 03:42:23 +00:00
"`add_loss()` can only be called from inside `build()` or "
f"`call()`, on a tensor input. Received invalid value: {x}"
2023-04-09 19:21:45 +00:00
)
if backend.in_stateless_scope():
scope = backend.get_stateless_scope()
if scope.collect_losses:
for x in losses:
scope.add_loss(loss)
else:
self._losses.extend(losses)
@property
def losses(self):
losses = self._losses[:]
for layer in self._layers:
losses.extend(layer._losses)
weight_regularization_losses = []
for v in self.trainable_weights:
2023-04-18 22:46:57 +00:00
regularizer = getattr(v, "regularizer", None)
2023-04-09 19:21:45 +00:00
if regularizer:
weight_regularization_losses.append(regularizer(v))
losses.extend(weight_regularization_losses)
return losses
def save_own_variables(self, store):
"""Saves the state of the layer.
You can override this method to take full control of how the state of
the layer is saved upon calling `model.save()`.
Args:
store: Dict where the state of the model will be saved.
"""
2023-05-08 18:34:55 +00:00
all_vars = self._trainable_variables + self._non_trainable_variables
for i, v in enumerate(all_vars):
store[f"{i}"] = np.array(v)
def load_own_variables(self, store):
"""Loads the state of the layer.
You can override this method to take full control of how the state of
the layer is loaded upon calling `keras.models.load_model()`.
Args:
store: Dict from which the state of the model will be loaded.
"""
2023-05-08 18:34:55 +00:00
all_vars = self._trainable_variables + self._non_trainable_variables
if len(store.keys()) != len(all_vars):
2023-04-25 19:59:32 +00:00
if len(all_vars) == 0 and not self.built:
raise ValueError(
f"Layer '{self.name}' was never built "
"and thus it doesn't have any variables. "
f"However the weights file lists {len(store.keys())} "
2023-06-12 23:37:00 +00:00
"variables for this layer.\n"
"In most cases, this error indicates that either:\n\n"
"1. The layer is owned by a parent layer that "
"implements a `build()` method, but calling the "
"parent's `build()` method did NOT create the state of "
f"the child layer '{self.name}'. A `build()` method "
"must create ALL state for the layer, including "
"the state of any children layers.\n\n"
"2. You need to implement "
"the `def build_from_config(self, config)` method "
f"on layer '{self.name}', to specify how to rebuild "
"it during loading. "
"In this case, you might also want to implement the "
"method that generates the build config at saving time, "
2023-04-25 19:59:32 +00:00
"`def get_build_config(self)`. "
"The method `build_from_config()` is meant "
"to create the state "
"of the layer (i.e. its variables) upon deserialization.",
)
raise ValueError(
f"Layer '{self.name}' expected {len(all_vars)} variables, "
"but received "
f"{len(store.keys())} variables during loading. "
f"Expected: {[v.name for v in all_vars]}"
)
for i, v in enumerate(all_vars):
v.assign(store[f"{i}"])
2023-04-09 19:21:45 +00:00
def _clear_losses(self):
if backend.in_stateless_scope():
scope = backend.get_stateless_scope()
if scope.collect_losses:
for x in scope.losses:
if x in self._losses:
scope.losses.remove(x)
self._losses.clear()
for layer in self._layers:
layer._clear_losses()
2023-04-09 19:21:45 +00:00
2023-06-05 03:55:24 +00:00
def _track_variable(self, variable):
if variable.trainable:
2023-06-06 01:39:26 +00:00
self._tracker.add_to_store("trainable_variables", variable)
2023-06-05 03:55:24 +00:00
else:
2023-06-06 01:39:26 +00:00
self._tracker.add_to_store("non_trainable_variables", variable)
2023-06-05 03:55:24 +00:00
2023-04-09 19:21:45 +00:00
def add_metric(self):
# Permanently disabled
raise NotImplementedError
2023-04-09 22:51:23 +00:00
def count_params(self):
"""Count the total number of scalars composing the weights.
Returns:
An integer count.
"""
if not self.built:
raise ValueError(
"You tried to call `count_params` "
2023-05-06 21:34:46 +00:00
f"on layer '{self.name}', "
"but the layer isn't built. "
2023-04-09 22:51:23 +00:00
"You can build it manually via: "
f"`layer.build(input_shape)`."
)
return summary_utils.count_params(self.weights)
2023-04-09 19:21:45 +00:00
2023-04-28 20:49:38 +00:00
def _maybe_build(self, call_spec):
2023-04-09 19:21:45 +00:00
if not self.built:
2023-05-17 23:06:01 +00:00
shapes_dict = get_shapes_dict(self.build, call_spec, self.__class__)
2023-04-09 19:21:45 +00:00
self._build_shapes_dict = shapes_dict
2023-04-26 18:57:46 +00:00
failure = False
2023-04-09 19:21:45 +00:00
if len(shapes_dict) == 1:
# Single arg: pass it positionally
input_shape = tuple(shapes_dict.values())[0]
with backend.name_scope(self.name):
2023-04-26 18:59:47 +00:00
if utils.is_default(
self.build
) and might_have_unbuilt_state(self):
status = self._build_by_run_for_single_pos_arg(
input_shape
)
2023-04-26 18:57:46 +00:00
if not status:
failure = True
else:
self.build(input_shape)
2023-04-09 19:21:45 +00:00
else:
with backend.name_scope(self.name):
2023-04-28 20:49:38 +00:00
if utils.is_default(self.build):
if might_have_unbuilt_state(self):
status = self._build_by_run_for_kwargs(shapes_dict)
if not status:
failure = True
else:
self.build(**shapes_dict)
2023-04-26 19:15:04 +00:00
if failure:
2023-05-09 19:17:31 +00:00
if call_spec.eager:
# Will let the actual eager call do the state-building
return
2023-04-26 19:15:04 +00:00
raise ValueError(
f"Layer '{self.name}' looks like it has "
"unbuilt state, but Keras is not able to "
"trace the layer `call()` in order to "
"build it automatically. You must implement "
"the `def build(self, input_shape)` method on your "
"layer. It should create all variables used by the "
"layer (e.g. by calling `layer.build()` on all its "
"children layers)."
)
2023-04-09 19:21:45 +00:00
self.built = True
# Check input spec again (after build, since self.input_spec
# may have been updated
2023-04-28 20:49:38 +00:00
self._assert_input_compatibility(call_spec.first_arg)
2023-06-06 01:39:26 +00:00
# Hook used to do post-build actions
self._post_build()
if not self._tracker.locked:
# No state updates past this point.
self._tracker.lock(
msg=(
"You cannot add new elements of state "
"(variables or sub-layers) "
"to a layer that is already built. All state "
"must be created in the `__init__()` method or "
"in the`build()` method."
)
)
2023-04-09 19:21:45 +00:00
2023-04-26 18:57:46 +00:00
def _build_by_run_for_single_pos_arg(self, input_shape):
# Case: all inputs are in the first arg (possibly nested).
2023-06-12 17:05:11 +00:00
input_tensors = map_shape_structure(
lambda s: backend.KerasTensor(s), input_shape
)
try:
backend.compute_output_spec(self.call, input_tensors)
return True
2023-05-09 19:17:31 +00:00
except:
return False
2023-04-26 18:57:46 +00:00
def _build_by_run_for_kwargs(self, shapes_dict):
# Case: inputs were recorded as multiple keyword arguments.
if all(is_shape_tuple(s) for s in shapes_dict.values()):
# Case: all input keyword arguments were plain tensors.
input_tensors = {
2023-04-26 20:46:23 +00:00
# We strip the `_shape` suffix to recover kwarg names.
2023-06-14 17:40:12 +00:00
utils.removesuffix(k, "_shape"): backend.KerasTensor(shape)
for k, shape in shapes_dict.items()
}
try:
backend.compute_output_spec(self.call, **input_tensors)
2023-04-26 20:46:23 +00:00
return True
2023-05-09 19:17:31 +00:00
except:
return False
else:
# Not supported: nested input keyword arguments.
return False
2023-04-09 19:21:45 +00:00
def __repr__(self):
2023-04-28 20:49:38 +00:00
return (
f"<{self.__class__.__name__} "
f"name={self.name}, built={self.built}>"
)
2023-04-09 19:21:45 +00:00
def __str__(self):
2023-04-28 20:49:38 +00:00
return (
f"<{self.__class__.__name__} "
f"name={self.name}, built={self.built}>"
)
2023-04-09 19:21:45 +00:00
def __setattr__(self, name, value):
# Track Variables, Layers, Metrics, SeedGenerators.
2023-04-09 19:21:45 +00:00
if hasattr(self, "_tracker"):
value = self._tracker.track(value)
elif name != "_tracker":
self._initializer_tracker()
2023-04-09 19:21:45 +00:00
return super().__setattr__(name, value)
def _check_super_called(self):
if getattr(self, "_lock", True):
2023-04-09 19:21:45 +00:00
raise RuntimeError(
f"In layer '{self.__class__.__name__}', you forgot to call "
"`super().__init__()` as the first statement "
"in the `__init__()` method. Go add it!"
2023-04-09 19:21:45 +00:00
)
2023-04-28 20:49:38 +00:00
def _assert_input_compatibility(self, arg_0):
if self.input_spec:
2023-04-09 19:21:45 +00:00
input_spec.assert_input_compatibility(
2023-04-28 20:49:38 +00:00
self.input_spec, arg_0, layer_name=self.name
2023-04-09 19:21:45 +00:00
)
def _call_has_training_arg(self):
2023-04-13 00:12:57 +00:00
return "training" in self._call_signature_parameters
def _call_has_mask_arg(self):
return "mask" in self._call_signature_parameters
2023-04-09 19:21:45 +00:00
def _get_call_context(self):
"""Returns currently active `CallContext`."""
layer_call_ctx = global_state.get_global_attribute("current_call_ctx")
if layer_call_ctx is None:
2023-04-09 19:21:45 +00:00
# Enter new call context.
layer_call_ctx = CallContext(entry_layer=self)
global_state.set_global_attribute(
"current_call_ctx", layer_call_ctx
)
2023-04-09 19:21:45 +00:00
self._clear_losses()
return layer_call_ctx
2023-04-09 19:21:45 +00:00
def _maybe_reset_call_context(self):
layer_call_ctx = global_state.get_global_attribute("current_call_ctx")
if layer_call_ctx is None or layer_call_ctx.entry_layer == self:
global_state.set_global_attribute("current_call_ctx", None)
2023-04-09 19:21:45 +00:00
def _flatten_layers(self, include_self=True, recursive=True):
layers = []
if include_self:
layers.append(self)
seen_object_ids = set()
deque = collections.deque(self._layers)
while deque:
layer = deque.popleft()
if id(layer) in seen_object_ids:
continue
seen_object_ids.add(id(layer))
2023-04-18 22:46:57 +00:00
layers.append(layer)
2023-04-09 19:21:45 +00:00
# Introspect recursively through sublayers.
if recursive:
deque.extendleft(layer._layers)
2023-04-18 22:46:57 +00:00
return layers
2023-04-09 19:21:45 +00:00
2023-04-23 02:26:17 +00:00
def _set_mask_metadata(self, inputs, outputs, previous_mask):
flat_outputs = nest.flatten(outputs)
mask_already_computed = all(
getattr(x, "_keras_mask", None) is not None for x in flat_outputs
)
if mask_already_computed:
return
output_masks = self.compute_mask(inputs, previous_mask)
if output_masks is None:
return
flat_masks = nest.flatten(output_masks)
for tensor, mask in zip(flat_outputs, flat_masks):
2023-04-28 20:49:38 +00:00
if getattr(tensor, "_keras_mask", None) is None:
2023-05-19 02:16:44 +00:00
try:
tensor._keras_mask = mask
except AttributeError:
# It's a C type.
pass
2023-04-28 20:49:38 +00:00
@python_utils.default
def get_config(self):
base_config = super().get_config()
config = {
"trainable": self.trainable,
"dtype": self.dtype_policy.name,
}
return {**base_config, **config}
2023-04-28 20:49:38 +00:00
def is_backend_tensor_or_symbolic(x):
return backend.is_tensor(x) or isinstance(x, backend.KerasTensor)
class CallSpec:
def __init__(self, call_fn, args, kwargs):
sig = inspect.signature(call_fn)
# `training` and `mask` are special kwargs that are always available in
# a layer, if user specifies them in their call without adding to spec,
# we remove them to be able to bind variables. User is not using
# `training` anyway so we can ignore.
# TODO: If necessary use workaround for `mask`
if "training" in kwargs and "training" not in sig.parameters:
kwargs.pop("training")
bound_args = sig.bind(*args, **kwargs)
else:
bound_args = sig.bind(*args, **kwargs)
2023-04-28 22:10:53 +00:00
self.user_arguments_dict = {
k: v for k, v in bound_args.arguments.items()
}
2023-04-28 20:49:38 +00:00
bound_args.apply_defaults()
arg_dict = {}
arg_names = []
tensor_arg_dict = {}
tensor_args = []
tensor_arg_names = []
nested_tensor_arg_names = []
for name, value in bound_args.arguments.items():
arg_dict[name] = value
arg_names.append(name)
if is_backend_tensor_or_symbolic(value):
tensor_args.append(value)
tensor_arg_names.append(name)
tensor_arg_dict[name] = value
elif nest.is_nested(value):
flat_values = nest.flatten(value)
if all(is_backend_tensor_or_symbolic(x) for x in flat_values):
tensor_args.append(value)
tensor_arg_names.append(name)
tensor_arg_dict[name] = value
nested_tensor_arg_names.append(name)
elif any(is_backend_tensor_or_symbolic(x) for x in flat_values):
raise ValueError(
"In a nested call() argument, "
"you cannot mix tensors and non-tensors. "
"Received invalid mixed argument: "
f"{name}={value}"
)
self.arguments_dict = arg_dict
self.argument_names = arg_names
self.tensor_arguments_dict = tensor_arg_dict
self.tensor_arguments_names = tensor_arg_names
self.nested_tensor_argument_names = nested_tensor_arg_names
self.first_arg = arg_dict[arg_names[0]]
2023-05-09 19:17:53 +00:00
if all(
backend.is_tensor(x) for x in self.tensor_arguments_dict.values()
):
2023-05-09 19:17:31 +00:00
self.eager = True
else:
self.eager = False
2023-04-23 02:26:17 +00:00
2023-04-09 19:21:45 +00:00
2023-04-28 20:49:38 +00:00
def get_arguments_dict(fn, args, kwargs):
2023-04-09 19:21:45 +00:00
"""Return a dict mapping argument names to their values."""
sig = inspect.signature(fn)
bound_args = sig.bind(*args, **kwargs)
arg_dict = {}
for name, value in bound_args.arguments.items():
arg_dict[name] = value
return arg_dict
2023-05-17 23:06:01 +00:00
def get_shapes_dict(target_fn, call_spec, cls):
2023-04-09 19:21:45 +00:00
"""Convert the call() arguments dict into a dict of input shape arguments.
Example:
```
2023-05-17 23:06:01 +00:00
>>> get_shapes_dict(self.build, call_spec, cls)
2023-04-09 19:21:45 +00:00
{"input_a_shape": (2, 3)}
```
"""
2023-05-17 23:06:01 +00:00
expected_names = check_shapes_signature(target_fn, call_spec, cls)
2023-04-09 19:21:45 +00:00
shapes_dict = {}
2023-04-28 20:49:38 +00:00
for k, v in call_spec.tensor_arguments_dict.items():
2023-05-04 02:56:56 +00:00
if k == "mask" or k.startswith("mask_"):
# Do not include mask tensors in shapes dict
continue
2023-05-08 20:51:15 +00:00
if k == "kwargs" or k == "args":
# Do not include catch-alls in shapes dict
continue
if expected_names is not None and f"{k}_shape" not in expected_names:
continue
2023-04-28 20:49:38 +00:00
if k in call_spec.nested_tensor_argument_names:
2023-04-09 19:21:45 +00:00
shapes_dict[f"{k}_shape"] = nest.map_structure(
lambda x: backend.standardize_shape(x.shape), v
)
2023-04-28 20:49:38 +00:00
else:
shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape)
2023-04-09 19:21:45 +00:00
return shapes_dict
2023-05-17 23:06:01 +00:00
def check_shapes_signature(target_fn, call_spec, cls):
"""Asserts that the argument names in `target_fn` match arguments in `call`.
We use this to check that `build()` and `compute_output_shape()` arguments
align with `call()` arguments.
For instance if `build()` has the signature
`def build(self, a_shape, b_shape)` we expect `call()` to accept the
arguments `a` and `b`.
2023-04-09 19:21:45 +00:00
When there is a single argument accepted by `target_fn`, we do allow any
name and do not check the call signature.
2023-04-09 19:21:45 +00:00
Returns:
The list of arguments names expected by the `target_fn` or
`None` if any passed name is acceptable.
2023-04-09 19:21:45 +00:00
"""
if utils.is_default(target_fn):
return None
sig = inspect.signature(target_fn)
2023-04-09 19:21:45 +00:00
expected_names = []
for name, param in sig.parameters.items():
if param.kind in (
param.POSITIONAL_OR_KEYWORD,
param.POSITIONAL_ONLY,
param.KEYWORD_ONLY,
):
expected_names.append(name)
if len(expected_names) == 1:
return None
for name in expected_names:
method_name = target_fn.__name__
error_preamble = (
f"For a `{method_name}()` method with more than one argument, all "
"arguments should have a `_shape` suffix and match an argument "
f"from `call()`. E.g. `{method_name}(self, foo_shape, bar_shape)` "
"would match `call(self, foo, bar)`."
2023-04-09 19:21:45 +00:00
)
if not name.endswith("_shape"):
raise ValueError(
2023-05-17 23:06:01 +00:00
f"{error_preamble} For layer '{cls.__name__}', "
f"Received `{method_name}()` argument "
f"`{name}`, which does not end in `_shape`."
)
2023-06-14 17:40:12 +00:00
expected_call_arg = utils.removesuffix(name, "_shape")
if expected_call_arg not in call_spec.arguments_dict:
raise ValueError(
2023-05-17 23:06:01 +00:00
f"{error_preamble} For layer '{cls.__name__}', "
f"received `{method_name}()` argument "
f"`{name}`, but `call()` does not have argument "
f"`{expected_call_arg}`."
)
return expected_names
2023-04-09 19:21:45 +00:00
class CallContext:
def __init__(self, entry_layer):
self.entry_layer = entry_layer
self.training = None
2023-04-26 17:29:40 +00:00
def is_shape_tuple(s):
return isinstance(s, (list, tuple)) and all(
d is None or isinstance(d, int) for d in s
)
2023-04-26 18:57:46 +00:00
def might_have_unbuilt_state(layer):
return any(not lr.built for lr in layer._layers)