45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
import tensorflow as tf
|
|
|
|
RESIZE_METHODS = (
|
|
"bilinear",
|
|
"nearest",
|
|
"lanczos3",
|
|
"lanczos5",
|
|
"bicubic",
|
|
)
|
|
|
|
|
|
def resize(
|
|
image, size, method="bilinear", antialias=False, data_format="channels_last"
|
|
):
|
|
if method not in RESIZE_METHODS:
|
|
raise ValueError(
|
|
"Invalid value for argument `method`. Expected of one "
|
|
f"{RESIZE_METHODS}. Received: method={method}"
|
|
)
|
|
if not len(size) == 2:
|
|
raise ValueError(
|
|
"Argument `size` must be a tuple of two elements "
|
|
f"(height, width). Received: size={size}"
|
|
)
|
|
size = tuple(size)
|
|
if data_format == "channels_first":
|
|
if len(image.shape) == 4:
|
|
image = tf.transpose(image, (0, 2, 3, 1))
|
|
elif len(image.shape) == 3:
|
|
image = tf.transpose(image, (1, 2, 0))
|
|
else:
|
|
raise ValueError(
|
|
"Invalid input rank: expected rank 3 (single image) "
|
|
"or rank 4 (batch of images). Received input with shape: "
|
|
f"image.shape={image.shape}"
|
|
)
|
|
|
|
resized = tf.image.resize(image, size, method=method, antialias=antialias)
|
|
if data_format == "channels_first":
|
|
if len(image.shape) == 4:
|
|
resized = tf.transpose(resized, (0, 3, 1, 2))
|
|
elif len(image.shape) == 3:
|
|
resized = tf.transpose(resized, (2, 0, 1))
|
|
return resized
|