keras/keras_core/layers/preprocessing/resizing.py

114 lines
4.4 KiB
Python

from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import image_utils
@keras_core_export("keras_core.layers.Resizing")
class Resizing(Layer):
"""A preprocessing layer which resizes images.
This layer resizes an image input to a target height and width. The input
should be a 4D (batched) or 3D (unbatched) tensor in `"channels_last"`
format. Input pixel values can be of any range
(e.g. `[0., 1.)` or `[0, 255]`).
Input shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format,
or `(..., channels, height, width)`, in `"channels_first"` format.
Output shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., target_height, target_width, channels)`,
or `(..., channels, target_height, target_width)`,
in `"channels_first"` format.
Args:
height: Integer, the height of the output shape.
width: Integer, the width of the output shape.
interpolation: String, the interpolation method.
Supports `"bilinear"`, `"nearest"`, `"bicubic"`,
`"lanczos3"`, `"lanczos5"`. Defaults to `"bilinear"`.
crop_to_aspect_ratio: If `True`, resize the images without aspect
ratio distortion. When the original aspect ratio differs
from the target aspect ratio, the output image will be
cropped so as to return the
largest possible window in the image (of size `(height, width)`)
that matches the target aspect ratio. By default
(`crop_to_aspect_ratio=False`), aspect ratio may not be preserved.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, height, width)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
"""
def __init__(
self,
height,
width,
interpolation="bilinear",
crop_to_aspect_ratio=False,
data_format=None,
**kwargs,
):
super().__init__(**kwargs)
self.height = height
self.width = width
self.interpolation = interpolation
self.data_format = backend.standardize_data_format(data_format)
self.crop_to_aspect_ratio = crop_to_aspect_ratio
def call(self, inputs):
size = (self.height, self.width)
if self.crop_to_aspect_ratio:
outputs = image_utils.smart_resize(
inputs,
size=size,
interpolation=self.interpolation,
data_format=self.data_format,
)
else:
outputs = ops.image.resize(
inputs,
size=size,
method=self.interpolation,
data_format=self.data_format,
)
return outputs
def compute_output_shape(self, input_shape):
input_shape = list(input_shape)
if len(input_shape) == 4:
if self.data_format == "channels_last":
input_shape[1] = self.height
input_shape[2] = self.width
else:
input_shape[2] = self.height
input_shape[3] = self.width
else:
if self.data_format == "channels_last":
input_shape[0] = self.height
input_shape[1] = self.width
else:
input_shape[1] = self.height
input_shape[2] = self.width
return tuple(input_shape)
def get_config(self):
base_config = super().get_config()
config = {
"height": self.height,
"width": self.width,
"interpolation": self.interpolation,
"crop_to_aspect_ratio": self.crop_to_aspect_ratio,
"data_format": self.data_format,
}
return {**base_config, **config}