keras/keras_core/layers/layer.py

809 lines
29 KiB
Python
Raw Normal View History

2023-04-09 19:21:45 +00:00
"""Layer is an Operation with state.
Takes care of:
- Weights / variables (and tracking thereof)
- deferred build
- trainable argument value inference
- masking
- autocasting
And some more magic:
- add_loss
- metric tracking
- RNG seed tracking
2023-04-23 02:26:17 +00:00
- activity regularization
2023-04-09 19:21:45 +00:00
"""
import collections
import inspect
import threading
2023-04-26 17:29:40 +00:00
import warnings
import numpy as np
from tensorflow import keras as tf_keras
from tensorflow import nest
2023-04-09 19:21:45 +00:00
from keras_core import backend
2023-04-21 17:00:32 +00:00
from keras_core import initializers
2023-04-23 02:26:17 +00:00
from keras_core import regularizers
2023-04-09 19:21:45 +00:00
from keras_core import utils
2023-04-09 19:53:37 +00:00
from keras_core.api_export import keras_core_export
from keras_core.backend import KerasTensor
from keras_core.layers import input_spec
from keras_core.metrics.metric import Metric
from keras_core.operations.operation import Operation
from keras_core.utils import summary_utils
from keras_core.utils.tracking import Tracker
2023-04-09 19:21:45 +00:00
2023-04-09 22:51:23 +00:00
# TODO: cache all call signature processing. See layer_utils.CallFunctionSpec() in Keras.
2023-04-09 19:21:45 +00:00
2023-04-09 19:53:37 +00:00
@keras_core_export(["keras_core.Layer", "keras_core.layers.Layer"])
2023-04-09 19:21:45 +00:00
class Layer(Operation):
2023-04-23 02:26:17 +00:00
def __init__(
self, activity_regularizer=None, trainable=True, dtype=None, name=None
):
2023-04-09 19:21:45 +00:00
super().__init__(name=name)
2023-04-23 02:26:17 +00:00
self.activity_regularizer = regularizers.get(activity_regularizer)
2023-04-09 19:21:45 +00:00
self._trainable = trainable
if dtype is None:
dtype = backend.floatx()
self.built = False
self.dtype_policy = tf_keras.mixed_precision.Policy(dtype)
self.input_spec = None
self._layers = []
self._metrics = []
self._seed_generators = []
2023-04-09 19:21:45 +00:00
self._losses = []
self._variables = []
self._supports_masking = not utils.is_default(self.compute_mask)
self._build_shapes_dict = None
2023-04-13 00:12:57 +00:00
self._call_signature_parameters = [
p.name for p in inspect.signature(self.call).parameters.values()
]
2023-04-09 19:21:45 +00:00
self._tracker = Tracker(
{
"variables": (
lambda x: isinstance(x, backend.Variable),
self._variables,
),
"metrics": (lambda x: isinstance(x, Metric), self._metrics),
"layers": (
lambda x: isinstance(x, Layer)
and not isinstance(x, Metric),
2023-04-09 19:21:45 +00:00
self._layers,
),
"seed_generators": (
lambda x: isinstance(x, backend.random.SeedGenerator),
self._seed_generators,
),
2023-04-09 19:21:45 +00:00
}
)
@utils.default
def build(self, input_shape):
self.built = True
def get_build_config(self):
"""Returns a dictionary with the layer's input shape.
This method returns a config dict that can be used by
`build_from_config(config)` to create all states (e.g. Variables and
Lookup tables) needed by the layer.
By default, the config only contains the input shape that the layer
was built with. If you're writing a custom layer that creates state in
an unusual way, you should override this method to make sure this state
is already created when Keras attempts to load its value upon model
loading.
Returns:
A dict containing the input shape associated with the layer.
"""
if self._build_shapes_dict is not None:
if len(self._build_shapes_dict) == 1:
return {
"input_shape": tuple(self._build_shapes_dict.values())[0],
}
else:
return {"shapes_dict": self._build_shapes_dict}
def build_from_config(self, config):
"""Builds the layer's states with the supplied config dict.
By default, this method calls the `build(config["input_shape"])` method,
which creates weights based on the layer's input shape in the supplied
config. If your config contains other information needed to load the
layer's state, you should override this method.
Args:
config: Dict containing the input shape associated with this layer.
"""
if config:
if "input_shape" in config:
self.build(config["input_shape"])
2023-04-21 22:01:17 +00:00
self._build_shapes_dict = config
2023-04-09 19:21:45 +00:00
elif "shapes_dict" in config:
self.build(**config["shapes_dict"])
2023-04-21 22:01:17 +00:00
self._build_shapes_dict = config["shapes_dict"]
2023-04-25 23:12:14 +00:00
self.built = True
2023-04-09 19:21:45 +00:00
def add_variable(
self,
shape,
initializer,
dtype=None,
2023-04-17 21:55:17 +00:00
trainable=True,
2023-04-09 19:21:45 +00:00
regularizer=None,
constraint=None,
name=None,
):
# TODO: handle constraint (in the optimizer)
# TODO: handle layout
self._check_super_called()
2023-04-21 17:00:32 +00:00
initializer = initializers.get(initializer)
2023-04-09 19:21:45 +00:00
variable = backend.Variable(
2023-04-26 17:29:40 +00:00
initializer=initializer,
shape=shape,
2023-04-09 19:21:45 +00:00
dtype=dtype,
trainable=trainable,
name=name,
)
# Will be added to layer.losses
variable.regularizer = regularizer
variable.constraint = constraint
self._variables.append(variable)
# Prevent double-tracking
self._tracker.stored_ids["variables"].add(id(variable))
return variable
2023-04-12 21:27:30 +00:00
def add_weight(self, *args, **kwargs):
return self.add_variable(*args, **kwargs)
2023-04-09 19:21:45 +00:00
@property
def trainable(self):
return self._trainable
@trainable.setter
def trainable(self, value):
"""Sets trainable attribute for the layer and its sublayers.
When this value is changed during training (e.g. with a
`Callback`) you need to call the parent
`Model.make_train_function` with `force=True` in order to
recompile the training graph.
Args:
value: Boolean with the desired state for the layer's trainable
attribute.
"""
for layer in self._layers():
layer._trainable = value
@property
def variables(self):
# Includes weights, seed generator state, and metric variables.
2023-04-09 19:21:45 +00:00
variables = self.weights[:]
for m in self._metrics:
variables.extend(m.variables)
for sg in self._seed_generators:
variables.append(sg.state)
2023-04-09 19:21:45 +00:00
return variables
@property
def trainable_variables(self):
return [v for v in self.variables if v.trainable]
@property
def non_trainable_variables(self):
return [v for v in self.variables if not v.trainable]
@property
def weights(self):
# Return only "own weights" of all Layers, recursively
weights = self._variables[:]
for layer in self._layers:
weights.extend(layer._variables)
return weights
@property
def trainable_weights(self):
return [v for v in self.weights if v.trainable]
@property
def non_trainable_weights(self):
return [v for v in self.weights if not v.trainable]
def get_weights(self):
return [v.numpy() for v in self.weights]
def set_weights(self, weights):
layer_weights = self.weights
if len(layer_weights) != len(weights):
raise ValueError(
f"You called `set_weights(weights)` on layer '{self.name}' "
f"with a weight list of length {len(weights)}, but the layer was "
f"expecting {len(layer_weights)} weights."
)
for variable, value in zip(layer_weights, weights):
if variable.shape != value.shape:
raise ValueError(
f"Layer {self.name} weight shape {variable.shape} "
"is not compatible with provided weight "
f"shape {value.shape}."
)
variable.assign(value)
@property
def dtype(self):
"""The dtype of the state (weights) of the layer."""
return self.variable_dtype
@property
def compute_dtype(self):
"""The dtype of the computations performed by the layer."""
return self.dtype_policy.compute_dtype
@property
def variable_dtype(self):
"""The dtype of the state (weights) of the layer."""
return self.dtype_policy.compute_dtype
@property
def supports_masking(self):
"""Whether this layer supports computing a mask using `compute_mask`."""
return self._supports_masking
@supports_masking.setter
def supports_masking(self, value):
self._supports_masking = value
@utils.default
def compute_mask(self, *args, **kwargs):
raise NotImplementedError
def __call__(self, *args, **kwargs):
self._check_super_called()
######################################
# Argument validation and conversion. #
2023-04-12 21:27:30 +00:00
# 1. Convert any array arguments to tensors of correct dtype.
2023-04-09 19:21:45 +00:00
def maybe_convert(x):
if isinstance(x, np.ndarray) or backend.is_tensor(x):
return backend.convert_to_tensor(x, dtype=self.compute_dtype)
# TODO: cast KerasTensor too
return x
args = nest.map_structure(maybe_convert, args)
kwargs = nest.map_structure(maybe_convert, kwargs)
# 3. Enforce that only tensors can be passed positionally.
for arg in nest.flatten(args):
if not isinstance(arg, KerasTensor) and not backend.is_tensor(arg):
raise ValueError(
"Only input tensors may be passed as "
"positional arguments. The following argument value "
2023-04-18 22:46:57 +00:00
f"should be passed as a keyword argument: {arg} "
f"(of type {type(arg)})"
2023-04-09 19:21:45 +00:00
)
2023-04-12 22:20:56 +00:00
# 4. Check input spec for 1st positional arg.
# TODO: consider extending this to all args and kwargs.
self._assert_input_compatibility(*args)
2023-04-09 19:21:45 +00:00
######################################
###############
# Call build. #
self._maybe_build(*args, **kwargs)
###############
# Maintains info about the `Layer.call` stack.
call_context = self._get_call_context()
# Infer training value
# Training phase for `Layer.call` is set via (in order of priority):
# (1) The `training` argument passed to this `Layer.call`, if it is not None
# (2) The training argument of an outer `Layer.call`.
# (4) Any non-None default value for `training` specified in the call signature
# (5) False (treating the layer as if it's in inference)
arguments_dict = get_arguments_dict(self.call, *args, **kwargs)
training = arguments_dict.get("training", None)
if training is None:
training = call_context.training
if training is None:
training = self._get_default_training_value()
if training is None:
training = False
call_context.training = training
if self._call_has_training_arg():
kwargs["training"] = training
# TODO: Populate mask argument(s)
2023-04-23 02:26:17 +00:00
# Call the layer.
2023-04-09 19:21:45 +00:00
with backend.name_scope(self.name):
outputs = super().__call__(*args, **kwargs)
2023-04-23 02:26:17 +00:00
# Record activity regularizer loss.
if self.activity_regularizer is not None:
self.add_loss(self.activity_regularizer(outputs))
2023-04-09 19:21:45 +00:00
# TODO: Set masks on outputs
2023-04-23 02:26:17 +00:00
# self._set_mask_metadata(inputs, outputs, previous_mask)
2023-04-09 19:21:45 +00:00
# Destroy call context if we created it
self._maybe_reset_call_context()
return outputs
def call(self, *args, **kwargs):
raise NotImplementedError
def stateless_call(
self, trainable_variables, non_trainable_variables, *args, **kwargs
):
# TODO: also handle losses
self._check_super_called()
if not self.built:
raise ValueError(
"To call stateless_call, {self.__class__.__name__} must be built "
"(i.e. its variables must have been already created). "
"You can build it by calling it on some data."
)
if len(trainable_variables) != len(self.trainable_variables):
raise ValueError(
"Argument `trainable_variables` must be a list of tensors "
f"corresponding 1:1 to {self.__class__.__name__}().trainable_variables. "
f"Received list with length {len(trainable_variables)}, but expected "
f"{len(self.trainable_variables)} variables."
)
if len(non_trainable_variables) != len(self.non_trainable_variables):
raise ValueError(
"Argument `non_trainable_variables` must be a list of tensors "
f"corresponding 1:1 to {self.__class__.__name__}().non_trainable_variables. "
f"Received list with length {len(non_trainable_variables)}, but expected "
f"{len(self.non_trainable_variables)} variables."
)
# Gather variable mapping
trainable_mapping = zip(self.trainable_variables, trainable_variables)
non_trainable_mapping = zip(
self.non_trainable_variables, non_trainable_variables
)
mapping = list(trainable_mapping) + list(non_trainable_mapping)
# Call in stateless scope
with backend.StatelessScope(state_mapping=mapping) as scope:
outputs = self.call(*args, **kwargs)
# Gather updated non-trainable variables
non_trainable_variables = []
for v in self.non_trainable_variables:
new_v = scope.get_current_value(v)
if new_v is not None:
non_trainable_variables.append(new_v)
else:
non_trainable_variables.append(v)
return outputs, non_trainable_variables
def compute_output_spec(self, *args, **kwargs):
if utils.is_default(self.compute_output_shape):
return super().compute_output_spec(*args, **kwargs)
else:
# Use compute_output_shape() to return the right output spec
arguments_dict = get_arguments_dict(self.call, *args, **kwargs)
shapes_dict = get_shapes_dict(arguments_dict)
if len(shapes_dict) == 1:
# Single arg: pass it positionally
input_shape = tuple(shapes_dict.values())[0]
output_shape = self.compute_output_shape(input_shape)
else:
# More than one shape: pass them by name.
output_shape = self.compute_output_shape(**shapes_dict)
if (
isinstance(output_shape, tuple)
and output_shape
and isinstance(output_shape[0], (int, type(None)))
):
return KerasTensor(output_shape, dtype=self.compute_dtype)
return nest.map_structure(
lambda s: KerasTensor(s, dtype=self.compute_dtype), output_shape
)
@utils.default
def compute_output_shape(self, *args, **kwargs):
return NotImplementedError
def add_loss(self, loss):
# Eager only.
losses = nest.flatten(loss)
for x in losses:
if not backend.is_tensor(x):
raise ValueError(
"`add_loss()` can only be called from inside `build()` or `call()`, "
f"on a tensor input. Received invalid value: {x}"
)
if backend.in_stateless_scope():
scope = backend.get_stateless_scope()
if scope.collect_losses:
for x in losses:
scope.add_loss(loss)
else:
self._losses.extend(losses)
@property
def losses(self):
losses = self._losses[:]
for layer in self._layers:
losses.extend(layer._losses)
weight_regularization_losses = []
for v in self.trainable_weights:
2023-04-18 22:46:57 +00:00
regularizer = getattr(v, "regularizer", None)
2023-04-09 19:21:45 +00:00
if regularizer:
weight_regularization_losses.append(regularizer(v))
losses.extend(weight_regularization_losses)
return losses
def save_own_variables(self, store):
"""Saves the state of the layer.
You can override this method to take full control of how the state of
the layer is saved upon calling `model.save()`.
Args:
store: Dict where the state of the model will be saved.
"""
all_vars = self._variables
for i, v in enumerate(all_vars):
store[f"{i}"] = np.array(v)
def load_own_variables(self, store):
"""Loads the state of the layer.
You can override this method to take full control of how the state of
the layer is loaded upon calling `keras.models.load_model()`.
Args:
store: Dict from which the state of the model will be loaded.
"""
all_vars = self._variables
if len(store.keys()) != len(all_vars):
2023-04-25 19:59:32 +00:00
if len(all_vars) == 0 and not self.built:
raise ValueError(
f"Layer '{self.name}' was never built "
"and thus it doesn't have any variables. "
f"However the weights file lists {len(store.keys())} "
"variables for this layer. In most cases, "
"this indicates that you need to implement the "
"`def build_from_config(self, config)` method "
"on the layer. "
"You might also want to implement the method "
"that generates the config at saving time, "
"`def get_build_config(self)`. "
"The method `build_from_config()` is meant "
"to create the state "
"of the layer (i.e. its variables) upon deserialization.",
)
raise ValueError(
f"Layer '{self.name}' expected {len(all_vars)} variables, "
"but received "
f"{len(store.keys())} variables during loading. "
f"Expected: {[v.name for v in all_vars]}"
)
for i, v in enumerate(all_vars):
v.assign(store[f"{i}"])
2023-04-09 19:21:45 +00:00
def _clear_losses(self):
if backend.in_stateless_scope():
scope = backend.get_stateless_scope()
if scope.collect_losses:
for x in scope.losses:
if x in self._losses:
scope.losses.remove(x)
self._losses = []
def add_metric(self):
# Permanently disabled
raise NotImplementedError
2023-04-09 22:51:23 +00:00
def count_params(self):
"""Count the total number of scalars composing the weights.
Returns:
An integer count.
"""
if not self.built:
raise ValueError(
"You tried to call `count_params` "
f"on layer '{self.name}'"
", but the layer isn't built. "
"You can build it manually via: "
f"`layer.build(input_shape)`."
)
return summary_utils.count_params(self.weights)
2023-04-09 19:21:45 +00:00
def _maybe_build(self, *args, **kwargs):
if not self.built:
2023-04-25 23:12:14 +00:00
arguments_dict = get_arguments_dict(self.call, *args, **kwargs)
2023-04-09 19:21:45 +00:00
shapes_dict = get_shapes_dict(arguments_dict)
self._build_shapes_dict = shapes_dict
2023-04-26 18:57:46 +00:00
failure = False
2023-04-09 19:21:45 +00:00
if len(shapes_dict) == 1:
# Single arg: pass it positionally
input_shape = tuple(shapes_dict.values())[0]
with backend.name_scope(self.name):
2023-04-26 18:59:47 +00:00
if utils.is_default(
self.build
) and might_have_unbuilt_state(self):
status = self._build_by_run_for_single_pos_arg(
input_shape
)
2023-04-26 18:57:46 +00:00
if not status:
failure = True
else:
self.build(input_shape)
2023-04-09 19:21:45 +00:00
else:
# More than one shape: pass them by name,
# and check that build() expects the right args.
check_build_signature(self.build, shapes_dict)
with backend.name_scope(self.name):
2023-04-26 18:59:47 +00:00
if utils.is_default(
self.build
) and might_have_unbuilt_state(self):
2023-04-26 18:57:46 +00:00
status = self._build_by_run_for_kwargs(shapes_dict)
if not status:
failure = True
else:
self.build(**shapes_dict)
2023-04-26 18:59:47 +00:00
if failure: # TODO: warn or raise
pass
2023-04-09 19:21:45 +00:00
self.built = True
# Check input spec again (after build, since self.input_spec
# may have been updated
2023-04-12 22:20:56 +00:00
self._assert_input_compatibility(*args)
2023-04-09 19:21:45 +00:00
2023-04-26 18:57:46 +00:00
def _build_by_run_for_single_pos_arg(self, input_shape):
# Case: all inputs are in the first arg (possibly nested).
if is_shape_tuple(input_shape):
input_shape = tuple(input_shape)
if isinstance(input_shape, list):
input_tensors = [
2023-04-26 18:59:47 +00:00
backend.KerasTensor(shape, record_history=False)
for shape in input_shape
]
elif isinstance(input_shape, dict):
input_tensors = {
k: backend.KerasTensor(shape, record_history=False)
for k, shape in input_shape.items()
}
else:
2023-04-26 18:59:47 +00:00
input_tensors = backend.KerasTensor(
input_shape, record_history=False
)
try:
2023-04-26 18:57:46 +00:00
self.compute_output_spec(input_tensors)
return True
2023-04-26 18:59:47 +00:00
except:
return False
2023-04-26 18:57:46 +00:00
def _build_by_run_for_kwargs(self, shapes_dict):
# Case: inputs were recorded as multiple keyword arguments.
if all(is_shape_tuple(s) for s in shapes_dict.values()):
# Case: all input keyword arguments were plain tensors.
input_tensors = {
2023-04-26 18:59:47 +00:00
k: backend.KerasTensor(v, record_history=False)
for k, v in shapes_dict.items()
}
try:
2023-04-26 18:57:46 +00:00
self.compute_output_spec(**input_tensors)
except:
return False
else:
# Not supported: nested input keyword arguments.
return False
2023-04-09 19:21:45 +00:00
def __repr__(self):
# TODO: improve
return f"<{self.__class__.__name__} name={self.name}>"
def __str__(self):
# TODO: improve
2023-04-12 21:27:30 +00:00
return f"<{self.__class__.__name__} name={self.name}>"
2023-04-09 19:21:45 +00:00
def __setattr__(self, name, value):
# Track Variables, Layers, Metrics
if hasattr(self, "_tracker"):
value = self._tracker.track(value)
return super().__setattr__(name, value)
def _check_super_called(self):
if not hasattr(self, "_tracker"):
raise RuntimeError(
f"In layer '{self.__class__.__name__}', you forgot to call "
"`super().__init__()` in the `__init__()` method. "
"Go add it!"
)
2023-04-12 22:20:56 +00:00
def _assert_input_compatibility(self, *args):
2023-04-09 19:21:45 +00:00
if args and self.input_spec:
input_spec.assert_input_compatibility(
self.input_spec, args[0], layer_name=self.name
)
def _call_has_training_arg(self):
2023-04-13 00:12:57 +00:00
return "training" in self._call_signature_parameters
def _call_has_mask_arg(self):
return "mask" in self._call_signature_parameters
2023-04-09 19:21:45 +00:00
def _get_call_context(self):
"""Returns currently active `CallContext`."""
global CALL_CTX
call_ctx = getattr(CALL_CTX, "current", None)
if call_ctx is None:
# Enter new call context.
call_ctx = CallContext(entry_layer=self)
CALL_CTX.current = call_ctx
self._clear_losses()
return call_ctx
def _maybe_reset_call_context(self):
global CALL_CTX
call_ctx = getattr(CALL_CTX, "current", None)
if call_ctx is None and call_ctx.entry_layer == self:
CALL_CTX.current = None
def _get_default_training_value(self):
signature = inspect.signature(self.call)
kwargs = [
p.name
for p in signature.parameters.values()
if p.default is not inspect.Parameter.empty
]
if not kwargs:
return None
values = self.call.__defaults__
mapping = dict(zip(kwargs, values))
return mapping.get("training", None)
2023-04-09 19:21:45 +00:00
def _flatten_layers(self, include_self=True, recursive=True):
layers = []
if include_self:
layers.append(self)
seen_object_ids = set()
deque = collections.deque(self._layers)
while deque:
layer = deque.popleft()
if id(layer) in seen_object_ids:
continue
seen_object_ids.add(id(layer))
2023-04-18 22:46:57 +00:00
layers.append(layer)
2023-04-09 19:21:45 +00:00
# Introspect recursively through sublayers.
if recursive:
deque.extendleft(layer._layers)
2023-04-18 22:46:57 +00:00
return layers
2023-04-09 19:21:45 +00:00
2023-04-23 02:26:17 +00:00
def _set_mask_metadata(self, inputs, outputs, previous_mask):
# Many `Layer`s don't need to call `compute_mask`.
# This method is optimized to do as little work as needed for the common
# case.
if not self._supports_masking:
return
flat_outputs = nest.flatten(outputs)
mask_already_computed = all(
getattr(x, "_keras_mask", None) is not None for x in flat_outputs
)
if mask_already_computed:
return
output_masks = self.compute_mask(inputs, previous_mask)
if output_masks is None:
return
flat_masks = nest.flatten(output_masks)
for tensor, mask in zip(flat_outputs, flat_masks):
tensor._keras_mask = mask
2023-04-09 19:21:45 +00:00
def get_arguments_dict(fn, *args, **kwargs):
"""Return a dict mapping argument names to their values."""
sig = inspect.signature(fn)
bound_args = sig.bind(*args, **kwargs)
arg_dict = {}
for name, value in bound_args.arguments.items():
arg_dict[name] = value
return arg_dict
def get_shapes_dict(arguments_dict):
"""Convert the call() arguments dict into a dict of input shape arguments.
Example:
```
>>> get_shapes_dict({"input_a": KerasTensor(shape=(2, 3)), "training": False})
{"input_a_shape": (2, 3)}
```
"""
shapes_dict = {}
for k, v in arguments_dict.items():
if isinstance(v, KerasTensor) or backend.is_tensor(v):
shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape)
elif nest.is_nested(v):
flat = nest.flatten(v)
if any(
isinstance(x, KerasTensor) or backend.is_tensor(x) for x in flat
):
2023-04-09 19:21:45 +00:00
if not all(
isinstance(x, KerasTensor) or backend.is_tensor(x)
for x in flat
2023-04-09 19:21:45 +00:00
):
raise ValueError(
"You cannot mix tensors and non-tensors in a nested argument. "
f"Invalid argument: {k}={v}"
)
shapes_dict[f"{k}_shape"] = nest.map_structure(
lambda x: backend.standardize_shape(x.shape), v
)
return shapes_dict
def check_build_signature(build_fn, shapes_dict):
"""Asserts that the argument names in build_fn match the entries in shapes_dict.
For instance if call() has the signature `def call(self, a, b)`
then we'll see `shapes_dict == {"a_shape": (...), "b_shape": (...)}
and we expect build() to have signature `def build(self, a_shape, b_shape)`.
When there is a single tensor argument, we pass it positionally and thus
don't check names (if we did, it would force call() to always take
`input` as its first argument, which is usually not the case).
"""
if len(shapes_dict) == 1:
return
if utils.is_default(build_fn):
return
sig = inspect.signature(build_fn)
expected_names = []
for name, param in sig.parameters.items():
if param.kind in (
param.POSITIONAL_OR_KEYWORD,
param.POSITIONAL_ONLY,
param.KEYWORD_ONLY,
):
expected_names.append(name)
if set(expected_names) != set(shapes_dict.keys()):
comma_separated = ", ".join(shapes_dict.keys())
raise ValueError(
"For a `call()` method with more than one tensor argument, "
"the arguments of the `build()` method should match the "
"tensor arguments of `call()` method. Here we expect the signature "
f"`build(self, {comma_separated})`."
)
CALL_CTX = threading.local()
class CallContext:
def __init__(self, entry_layer):
self.entry_layer = entry_layer
self.training = None
2023-04-26 17:29:40 +00:00
def is_shape_tuple(s):
return isinstance(s, (list, tuple)) and all(
d is None or isinstance(d, int) for d in s
)
2023-04-26 18:57:46 +00:00
def might_have_unbuilt_state(layer):
return any(not lr.built for lr in layer._layers)