keras/keras_core/layers/preprocessing/random_brightness.py
Neel Kovelamudi b518b9ef2b Adds RandomContrast and RandomBrightness and associated tests (#130)
* Adds unit normalization and tests

* Adds layer normalization and initial tests

* Fixes formatting in docstrings

* Fix type issues for JAX

* Fix nits

* Initial stash for group_normalization and spectral_normalization

* Adds spectral normalization and tests

* Adds group normalization and tests

* Formatting fixes

* Fix small nit in docstring

* Fix docstring and tests

* Adds RandomContrast and associated tests

* Remove arithmetic comment

* Adds RandomBrightness and tests

* Fix docstring and format

* Fix nits and add backend generator

* Inlines random_contrast helper
2023-05-11 04:19:03 +00:00

159 lines
6.0 KiB
Python

from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.backend import random
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.RandomBrightness")
class RandomBrightness(Layer):
"""A preprocessing layer which randomly adjusts brightness during training.
This layer will randomly increase/reduce the brightness for the input RGB
images. At inference time, the output will be identical to the input.
Call the layer with `training=True` to adjust the brightness of the input.
Args:
factor: Float or a list/tuple of 2 floats between -1.0 and 1.0. The
factor is used to determine the lower bound and upper bound of the
brightness adjustment. A float value will be chosen randomly between
the limits. When -1.0 is chosen, the output image will be black, and
when 1.0 is chosen, the image will be fully white.
When only one float is provided, eg, 0.2,
then -0.2 will be used for lower bound and 0.2
will be used for upper bound.
value_range: Optional list/tuple of 2 floats
for the lower and upper limit
of the values of the input data.
To make no change, use `[0.0, 1.0]`, e.g., if the image input
has been scaled before this layer. Defaults to `[0.0, 255.0]`.
The brightness adjustment will be scaled to this range, and the
output values will be clipped to this range.
seed: optional integer, for fixed RNG behavior.
Inputs: 3D (HWC) or 4D (NHWC) tensor, with float or int dtype. Input pixel
values can be of any range (e.g. `[0., 1.)` or `[0, 255]`)
Output: 3D (HWC) or 4D (NHWC) tensor with brightness adjusted based on the
`factor`. By default, the layer will output floats.
The output value will be clipped to the range `[0, 255]`,
the valid range of RGB colors, and
rescaled based on the `value_range` if needed.
Sample usage:
```python
random_bright = keras_core.layers.RandomBrightness(factor=0.2)
# An image with shape [2, 2, 3]
image = [[[1, 2, 3], [4 ,5 ,6]], [[7, 8, 9], [10, 11, 12]]]
# Assume we randomly select the factor to be 0.1, then it will apply
# 0.1 * 255 to all the channel
output = random_bright(image, training=True)
# output will be int64 with 25.5 added to each channel and round down.
>>> array([[[26.5, 27.5, 28.5]
[29.5, 30.5, 31.5]]
[[32.5, 33.5, 34.5]
[35.5, 36.5, 37.5]]],
shape=(2, 2, 3), dtype=int64)
```
"""
_FACTOR_VALIDATION_ERROR = (
"The `factor` argument should be a number (or a list of two numbers) "
"in the range [-1.0, 1.0]. "
)
_VALUE_RANGE_VALIDATION_ERROR = (
"The `value_range` argument should be a list of two numbers. "
)
def __init__(self, factor, value_range=(0, 255), seed=None, **kwargs):
super().__init__(**kwargs)
self._set_factor(factor)
self._set_value_range(value_range)
self._seed = seed
self._generator = random.SeedGenerator(seed)
def _set_value_range(self, value_range):
if not isinstance(value_range, (tuple, list)):
raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
if len(value_range) != 2:
raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
self._value_range = sorted(value_range)
def _set_factor(self, factor):
if isinstance(factor, (tuple, list)):
if len(factor) != 2:
raise ValueError(
self._FACTOR_VALIDATION_ERROR + f"Received: factor={factor}"
)
self._check_factor_range(factor[0])
self._check_factor_range(factor[1])
self._factor = sorted(factor)
elif isinstance(factor, (int, float)):
self._check_factor_range(factor)
factor = abs(factor)
self._factor = [-factor, factor]
else:
raise ValueError(
self._FACTOR_VALIDATION_ERROR + f"Received: factor={factor}"
)
def _check_factor_range(self, input_number):
if input_number > 1.0 or input_number < -1.0:
raise ValueError(
self._FACTOR_VALIDATION_ERROR
+ f"Received: input_number={input_number}"
)
def call(self, inputs, training=True):
inputs = ops.cast(inputs, self.compute_dtype)
if training:
return self._brightness_adjust(inputs)
else:
return inputs
def _brightness_adjust(self, images):
rank = len(images.shape)
if rank == 3:
rgb_delta_shape = (1, 1, 1)
elif rank == 4:
# Keep only the batch dim. This will ensure to have same adjustment
# with in one image, but different across the images.
rgb_delta_shape = [images.shape[0], 1, 1, 1]
else:
raise ValueError(
"Expected the input image to be rank 3 or 4. Received "
f"inputs.shape = {images.shape}"
)
rgb_delta = ops.random.uniform(
minval=self._factor[0],
maxval=self._factor[1],
shape=rgb_delta_shape,
seed=self._generator,
)
rgb_delta = rgb_delta * (self._value_range[1] - self._value_range[0])
rgb_delta = ops.cast(rgb_delta, images.dtype)
images += rgb_delta
return ops.clip(images, self._value_range[0], self._value_range[1])
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = {
"factor": self._factor,
"value_range": self._value_range,
"seed": self._seed,
}
base_config = super().get_config()
return {**base_config, **config}