keras/keras_core/backend/tensorflow/image.py
2023-05-05 22:09:34 -07:00

45 lines
1.4 KiB
Python

import tensorflow as tf
RESIZE_METHODS = (
"bilinear",
"nearest",
"lanczos3",
"lanczos5",
"bicubic",
)
def resize(
image, size, method="bilinear", antialias=False, data_format="channels_last"
):
if method not in RESIZE_METHODS:
raise ValueError(
"Invalid value for argument `method`. Expected of one "
f"{RESIZE_METHODS}. Received: method={method}"
)
if not len(size) == 2:
raise ValueError(
"Argument `size` must be a tuple of two elements "
f"(height, width). Received: size={size}"
)
size = tuple(size)
if data_format == "channels_first":
if len(image.shape) == 4:
image = tf.transpose(image, (0, 2, 3, 1))
elif len(image.shape) == 3:
image = tf.transpose(image, (1, 2, 0))
else:
raise ValueError(
"Invalid input rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
resized = tf.image.resize(image, size, method=method, antialias=antialias)
if data_format == "channels_first":
if len(image.shape) == 4:
resized = tf.transpose(resized, (0, 3, 1, 2))
elif len(image.shape) == 3:
resized = tf.transpose(resized, (2, 0, 1))
return resized