112 lines
4.4 KiB
Python
112 lines
4.4 KiB
Python
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from keras_core import backend
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.layers.layer import Layer
|
|
|
|
|
|
@keras_core_export("keras_core.layers.RandomRotation")
|
|
class RandomRotation(Layer):
|
|
"""A preprocessing layer which randomly rotates images during training.
|
|
|
|
This layer will apply random rotations to each image, filling empty space
|
|
according to `fill_mode`.
|
|
|
|
By default, random rotations are only applied during training.
|
|
At inference time, the layer does nothing. If you need to apply random
|
|
rotations at inference time, set `training` to True when calling the layer.
|
|
|
|
Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and
|
|
of integer or floating point dtype.
|
|
By default, the layer will output floats.
|
|
|
|
**Note:** This layer wraps `tf.keras.layers.RandomRotation`. It cannot
|
|
be used as part of the compiled computation graph of a model with
|
|
any backend other than TensorFlow.
|
|
It can however be used with any backend when running eagerly.
|
|
It can also always be used as part of an input preprocessing pipeline
|
|
with any backend (outside the model itself), which is how we recommend
|
|
to use this layer.
|
|
|
|
Input shape:
|
|
3D (unbatched) or 4D (batched) tensor with shape:
|
|
`(..., height, width, channels)`, in `"channels_last"` format
|
|
|
|
Output shape:
|
|
3D (unbatched) or 4D (batched) tensor with shape:
|
|
`(..., height, width, channels)`, in `"channels_last"` format
|
|
|
|
Args:
|
|
factor: a float represented as fraction of 2 Pi, or a tuple of size 2
|
|
representing lower and upper bound for rotating clockwise and
|
|
counter-clockwise. A positive values means rotating
|
|
counter clock-wise,
|
|
while a negative value means clock-wise.
|
|
When represented as a single
|
|
float, this value is used for both the upper and lower bound.
|
|
For instance, `factor=(-0.2, 0.3)`
|
|
results in an output rotation by a random
|
|
amount in the range `[-20% * 2pi, 30% * 2pi]`.
|
|
`factor=0.2` results in an
|
|
output rotating by a random amount
|
|
in the range `[-20% * 2pi, 20% * 2pi]`.
|
|
fill_mode: Points outside the boundaries of the input are filled
|
|
according to the given mode
|
|
(one of `{"constant", "reflect", "wrap", "nearest"}`).
|
|
- *reflect*: `(d c b a | a b c d | d c b a)`
|
|
The input is extended by reflecting about
|
|
the edge of the last pixel.
|
|
- *constant*: `(k k k k | a b c d | k k k k)`
|
|
The input is extended by
|
|
filling all values beyond the edge with
|
|
the same constant value k = 0.
|
|
- *wrap*: `(a b c d | a b c d | a b c d)` The input is extended by
|
|
wrapping around to the opposite edge.
|
|
- *nearest*: `(a a a a | a b c d | d d d d)`
|
|
The input is extended by the nearest pixel.
|
|
interpolation: Interpolation mode. Supported values: `"nearest"`,
|
|
`"bilinear"`.
|
|
seed: Integer. Used to create a random seed.
|
|
fill_value: a float represents the value to be filled outside
|
|
the boundaries when `fill_mode="constant"`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
factor,
|
|
fill_mode="reflect",
|
|
interpolation="bilinear",
|
|
seed=None,
|
|
fill_value=0.0,
|
|
name=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(name=name, **kwargs)
|
|
self.seed = seed or backend.random.make_default_seed()
|
|
self.layer = tf.keras.layers.RandomRotation(
|
|
factor=factor,
|
|
fill_mode=fill_mode,
|
|
interpolation=interpolation,
|
|
seed=self.seed,
|
|
fill_value=fill_value,
|
|
name=name,
|
|
**kwargs,
|
|
)
|
|
|
|
def call(self, inputs, training=True):
|
|
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
|
|
inputs = tf.convert_to_tensor(np.array(inputs))
|
|
outputs = self.layer.call(inputs)
|
|
if backend.backend() != "tensorflow":
|
|
outputs = backend.convert_to_tensor(outputs)
|
|
return outputs
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
return tuple(self.layer.compute_output_shape(input_shape))
|
|
|
|
def get_config(self):
|
|
config = self.layer.get_config()
|
|
config.update({"seed": self.seed})
|
|
return config
|