keras/keras_core/layers/core/dense.py
2023-04-22 09:08:31 -07:00

76 lines
2.7 KiB
Python

from keras_core import activations
from keras_core import constraints
from keras_core import initializers
from keras_core import operations as ops
from keras_core import regularizers
from keras_core.layers.layer import Layer
class Dense(Layer):
def __init__(
self,
units,
activation=None,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
name=None,
):
super().__init__(name=name)
self.units = units
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
if activity_regularizer:
# TODO
raise ValueError("activity_regularizer not yet supported.")
def build(self, input_shape):
input_dim = input_shape[-1]
self.kernel = self.add_weight(
shape=(input_dim, self.units),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
)
if self.use_bias:
self.bias = self.add_weight(
shape=(self.units,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
)
def call(self, inputs):
x = ops.matmul(inputs, self.kernel)
if self.use_bias:
x = x + self.bias
return self.activation(x)
def get_config(self):
base_config = super().get_config()
# TODO
config = {
"units": self.units,
"activation": activations.serialize(self.activation),
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(
self.kernel_initializer
),
"bias_initializer": initializers.serialize(self.bias_initializer),
"kernel_regularizer": regularizers.serialize(self.kernel_regularizer),
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
"kernel_constraint": constraints.serialize(self.kernel_constraint),
"bias_constraint": constraints.serialize(self.bias_constraint),
}
return {**base_config, **config}