b518b9ef2b
* Adds unit normalization and tests * Adds layer normalization and initial tests * Fixes formatting in docstrings * Fix type issues for JAX * Fix nits * Initial stash for group_normalization and spectral_normalization * Adds spectral normalization and tests * Adds group normalization and tests * Formatting fixes * Fix small nit in docstring * Fix docstring and tests * Adds RandomContrast and associated tests * Remove arithmetic comment * Adds RandomBrightness and tests * Fix docstring and format * Fix nits and add backend generator * Inlines random_contrast helper
95 lines
3.4 KiB
Python
95 lines
3.4 KiB
Python
from keras_core import operations as ops
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.backend import random
|
|
from keras_core.layers.layer import Layer
|
|
|
|
|
|
@keras_core_export("keras_core.layers.RandomContrast")
|
|
class RandomContrast(Layer):
|
|
"""A preprocessing layer which randomly adjusts contrast during training.
|
|
|
|
This layer will randomly adjust the contrast of an image or images
|
|
by a random factor. Contrast is adjusted independently
|
|
for each channel of each image during training.
|
|
|
|
For each channel, this layer computes the mean of the image pixels in the
|
|
channel and then adjusts each component `x` of each pixel to
|
|
`(x - mean) * contrast_factor + mean`.
|
|
|
|
Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and
|
|
in integer or floating point dtype.
|
|
By default, the layer will output floats.
|
|
|
|
Input shape:
|
|
3D (unbatched) or 4D (batched) tensor with shape:
|
|
`(..., height, width, channels)`, in `"channels_last"` format.
|
|
|
|
Output shape:
|
|
3D (unbatched) or 4D (batched) tensor with shape:
|
|
`(..., height, width, channels)`, in `"channels_last"` format.
|
|
|
|
Args:
|
|
factor: a positive float represented as fraction of value, or a tuple of
|
|
size 2 representing lower and upper bound.
|
|
When represented as a single float, lower = upper.
|
|
The contrast factor will be randomly picked between
|
|
`[1.0 - lower, 1.0 + upper]`. For any pixel x in the channel,
|
|
the output will be `(x - mean) * factor + mean`
|
|
where `mean` is the mean value of the channel.
|
|
seed: Integer. Used to create a random seed.
|
|
"""
|
|
|
|
def __init__(self, factor, seed=None, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.factor = factor
|
|
if isinstance(factor, (tuple, list)):
|
|
self.lower = factor[0]
|
|
self.upper = factor[1]
|
|
else:
|
|
self.lower = self.upper = factor
|
|
if self.lower < 0.0 or self.upper < 0.0 or self.lower > 1.0:
|
|
raise ValueError(
|
|
"`factor` argument cannot have negative values or values "
|
|
"greater than 1."
|
|
f"Received: factor={factor}"
|
|
)
|
|
self.seed = seed
|
|
self.generator = random.SeedGenerator(seed)
|
|
|
|
def call(self, inputs, training=True):
|
|
inputs = ops.cast(inputs, self.compute_dtype)
|
|
if training:
|
|
factor = ops.random.uniform(
|
|
shape=(),
|
|
minval=1.0 - self.lower,
|
|
maxval=1.0 + self.upper,
|
|
seed=self.generator,
|
|
)
|
|
|
|
outputs = self._adjust_constrast(inputs, factor)
|
|
outputs = ops.clip(outputs, 0, 255)
|
|
ops.reshape(outputs, inputs.shape)
|
|
return outputs
|
|
else:
|
|
return inputs
|
|
|
|
def _adjust_constrast(self, inputs, contrast_factor):
|
|
# reduce mean on height
|
|
inp_mean = ops.mean(inputs, axis=-3, keepdims=True)
|
|
# reduce mean on width
|
|
inp_mean = ops.mean(inp_mean, axis=-2, keepdims=True)
|
|
|
|
outputs = (inputs - inp_mean) * contrast_factor + inp_mean
|
|
return outputs
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
return input_shape
|
|
|
|
def get_config(self):
|
|
config = {
|
|
"factor": self.factor,
|
|
"seed": self.seed,
|
|
}
|
|
base_config = super().get_config()
|
|
return {**base_config, **config}
|