from keras_core import activations from keras_core import constraints from keras_core import initializers from keras_core import operations as ops from keras_core import regularizers from keras_core.api_export import keras_core_export from keras_core.layers.input_spec import InputSpec from keras_core.layers.layer import Layer @keras_core_export("keras_core.layers.Dense") class Dense(Layer): """Just your regular densely-connected NN layer. `Dense` implements the operation: `output = activation(dot(input, kernel) + bias)` where `activation` is the element-wise activation function passed as the `activation` argument, `kernel` is a weights matrix created by the layer, and `bias` is a bias vector created by the layer (only applicable if `use_bias` is `True`). Note: If the input to the layer has a rank greater than 2, `Dense` computes the dot product between the `inputs` and the `kernel` along the last axis of the `inputs` and axis 0 of the `kernel` (using `tf.tensordot`). For example, if input has dimensions `(batch_size, d0, d1)`, then we create a `kernel` with shape `(d1, units)`, and the `kernel` operates along axis 2 of the `input`, on every sub-tensor of shape `(1, 1, d1)` (there are `batch_size * d0` such sub-tensors). The output in this case will have shape `(batch_size, d0, units)`. Args: units: Positive integer, dimensionality of the output space. activation: Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix. bias_initializer: Initializer for the bias vector. kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. bias_regularizer: Regularizer function applied to the bias vector. activity_regularizer: Regularizer function applied to the output of the layer (its "activation"). kernel_constraint: Constraint function applied to the `kernel` weights matrix. bias_constraint: Constraint function applied to the bias vector. Input shape: N-D tensor with shape: `(batch_size, ..., input_dim)`. The most common situation would be a 2D input with shape `(batch_size, input_dim)`. Output shape: N-D tensor with shape: `(batch_size, ..., units)`. For instance, for a 2D input with shape `(batch_size, input_dim)`, the output would have shape `(batch_size, units)`. """ def __init__( self, units, activation=None, use_bias=True, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs, ): super().__init__(activity_regularizer=activity_regularizer, **kwargs) self.units = units self.activation = activations.get(activation) self.use_bias = use_bias self.kernel_initializer = initializers.get(kernel_initializer) self.bias_initializer = initializers.get(bias_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) self.input_spec = InputSpec(min_ndim=2) def build(self, input_shape): input_dim = input_shape[-1] self.kernel = self.add_weight( shape=(input_dim, self.units), initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, ) if self.use_bias: self.bias = self.add_weight( shape=(self.units,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, ) self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim}) self.built = True def call(self, inputs): x = ops.matmul(inputs, self.kernel) if self.use_bias: x = x + self.bias return self.activation(x) def compute_output_shape(self, input_shape): output_shape = list(input_shape) output_shape[-1] = self.units return tuple(output_shape) def get_config(self): base_config = super().get_config() config = { "units": self.units, "activation": activations.serialize(self.activation), "use_bias": self.use_bias, "kernel_initializer": initializers.serialize( self.kernel_initializer ), "bias_initializer": initializers.serialize(self.bias_initializer), "kernel_regularizer": regularizers.serialize( self.kernel_regularizer ), "bias_regularizer": regularizers.serialize(self.bias_regularizer), "kernel_constraint": constraints.serialize(self.kernel_constraint), "bias_constraint": constraints.serialize(self.bias_constraint), } return {**base_config, **config}