keras/keras_core/layers/preprocessing/rescaling.py
2023-05-05 17:12:09 -07:00

53 lines
1.7 KiB
Python

from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.Rescaling")
class Rescaling(Layer):
"""A preprocessing layer which rescales input values to a new range.
This layer rescales every value of an input (often an image) by multiplying
by `scale` and adding `offset`.
For instance:
1. To rescale an input in the `[0, 255]` range
to be in the `[0, 1]` range, you would pass `scale=1./255`.
2. To rescale an input in the `[0, 255]` range to be in the `[-1, 1]` range,
you would pass `scale=1./127.5, offset=-1`.
The rescaling is applied both during training and inference. Inputs can be
of integer or floating point dtype, and by default the layer will output
floats.
Args:
scale: Float, the scale to apply to the inputs.
offset: Float, the offset to apply to the inputs.
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
"""
def __init__(self, scale, offset=0.0, **kwargs):
super().__init__(**kwargs)
self.scale = scale
self.offset = offset
self.supports_masking = True
def call(self, inputs):
dtype = self.compute_dtype
scale = ops.cast(self.scale, dtype)
offset = ops.cast(self.offset, dtype)
return inputs * scale + offset
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
base_config = super().get_config()
config = {
"scale": self.scale,
"offset": self.offset,
}
return {**base_config, **config}