add RandomCrop to preprocessing layers (#123)
* adding randomcrop * added test cases * update test names * updated test class name * added review changes requested * empty init
This commit is contained in:
parent
46378e883f
commit
dd18f20914
@ -64,6 +64,7 @@ from keras_core.layers.preprocessing.discretization import Discretization
|
|||||||
from keras_core.layers.preprocessing.hashing import Hashing
|
from keras_core.layers.preprocessing.hashing import Hashing
|
||||||
from keras_core.layers.preprocessing.integer_lookup import IntegerLookup
|
from keras_core.layers.preprocessing.integer_lookup import IntegerLookup
|
||||||
from keras_core.layers.preprocessing.normalization import Normalization
|
from keras_core.layers.preprocessing.normalization import Normalization
|
||||||
|
from keras_core.layers.preprocessing.random_crop import RandomCrop
|
||||||
from keras_core.layers.preprocessing.rescaling import Rescaling
|
from keras_core.layers.preprocessing.rescaling import Rescaling
|
||||||
from keras_core.layers.preprocessing.resizing import Resizing
|
from keras_core.layers.preprocessing.resizing import Resizing
|
||||||
from keras_core.layers.preprocessing.string_lookup import StringLookup
|
from keras_core.layers.preprocessing.string_lookup import StringLookup
|
||||||
|
65
keras_core/layers/preprocessing/random_crop.py
Normal file
65
keras_core/layers/preprocessing/random_crop.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from keras_core import backend
|
||||||
|
from keras_core.api_export import keras_core_export
|
||||||
|
from keras_core.layers.layer import Layer
|
||||||
|
|
||||||
|
|
||||||
|
@keras_core_export("keras_core.layers.RandomCrop")
|
||||||
|
class RandomCrop(Layer):
|
||||||
|
"""A preprocessing layer which randomly crops images during training.
|
||||||
|
|
||||||
|
During training, this layer will randomly choose a location to crop images
|
||||||
|
down to a target size. The layer will crop all the images in the same batch
|
||||||
|
to the same cropping location.
|
||||||
|
|
||||||
|
At inference time, and during training if an input image is smaller than the
|
||||||
|
target size, the input will be resized and cropped so as to return the
|
||||||
|
largest possible window in the image that matches the target aspect ratio.
|
||||||
|
If you need to apply random cropping at inference time, set `training` to
|
||||||
|
True when calling the layer.
|
||||||
|
|
||||||
|
Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and
|
||||||
|
of integer or floating point dtype. By default, the layer will output
|
||||||
|
floats.
|
||||||
|
|
||||||
|
Input shape:
|
||||||
|
3D (unbatched) or 4D (batched) tensor with shape:
|
||||||
|
`(..., height, width, channels)`, in `"channels_last"` format.
|
||||||
|
|
||||||
|
Output shape:
|
||||||
|
3D (unbatched) or 4D (batched) tensor with shape:
|
||||||
|
`(..., target_height, target_width, channels)`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
height: Integer, the height of the output shape.
|
||||||
|
width: Integer, the width of the output shape.
|
||||||
|
seed: Integer. Used to create a random seed.
|
||||||
|
**kwargs: Base layer keyword arguments, such as
|
||||||
|
`name` and `dtype`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, height, width, seed=None, name=None, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.layer = tf.keras.layers.RandomCrop(
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
seed=seed,
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
self.supports_masking = False
|
||||||
|
|
||||||
|
def call(self, inputs, training=True):
|
||||||
|
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
|
||||||
|
inputs = tf.convert_to_tensor(np.array(inputs))
|
||||||
|
outputs = self.layer.call(inputs)
|
||||||
|
if backend.backend() != "tensorflow":
|
||||||
|
outputs = backend.convert_to_tensor(outputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def compute_output_shape(self, input_shape):
|
||||||
|
return tuple(self.layer.compute_output_shape(input_shape))
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return self.layer.get_config()
|
61
keras_core/layers/preprocessing/random_crop_test.py
Normal file
61
keras_core/layers/preprocessing/random_crop_test.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from keras_core import layers
|
||||||
|
from keras_core import testing
|
||||||
|
|
||||||
|
|
||||||
|
class RandomCropTest(testing.TestCase):
|
||||||
|
def test_random_crop(self):
|
||||||
|
self.run_layer_test(
|
||||||
|
layers.RandomCrop,
|
||||||
|
init_kwargs={
|
||||||
|
"height": 1,
|
||||||
|
"width": 1,
|
||||||
|
},
|
||||||
|
input_shape=(2, 3, 4),
|
||||||
|
supports_masking=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_random_crop_full(self):
|
||||||
|
np.random.seed(1337)
|
||||||
|
height, width = 8, 16
|
||||||
|
inp = np.random.random((12, 8, 16, 3))
|
||||||
|
layer = layers.RandomCrop(height, width)
|
||||||
|
actual_output = layer(inp, training=False)
|
||||||
|
self.assertAllClose(inp, actual_output)
|
||||||
|
|
||||||
|
def test_random_crop_partial(self):
|
||||||
|
self.run_layer_test(
|
||||||
|
layers.RandomCrop,
|
||||||
|
init_kwargs={
|
||||||
|
"height": 8,
|
||||||
|
"width": 8,
|
||||||
|
},
|
||||||
|
input_shape=(12, 8, 16, 3),
|
||||||
|
expected_output_shape=(12, 8, 8, 3),
|
||||||
|
supports_masking=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_predicting_with_longer_height(self):
|
||||||
|
self.run_layer_test(
|
||||||
|
layers.RandomCrop,
|
||||||
|
init_kwargs={
|
||||||
|
"height": 10,
|
||||||
|
"width": 8,
|
||||||
|
},
|
||||||
|
input_shape=(12, 8, 16, 3),
|
||||||
|
expected_output_shape=(12, 10, 8, 3),
|
||||||
|
supports_masking=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_predicting_with_longer_width(self):
|
||||||
|
self.run_layer_test(
|
||||||
|
layers.RandomCrop,
|
||||||
|
init_kwargs={
|
||||||
|
"height": 8,
|
||||||
|
"width": 18,
|
||||||
|
},
|
||||||
|
input_shape=(12, 8, 16, 3),
|
||||||
|
expected_output_shape=(12, 8, 18, 3),
|
||||||
|
supports_masking=False,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user