Add Dropout layer.
This commit is contained in:
parent
f6df67f2d2
commit
3aa1d977b1
@ -5,11 +5,12 @@ class SeedGenerator:
|
|||||||
def __init__(self, seed):
|
def __init__(self, seed):
|
||||||
from keras_core.backend import Variable
|
from keras_core.backend import Variable
|
||||||
|
|
||||||
|
if seed is None:
|
||||||
|
seed = make_default_seed()
|
||||||
if not isinstance(seed, int):
|
if not isinstance(seed, int):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Argument `seed` must be an integer. " f"Received: seed={seed}"
|
"Argument `seed` must be an integer. " f"Received: seed={seed}"
|
||||||
)
|
)
|
||||||
seed = seed or make_default_seed()
|
|
||||||
self.state = Variable([seed, 0], dtype="uint32", trainable=False)
|
self.state = Variable([seed, 0], dtype="uint32", trainable=False)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,5 +2,4 @@ from keras_core.layers.core.dense import Dense
|
|||||||
from keras_core.layers.core.input_layer import Input
|
from keras_core.layers.core.input_layer import Input
|
||||||
from keras_core.layers.core.input_layer import InputLayer
|
from keras_core.layers.core.input_layer import InputLayer
|
||||||
from keras_core.layers.layer import Layer
|
from keras_core.layers.layer import Layer
|
||||||
|
from keras_core.layers.regularization.dropout import Dropout
|
||||||
# from keras_core.layers.regularization.dropout import Dropout
|
|
||||||
|
@ -0,0 +1,77 @@
|
|||||||
|
from keras_core import backend
|
||||||
|
from keras_core import layers
|
||||||
|
from keras_core.api_export import keras_core_export
|
||||||
|
|
||||||
|
|
||||||
|
@keras_core_export("keras_core.layers.Dropout")
|
||||||
|
class Dropout(layers.Layer):
|
||||||
|
"""Applies dropout to the input.
|
||||||
|
|
||||||
|
The `Dropout` layer randomly sets input units to 0 with a frequency of `rate`
|
||||||
|
at each step during training time, which helps prevent overfitting.
|
||||||
|
Inputs not set to 0 are scaled up by `1 / (1 - rate)` such that the sum over
|
||||||
|
all inputs is unchanged.
|
||||||
|
|
||||||
|
Note that the `Dropout` layer only applies when `training` is set to `True`
|
||||||
|
in `call()`, such that no values are dropped during inference.
|
||||||
|
When using `model.fit`, `training` will be appropriately set to `True`
|
||||||
|
automatically. In other contexts, you can set the argument explicitly
|
||||||
|
to `True` when calling the layer.
|
||||||
|
|
||||||
|
(This is in contrast to setting `trainable=False` for a `Dropout` layer.
|
||||||
|
`trainable` does not affect the layer's behavior, as `Dropout` does
|
||||||
|
not have any variables/weights that can be frozen during training.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rate: Float between 0 and 1. Fraction of the input units to drop.
|
||||||
|
noise_shape: 1D integer tensor representing the shape of the
|
||||||
|
binary dropout mask that will be multiplied with the input.
|
||||||
|
For instance, if your inputs have shape
|
||||||
|
`(batch_size, timesteps, features)` and
|
||||||
|
you want the dropout mask to be the same for all timesteps,
|
||||||
|
you can use `noise_shape=(batch_size, 1, features)`.
|
||||||
|
seed: A Python integer to use as random seed.
|
||||||
|
|
||||||
|
Call arguments:
|
||||||
|
inputs: Input tensor (of any rank).
|
||||||
|
training: Python boolean indicating whether the layer should behave in
|
||||||
|
training mode (adding dropout) or in inference mode (doing nothing).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, rate, noise_shape=None, seed=None, name=None, dtype=None
|
||||||
|
):
|
||||||
|
super().__init__(name=name, dtype=dtype)
|
||||||
|
if isinstance(rate, (int, float)) and not 0 <= rate <= 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid value received for argument "
|
||||||
|
"`rate`. Expected a float value between 0 and 1. "
|
||||||
|
f"Received: rate={rate}"
|
||||||
|
)
|
||||||
|
self.rate = rate
|
||||||
|
self.seed = seed
|
||||||
|
self.noise_shape = noise_shape
|
||||||
|
self.seed_generator = backend.random.SeedGenerator(seed)
|
||||||
|
self.supports_masking = True
|
||||||
|
|
||||||
|
def call(self, inputs, training=False):
|
||||||
|
if training and self.rate > 0:
|
||||||
|
return backend.random.dropout(
|
||||||
|
inputs,
|
||||||
|
self.rate,
|
||||||
|
noise_shape=self.noise_shape,
|
||||||
|
seed=self.seed_generator,
|
||||||
|
)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def compute_output_shape(self, input_shape):
|
||||||
|
return input_shape
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
base_config = super().get_config()
|
||||||
|
config = {
|
||||||
|
"rate": self.rate,
|
||||||
|
"seed": self.seed,
|
||||||
|
"noise_shape": self.noise_shape,
|
||||||
|
}
|
||||||
|
return {**base_config, **config}
|
34
keras_core/layers/regularization/dropout_test.py
Normal file
34
keras_core/layers/regularization/dropout_test.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from keras_core import backend
|
||||||
|
from keras_core import layers
|
||||||
|
from keras_core import testing
|
||||||
|
|
||||||
|
|
||||||
|
class DropoutTest(testing.TestCase):
|
||||||
|
def test_dropout_supports_masking(self):
|
||||||
|
dropout = layers.Dropout(0.5)
|
||||||
|
self.assertEqual(True, dropout.supports_masking)
|
||||||
|
|
||||||
|
def test_dropout_rescaling(self):
|
||||||
|
inputs = np.ones((20, 500))
|
||||||
|
layer = layers.Dropout(0.5, seed=1337)
|
||||||
|
outputs = layer(inputs, training=True)
|
||||||
|
self.assertAllClose(np.mean(outputs), 1.0, atol=0.02)
|
||||||
|
self.assertAllClose(np.max(outputs), 2.0)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
backend.backend() != "tensorflow", reason="Requires dynamic shapes"
|
||||||
|
)
|
||||||
|
def test_dropout_partial_noise_shape_dynamic(self):
|
||||||
|
inputs = np.ones((20, 5, 10))
|
||||||
|
layer = layers.Dropout(0.5, noise_shape=(None, 1, None))
|
||||||
|
outputs = layer(inputs, training=True)
|
||||||
|
self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :])
|
||||||
|
|
||||||
|
def test_dropout_partial_noise_shape_static(self):
|
||||||
|
inputs = np.ones((20, 5, 10))
|
||||||
|
layer = layers.Dropout(0.5, noise_shape=(20, 1, 10))
|
||||||
|
outputs = layer(inputs, training=True)
|
||||||
|
self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :])
|
@ -1,8 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from keras_core import backend
|
from keras_core import backend
|
||||||
from keras_core import initializers
|
|
||||||
from keras_core import operations as ops
|
|
||||||
from keras_core import regularizers
|
from keras_core import regularizers
|
||||||
from keras_core import testing
|
from keras_core import testing
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user