from tensorflow import nest from keras_core import backend from keras_core.api_export import keras_core_export @keras_core_export(["keras_core.InputSpec", "keras_core.layers.InputSpec"]) class InputSpec: """Specifies the rank, dtype and shape of every input to a layer. Layers can expose (if appropriate) an `input_spec` attribute: an instance of `InputSpec`, or a nested structure of `InputSpec` instances (one per input tensor). These objects enable the layer to run input compatibility checks for input structure, input rank, input shape, and input dtype for the first argument of `Layer.__call__`. A `None` entry in a shape is compatible with any dimension. Args: dtype: Expected DataType of the input. shape: Shape tuple, expected shape of the input (may include None for unchecked axes). Includes the batch size. ndim: Integer, expected rank of the input. max_ndim: Integer, maximum rank of the input. min_ndim: Integer, minimum rank of the input. axes: Dictionary mapping integer axes to a specific dimension value. allow_last_axis_squeeze: If True, then allow inputs of rank N+1 as long as the last axis of the input is 1, as well as inputs of rank N-1 as long as the last axis of the spec is 1. name: Expected key corresponding to this input when passing data as a dictionary. Example: ```python class MyLayer(Layer): def __init__(self): super(MyLayer, self).__init__() # The layer will accept inputs with # shape (*, 28, 28) & (*, 28, 28, 1) # and raise an appropriate error message otherwise. self.input_spec = InputSpec( shape=(None, 28, 28, 1), allow_last_axis_squeeze=True) ``` """ def __init__( self, dtype=None, shape=None, ndim=None, max_ndim=None, min_ndim=None, axes=None, allow_last_axis_squeeze=False, name=None, ): self.dtype = ( backend.standardize_dtype(dtype) if dtype is not None else None ) if shape is not None: self.shape = backend.standardize_shape(shape) self.ndim = len(shape) else: self.ndim = ndim self.shape = None self.max_ndim = max_ndim self.min_ndim = min_ndim self.name = name self.allow_last_axis_squeeze = allow_last_axis_squeeze try: axes = axes or {} self.axes = {int(k): axes[k] for k in axes} except (ValueError, TypeError): raise TypeError( "Argument `axes` must be a dict with integer keys. " f"Received: axes={axes}" ) if self.axes and (self.ndim is not None or self.max_ndim is not None): max_dim = (self.ndim if self.ndim else self.max_ndim) - 1 max_axis = max(self.axes) if max_axis > max_dim: raise ValueError( "Axis {} is greater than the maximum " "allowed value: {}".format(max_axis, max_dim) ) def __repr__(self): spec = [ ("dtype=" + str(self.dtype)) if self.dtype else "", ("shape=" + str(self.shape)) if self.shape else "", ("ndim=" + str(self.ndim)) if self.ndim else "", ("max_ndim=" + str(self.max_ndim)) if self.max_ndim else "", ("min_ndim=" + str(self.min_ndim)) if self.min_ndim else "", ("axes=" + str(self.axes)) if self.axes else "", ] return f"InputSpec({', '.join(x for x in spec if x)})" def get_config(self): return { "dtype": self.dtype, "shape": self.shape, "ndim": self.ndim, "max_ndim": self.max_ndim, "min_ndim": self.min_ndim, "axes": self.axes, } @classmethod def from_config(cls, config): return cls(**config) def assert_input_compatibility(input_spec, inputs, layer_name): """Checks compatibility between the layer and provided inputs. This checks that the tensor(s) `inputs` verify the input assumptions of a layer (if any). If not, a clear and actional exception gets raised. Args: input_spec: An InputSpec instance, list of InputSpec instances, a nested structure of InputSpec instances, or None. inputs: Input tensor, list of input tensors, or a nested structure of input tensors. layer_name: String, name of the layer (for error message formatting). Raises: ValueError: in case of mismatch between the provided inputs and the expectations of the layer. """ if not input_spec: return input_spec = nest.flatten(input_spec) if isinstance(inputs, dict): # Flatten `inputs` by reference order if input spec names are provided names = [spec.name for spec in input_spec] if all(names): list_inputs = [] for name in names: if name not in inputs: raise ValueError( f'Missing data for input "{name}". ' "You passed a data dictionary with keys " f"{list(inputs.keys())}. " f"Expected the following keys: {names}" ) list_inputs.append(inputs[name]) inputs = list_inputs inputs = nest.flatten(inputs) for x in inputs: # Having a shape/dtype is the only commonality of the various # tensor-like objects that may be passed. The most common kind of # invalid type we are guarding for is a Layer instance (Functional API), # which does not have a `shape` attribute. if not hasattr(x, "shape"): raise TypeError( f"Inputs to a layer should be tensors. Got '{x}' " f"(of type {type(x)}) as input for layer '{layer_name}'." ) if len(inputs) != len(input_spec): raise ValueError( f'Layer "{layer_name}" expects {len(input_spec)} input(s),' f" but it received {len(inputs)} input tensors. " f"Inputs received: {inputs}" ) for input_index, (x, spec) in enumerate(zip(inputs, input_spec)): if spec is None: continue shape = backend.standardize_shape(x.shape) ndim = len(shape) # Check ndim. if spec.ndim is not None and not spec.allow_last_axis_squeeze: if ndim != spec.ndim: raise ValueError( f'Input {input_index} of layer "{layer_name}" ' "is incompatible with the layer: " f"expected ndim={spec.ndim}, found ndim={ndim}. " f"Full shape received: {shape}" ) if spec.max_ndim is not None: if ndim is not None and ndim > spec.max_ndim: raise ValueError( f'Input {input_index} of layer "{layer_name}" ' "is incompatible with the layer: " f"expected max_ndim={spec.max_ndim}, " f"found ndim={ndim}" ) if spec.min_ndim is not None: if ndim is not None and ndim < spec.min_ndim: raise ValueError( f'Input {input_index} of layer "{layer_name}" ' "is incompatible with the layer: " f"expected min_ndim={spec.min_ndim}, " f"found ndim={ndim}. " f"Full shape received: {shape}" ) # Check dtype. if spec.dtype is not None: dtype = backend.standardize_dtype(x.dtype) if dtype != spec.dtype: raise ValueError( f'Input {input_index} of layer "{layer_name}" ' "is incompatible with the layer: " f"expected dtype={spec.dtype}, " f"found dtype={dtype}" ) # Check specific shape axes. if spec.axes: for axis, value in spec.axes.items(): if value is not None and shape[axis] not in { value, None, }: raise ValueError( f'Input {input_index} of layer "{layer_name}" is ' f"incompatible with the layer: expected axis {axis} " f"of input shape to have value {value}, " "but received input with " f"shape {shape}" ) # Check shape. if spec.shape is not None: spec_shape = spec.shape if spec.allow_last_axis_squeeze: if shape and shape[-1] == 1: shape = shape[:-1] if spec_shape and spec_shape[-1] == 1: spec_shape = spec_shape[:-1] for spec_dim, dim in zip(spec_shape, shape): if spec_dim is not None and dim is not None: if spec_dim != dim: raise ValueError( f'Input {input_index} of layer "{layer_name}" is ' "incompatible with the layer: " f"expected shape={spec.shape}, " f"found shape={shape}" )