Finalize Dense layer.

This commit is contained in:
Francois Chollet 2023-04-22 19:26:17 -07:00
parent c690e1b5a6
commit 1181f444f2
3 changed files with 94 additions and 4 deletions

@ -3,10 +3,59 @@ from keras_core import constraints
from keras_core import initializers
from keras_core import operations as ops
from keras_core import regularizers
from keras_core.api_export import keras_core_export
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.Dense")
class Dense(Layer):
"""Just your regular densely-connected NN layer.
`Dense` implements the operation:
`output = activation(dot(input, kernel) + bias)`
where `activation` is the element-wise activation function
passed as the `activation` argument, `kernel` is a weights matrix
created by the layer, and `bias` is a bias vector created by the layer
(only applicable if `use_bias` is `True`).
Note: If the input to the layer has a rank greater than 2, `Dense`
computes the dot product between the `inputs` and the `kernel` along the
last axis of the `inputs` and axis 0 of the `kernel` (using `tf.tensordot`).
For example, if input has dimensions `(batch_size, d0, d1)`, then we create
a `kernel` with shape `(d1, units)`, and the `kernel` operates along axis 2
of the `input`, on every sub-tensor of shape `(1, 1, d1)` (there are
`batch_size * d0` such sub-tensors). The output in this case will have
shape `(batch_size, d0, units)`.
Args:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
If you don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix.
bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix.
bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to
the output of the layer (its "activation").
kernel_constraint: Constraint function applied to
the `kernel` weights matrix.
bias_constraint: Constraint function applied to the bias vector.
Input shape:
N-D tensor with shape: `(batch_size, ..., input_dim)`.
The most common situation would be
a 2D input with shape `(batch_size, input_dim)`.
Output shape:
N-D tensor with shape: `(batch_size, ..., units)`.
For instance, for a 2D input with shape `(batch_size, input_dim)`,
the output would have shape `(batch_size, units)`.
"""
def __init__(
self,
units,
@ -32,9 +81,7 @@ class Dense(Layer):
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
if activity_regularizer:
# TODO
raise ValueError("activity_regularizer not yet supported.")
self.input_spec = InputSpec(min_ndim=2)
def build(self, input_shape):
input_dim = input_shape[-1]
@ -49,6 +96,7 @@ class Dense(Layer):
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
)
self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
self.built = True
def call(self, inputs):

@ -2,6 +2,7 @@ import numpy as np
from keras_core import layers
from keras_core import testing
from keras_core.backend import keras_tensor
class DenseTest(testing.TestCase):
@ -55,3 +56,9 @@ class DenseTest(testing.TestCase):
[[-1.0, 2.0]],
)
self.assertAllClose(layer(inputs), [[10.0, 0.0]])
def test_dense_errors(self):
with self.assertRaisesRegex(ValueError, "incompatible with the layer"):
layer = layers.Dense(units=2, activation="relu")
layer(keras_tensor.KerasTensor((1, 2)))
layer(keras_tensor.KerasTensor((1, 3)))

@ -13,6 +13,7 @@ And some more magic:
- add_loss
- metric tracking
- RNG seed tracking
- activity regularization
"""
import collections
import inspect
@ -24,6 +25,7 @@ from tensorflow import nest
from keras_core import backend
from keras_core import initializers
from keras_core import regularizers
from keras_core import utils
from keras_core.api_export import keras_core_export
from keras_core.backend import KerasTensor
@ -38,8 +40,11 @@ from keras_core.utils.tracking import Tracker
@keras_core_export(["keras_core.Layer", "keras_core.layers.Layer"])
class Layer(Operation):
def __init__(self, trainable=True, dtype=None, name=None):
def __init__(
self, activity_regularizer=None, trainable=True, dtype=None, name=None
):
super().__init__(name=name)
self.activity_regularizer = regularizers.get(activity_regularizer)
self._trainable = trainable
if dtype is None:
dtype = backend.floatx()
@ -316,9 +321,16 @@ class Layer(Operation):
kwargs["training"] = training
# TODO: Populate mask argument(s)
# Call the layer.
with backend.name_scope(self.name):
outputs = super().__call__(*args, **kwargs)
# Record activity regularizer loss.
if self.activity_regularizer is not None:
self.add_loss(self.activity_regularizer(outputs))
# TODO: Set masks on outputs
# self._set_mask_metadata(inputs, outputs, previous_mask)
# Destroy call context if we created it
self._maybe_reset_call_context()
@ -566,6 +578,29 @@ class Layer(Operation):
deque.extendleft(layer._layers)
return layers
def _set_mask_metadata(self, inputs, outputs, previous_mask):
# Many `Layer`s don't need to call `compute_mask`.
# This method is optimized to do as little work as needed for the common
# case.
if not self._supports_masking:
return
flat_outputs = nest.flatten(outputs)
mask_already_computed = all(
getattr(x, "_keras_mask", None) is not None for x in flat_outputs
)
if mask_already_computed:
return
output_masks = self.compute_mask(inputs, previous_mask)
if output_masks is None:
return
flat_masks = nest.flatten(output_masks)
for tensor, mask in zip(flat_outputs, flat_masks):
tensor._keras_mask = mask
def get_arguments_dict(fn, *args, **kwargs):
"""Return a dict mapping argument names to their values."""