keras/keras_core/backend/jax/image.py
2023-05-05 22:09:34 -07:00

43 lines
1.3 KiB
Python

import jax
RESIZE_METHODS = (
"bilinear",
"nearest",
"lanczos3",
"lanczos5",
"bicubic",
)
def resize(
image, size, method="bilinear", antialias=False, data_format="channels_last"
):
if method not in RESIZE_METHODS:
raise ValueError(
"Invalid value for argument `method`. Expected of one "
f"{RESIZE_METHODS}. Received: method={method}"
)
if not len(size) == 2:
raise ValueError(
"Argument `size` must be a tuple of two elements "
f"(height, width). Received: size={size}"
)
size = tuple(size)
if len(image.shape) == 4:
if data_format == "channels_last":
size = (image.shape[0],) + size + (image.shape[-1],)
else:
size = (image.shape[0], image.shape[1]) + size
elif len(image.shape) == 3:
if data_format == "channels_last":
size = size + (image.shape[-1],)
else:
size = (image.shape[0],) + size
else:
raise ValueError(
"Invalid input rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
return jax.image.resize(image, size, method=method, antialias=antialias)