53 lines
1.7 KiB
Python
53 lines
1.7 KiB
Python
from keras_core import operations as ops
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.layers.layer import Layer
|
|
|
|
|
|
@keras_core_export("keras_core.layers.Rescaling")
|
|
class Rescaling(Layer):
|
|
"""A preprocessing layer which rescales input values to a new range.
|
|
|
|
This layer rescales every value of an input (often an image) by multiplying
|
|
by `scale` and adding `offset`.
|
|
|
|
For instance:
|
|
|
|
1. To rescale an input in the `[0, 255]` range
|
|
to be in the `[0, 1]` range, you would pass `scale=1./255`.
|
|
|
|
2. To rescale an input in the `[0, 255]` range to be in the `[-1, 1]` range,
|
|
you would pass `scale=1./127.5, offset=-1`.
|
|
|
|
The rescaling is applied both during training and inference. Inputs can be
|
|
of integer or floating point dtype, and by default the layer will output
|
|
floats.
|
|
|
|
Args:
|
|
scale: Float, the scale to apply to the inputs.
|
|
offset: Float, the offset to apply to the inputs.
|
|
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
|
|
"""
|
|
|
|
def __init__(self, scale, offset=0.0, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.scale = scale
|
|
self.offset = offset
|
|
self.supports_masking = True
|
|
|
|
def call(self, inputs):
|
|
dtype = self.compute_dtype
|
|
scale = ops.cast(self.scale, dtype)
|
|
offset = ops.cast(self.offset, dtype)
|
|
return inputs * scale + offset
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
return input_shape
|
|
|
|
def get_config(self):
|
|
base_config = super().get_config()
|
|
config = {
|
|
"scale": self.scale,
|
|
"offset": self.offset,
|
|
}
|
|
return {**base_config, **config}
|