keras/keras_core/layers/preprocessing/random_contrast.py
Neel Kovelamudi b518b9ef2b Adds RandomContrast and RandomBrightness and associated tests (#130)
* Adds unit normalization and tests

* Adds layer normalization and initial tests

* Fixes formatting in docstrings

* Fix type issues for JAX

* Fix nits

* Initial stash for group_normalization and spectral_normalization

* Adds spectral normalization and tests

* Adds group normalization and tests

* Formatting fixes

* Fix small nit in docstring

* Fix docstring and tests

* Adds RandomContrast and associated tests

* Remove arithmetic comment

* Adds RandomBrightness and tests

* Fix docstring and format

* Fix nits and add backend generator

* Inlines random_contrast helper
2023-05-11 04:19:03 +00:00

95 lines
3.4 KiB
Python

from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.backend import random
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.RandomContrast")
class RandomContrast(Layer):
"""A preprocessing layer which randomly adjusts contrast during training.
This layer will randomly adjust the contrast of an image or images
by a random factor. Contrast is adjusted independently
for each channel of each image during training.
For each channel, this layer computes the mean of the image pixels in the
channel and then adjusts each component `x` of each pixel to
`(x - mean) * contrast_factor + mean`.
Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and
in integer or floating point dtype.
By default, the layer will output floats.
Input shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format.
Output shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format.
Args:
factor: a positive float represented as fraction of value, or a tuple of
size 2 representing lower and upper bound.
When represented as a single float, lower = upper.
The contrast factor will be randomly picked between
`[1.0 - lower, 1.0 + upper]`. For any pixel x in the channel,
the output will be `(x - mean) * factor + mean`
where `mean` is the mean value of the channel.
seed: Integer. Used to create a random seed.
"""
def __init__(self, factor, seed=None, **kwargs):
super().__init__(**kwargs)
self.factor = factor
if isinstance(factor, (tuple, list)):
self.lower = factor[0]
self.upper = factor[1]
else:
self.lower = self.upper = factor
if self.lower < 0.0 or self.upper < 0.0 or self.lower > 1.0:
raise ValueError(
"`factor` argument cannot have negative values or values "
"greater than 1."
f"Received: factor={factor}"
)
self.seed = seed
self.generator = random.SeedGenerator(seed)
def call(self, inputs, training=True):
inputs = ops.cast(inputs, self.compute_dtype)
if training:
factor = ops.random.uniform(
shape=(),
minval=1.0 - self.lower,
maxval=1.0 + self.upper,
seed=self.generator,
)
outputs = self._adjust_constrast(inputs, factor)
outputs = ops.clip(outputs, 0, 255)
ops.reshape(outputs, inputs.shape)
return outputs
else:
return inputs
def _adjust_constrast(self, inputs, contrast_factor):
# reduce mean on height
inp_mean = ops.mean(inputs, axis=-3, keepdims=True)
# reduce mean on width
inp_mean = ops.mean(inp_mean, axis=-2, keepdims=True)
outputs = (inputs - inp_mean) * contrast_factor + inp_mean
return outputs
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = {
"factor": self.factor,
"seed": self.seed,
}
base_config = super().get_config()
return {**base_config, **config}