keras/keras_core/layers/core/wrapper.py
2023-05-08 13:51:15 -07:00

48 lines
1.5 KiB
Python

from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.saving import serialization_lib
@keras_core_export("keras_core.layers.Wrapper")
class Wrapper(Layer):
"""Abstract wrapper base class.
Wrappers take another layer and augment it in various ways.
Do not use this class as a layer, it is only an abstract base class.
Two usable wrappers are the `TimeDistributed` and `Bidirectional` layers.
Args:
layer: The layer to be wrapped.
"""
def __init__(self, layer, **kwargs):
try:
assert isinstance(layer, Layer)
except Exception:
raise ValueError(
f"Layer {layer} supplied to Wrapper isn't "
"a supported layer type. Please "
"ensure wrapped layer is a valid Keras layer."
)
super().__init__(**kwargs)
self.layer = layer
def build(self, input_shape=None):
if not self.layer.built:
self.layer.build(input_shape)
self.layer.built = True
self.built = True
def get_config(self):
config = {"layer": serialization_lib.serialize_keras_object(self.layer)}
base_config = super().get_config()
return {**base_config, **config}
@classmethod
def from_config(cls, config, custom_objects=None):
layer = serialization_lib.deserialize_keras_object(
config.pop("layer"),
custom_objects=custom_objects,
)
return cls(layer, **config)