2023-05-05 22:48:24 +00:00
|
|
|
import jax
|
|
|
|
|
|
|
|
RESIZE_METHODS = (
|
|
|
|
"bilinear",
|
|
|
|
"nearest",
|
|
|
|
"lanczos3",
|
|
|
|
"lanczos5",
|
|
|
|
"bicubic",
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def resize(
|
|
|
|
image, size, method="bilinear", antialias=False, data_format="channels_last"
|
|
|
|
):
|
|
|
|
if method not in RESIZE_METHODS:
|
|
|
|
raise ValueError(
|
|
|
|
"Invalid value for argument `method`. Expected of one "
|
|
|
|
f"{RESIZE_METHODS}. Received: method={method}"
|
|
|
|
)
|
|
|
|
if not len(size) == 2:
|
|
|
|
raise ValueError(
|
|
|
|
"Argument `size` must be a tuple of two elements "
|
|
|
|
f"(height, width). Received: size={size}"
|
|
|
|
)
|
2023-05-06 05:09:34 +00:00
|
|
|
size = tuple(size)
|
2023-05-05 22:48:24 +00:00
|
|
|
if len(image.shape) == 4:
|
|
|
|
if data_format == "channels_last":
|
|
|
|
size = (image.shape[0],) + size + (image.shape[-1],)
|
|
|
|
else:
|
|
|
|
size = (image.shape[0], image.shape[1]) + size
|
|
|
|
elif len(image.shape) == 3:
|
|
|
|
if data_format == "channels_last":
|
|
|
|
size = size + (image.shape[-1],)
|
|
|
|
else:
|
|
|
|
size = (image.shape[0],) + size
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
"Invalid input rank: expected rank 3 (single image) "
|
|
|
|
"or rank 4 (batch of images). Received input with shape: "
|
|
|
|
f"image.shape={image.shape}"
|
|
|
|
)
|
|
|
|
return jax.image.resize(image, size, method=method, antialias=antialias)
|