keras/keras_core/layers/activations/activation.py
2023-05-19 11:40:25 -07:00

40 lines
1.2 KiB
Python

from keras_core import activations
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.Activation")
class Activation(Layer):
"""Applies an activation function to an output.
Args:
activation: Activation function. It could be a callable, or the name of
an activation from the `keras_core.activations` namespace.
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
Example:
>>> layer = keras_core.layers.Activation('relu')
>>> layer([-3.0, -1.0, 0.0, 2.0])
[0.0, 0.0, 0.0, 2.0]
>>> layer = keras_core.layers.Activation(keras_core.activations.relu)
>>> layer([-3.0, -1.0, 0.0, 2.0])
[0.0, 0.0, 0.0, 2.0]
"""
def __init__(self, activation, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True
self.activation = activations.get(activation)
def call(self, inputs):
return self.activation(inputs)
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = {"activation": activations.serialize(self.activation)}
base_config = super().get_config()
return {**base_config, **config}