import jax RESIZE_METHODS = ( "bilinear", "nearest", "lanczos3", "lanczos5", "bicubic", ) def resize( image, size, method="bilinear", antialias=False, data_format="channels_last" ): if method not in RESIZE_METHODS: raise ValueError( "Invalid value for argument `method`. Expected of one " f"{RESIZE_METHODS}. Received: method={method}" ) if not len(size) == 2: raise ValueError( "Argument `size` must be a tuple of two elements " f"(height, width). Received: size={size}" ) size = tuple(size) if len(image.shape) == 4: if data_format == "channels_last": size = (image.shape[0],) + size + (image.shape[-1],) else: size = (image.shape[0], image.shape[1]) + size elif len(image.shape) == 3: if data_format == "channels_last": size = size + (image.shape[-1],) else: size = (image.shape[0],) + size else: raise ValueError( "Invalid input rank: expected rank 3 (single image) " "or rank 4 (batch of images). Received input with shape: " f"image.shape={image.shape}" ) return jax.image.resize(image, size, method=method, antialias=antialias)