keras/keras_core/layers/preprocessing/random_crop.py
divyasreepat 3f4651e67b Add RandomRotation preprocessing layer (#148)
* added random rotation

* update test name

* updates

* updated the seed for random crop as well

* added name

* added name
2023-05-11 18:01:12 -07:00

77 lines
2.9 KiB
Python

import numpy as np
import tensorflow as tf
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.RandomCrop")
class RandomCrop(Layer):
"""A preprocessing layer which randomly crops images during training.
During training, this layer will randomly choose a location to crop images
down to a target size. The layer will crop all the images in the same batch
to the same cropping location.
At inference time, and during training if an input image is smaller than the
target size, the input will be resized and cropped so as to return the
largest possible window in the image that matches the target aspect ratio.
If you need to apply random cropping at inference time, set `training` to
True when calling the layer.
Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and
of integer or floating point dtype. By default, the layer will output
floats.
**Note:** This layer wraps `tf.keras.layers.RandomCrop`. It cannot
be used as part of the compiled computation graph of a model with
any backend other than TensorFlow.
It can however be used with any backend when running eagerly.
It can also always be used as part of an input preprocessing pipeline
with any backend (outside the model itself), which is how we recommend
to use this layer.
Input shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format.
Output shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., target_height, target_width, channels)`.
Args:
height: Integer, the height of the output shape.
width: Integer, the width of the output shape.
seed: Integer. Used to create a random seed.
**kwargs: Base layer keyword arguments, such as
`name` and `dtype`.
"""
def __init__(self, height, width, seed=None, name=None, **kwargs):
super().__init__(name=name, **kwargs)
self.seed = seed or backend.random.make_default_seed()
self.layer = tf.keras.layers.RandomCrop(
height=height,
width=width,
seed=self.seed,
name=name,
)
self.supports_masking = False
def call(self, inputs, training=True):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow":
outputs = backend.convert_to_tensor(outputs)
return outputs
def compute_output_shape(self, input_shape):
return tuple(self.layer.compute_output_shape(input_shape))
def get_config(self):
config = self.layer.get_config()
config.update({"seed": self.seed})
return config