Add affine_transform
op to all backends (#477)
* Add affine op * Sync import convention * Use `np.random.random` * Refactor jax implementation * Fix * Address fchollet's comments * Update docstring * Fix test * Replace method with interpolation * Replace method with interpolation * Replace method with interpolation * Update test
This commit is contained in:
parent
ed26f7f4c5
commit
d7b3eae31f
@ -1,6 +1,11 @@
|
|||||||
import jax
|
import functools
|
||||||
|
|
||||||
RESIZE_METHODS = (
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
from keras_core.backend.jax.core import convert_to_tensor
|
||||||
|
|
||||||
|
RESIZE_INTERPOLATIONS = (
|
||||||
"bilinear",
|
"bilinear",
|
||||||
"nearest",
|
"nearest",
|
||||||
"lanczos3",
|
"lanczos3",
|
||||||
@ -10,12 +15,16 @@ RESIZE_METHODS = (
|
|||||||
|
|
||||||
|
|
||||||
def resize(
|
def resize(
|
||||||
image, size, method="bilinear", antialias=False, data_format="channels_last"
|
image,
|
||||||
|
size,
|
||||||
|
interpolation="bilinear",
|
||||||
|
antialias=False,
|
||||||
|
data_format="channels_last",
|
||||||
):
|
):
|
||||||
if method not in RESIZE_METHODS:
|
if interpolation not in RESIZE_INTERPOLATIONS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid value for argument `method`. Expected of one "
|
"Invalid value for argument `interpolation`. Expected of one "
|
||||||
f"{RESIZE_METHODS}. Received: method={method}"
|
f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}"
|
||||||
)
|
)
|
||||||
if not len(size) == 2:
|
if not len(size) == 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -39,4 +48,117 @@ def resize(
|
|||||||
"or rank 4 (batch of images). Received input with shape: "
|
"or rank 4 (batch of images). Received input with shape: "
|
||||||
f"image.shape={image.shape}"
|
f"image.shape={image.shape}"
|
||||||
)
|
)
|
||||||
return jax.image.resize(image, size, method=method, antialias=antialias)
|
return jax.image.resize(
|
||||||
|
image, size, method=interpolation, antialias=antialias
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order
|
||||||
|
"nearest": 0,
|
||||||
|
"bilinear": 1,
|
||||||
|
}
|
||||||
|
AFFINE_TRANSFORM_FILL_MODES = {
|
||||||
|
"constant": "grid-constant",
|
||||||
|
"nearest": "nearest",
|
||||||
|
"wrap": "grid-wrap",
|
||||||
|
"mirror": "mirror",
|
||||||
|
"reflect": "reflect",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def affine_transform(
|
||||||
|
image,
|
||||||
|
transform,
|
||||||
|
interpolation="bilinear",
|
||||||
|
fill_mode="constant",
|
||||||
|
fill_value=0,
|
||||||
|
data_format="channels_last",
|
||||||
|
):
|
||||||
|
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid value for argument `interpolation`. Expected of one "
|
||||||
|
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
|
||||||
|
f"interpolation={interpolation}"
|
||||||
|
)
|
||||||
|
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES.keys():
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid value for argument `fill_mode`. Expected of one "
|
||||||
|
f"{set(AFFINE_TRANSFORM_FILL_MODES.keys())}. "
|
||||||
|
f"Received: fill_mode={fill_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
|
transform = convert_to_tensor(transform)
|
||||||
|
|
||||||
|
if len(image.shape) not in (3, 4):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image rank: expected rank 3 (single image) "
|
||||||
|
"or rank 4 (batch of images). Received input with shape: "
|
||||||
|
f"image.shape={image.shape}"
|
||||||
|
)
|
||||||
|
if len(transform.shape) not in (1, 2):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid transform rank: expected rank 1 (single transform) "
|
||||||
|
"or rank 2 (batch of transforms). Received input with shape: "
|
||||||
|
f"transform.shape={transform.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# unbatched case
|
||||||
|
need_squeeze = False
|
||||||
|
if len(image.shape) == 3:
|
||||||
|
image = jnp.expand_dims(image, axis=0)
|
||||||
|
need_squeeze = True
|
||||||
|
if len(transform.shape) == 1:
|
||||||
|
transform = jnp.expand_dims(transform, axis=0)
|
||||||
|
|
||||||
|
if data_format == "channels_first":
|
||||||
|
image = jnp.transpose(image, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
batch_size = image.shape[0]
|
||||||
|
|
||||||
|
# get indices
|
||||||
|
meshgrid = jnp.meshgrid(
|
||||||
|
*[jnp.arange(size) for size in image.shape[1:]], indexing="ij"
|
||||||
|
)
|
||||||
|
indices = jnp.concatenate(
|
||||||
|
[jnp.expand_dims(x, axis=-1) for x in meshgrid], axis=-1
|
||||||
|
)
|
||||||
|
indices = jnp.tile(indices, (batch_size, 1, 1, 1, 1))
|
||||||
|
|
||||||
|
# swap the values
|
||||||
|
a0 = transform[:, 0]
|
||||||
|
a2 = transform[:, 2]
|
||||||
|
b1 = transform[:, 4]
|
||||||
|
b2 = transform[:, 5]
|
||||||
|
transform = transform.at[:, 0].set(b1)
|
||||||
|
transform = transform.at[:, 2].set(b2)
|
||||||
|
transform = transform.at[:, 4].set(a0)
|
||||||
|
transform = transform.at[:, 5].set(a2)
|
||||||
|
|
||||||
|
# deal with transform
|
||||||
|
transform = jnp.pad(
|
||||||
|
transform, pad_width=[[0, 0], [0, 1]], constant_values=1
|
||||||
|
)
|
||||||
|
transform = jnp.reshape(transform, (batch_size, 3, 3))
|
||||||
|
offset = transform[:, 0:2, 2]
|
||||||
|
offset = jnp.pad(offset, pad_width=[[0, 0], [0, 1]])
|
||||||
|
transform = transform.at[:, 0:2, 2].set(0)
|
||||||
|
|
||||||
|
# transform the indices
|
||||||
|
coordinates = jnp.einsum("Bhwij, Bjk -> Bhwik", indices, transform)
|
||||||
|
coordinates = jnp.moveaxis(coordinates, source=-1, destination=1)
|
||||||
|
coordinates += jnp.reshape(a=offset, newshape=(*offset.shape, 1, 1, 1))
|
||||||
|
|
||||||
|
# apply affine transformation
|
||||||
|
_map_coordinates = functools.partial(
|
||||||
|
jax.scipy.ndimage.map_coordinates,
|
||||||
|
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
|
||||||
|
mode=fill_mode,
|
||||||
|
cval=fill_value,
|
||||||
|
)
|
||||||
|
affined = jax.vmap(_map_coordinates)(image, coordinates)
|
||||||
|
|
||||||
|
if data_format == "channels_first":
|
||||||
|
affined = jnp.transpose(affined, (0, 3, 1, 2))
|
||||||
|
if need_squeeze:
|
||||||
|
affined = jnp.squeeze(affined, axis=0)
|
||||||
|
return affined
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
import jax
|
import jax
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import scipy.ndimage
|
||||||
|
|
||||||
RESIZE_METHODS = (
|
from keras_core.backend.numpy.core import convert_to_tensor
|
||||||
|
|
||||||
|
RESIZE_INTERPOLATIONS = (
|
||||||
"bilinear",
|
"bilinear",
|
||||||
"nearest",
|
"nearest",
|
||||||
"lanczos3",
|
"lanczos3",
|
||||||
@ -11,12 +14,16 @@ RESIZE_METHODS = (
|
|||||||
|
|
||||||
|
|
||||||
def resize(
|
def resize(
|
||||||
image, size, method="bilinear", antialias=False, data_format="channels_last"
|
image,
|
||||||
|
size,
|
||||||
|
interpolation="bilinear",
|
||||||
|
antialias=False,
|
||||||
|
data_format="channels_last",
|
||||||
):
|
):
|
||||||
if method not in RESIZE_METHODS:
|
if interpolation not in RESIZE_INTERPOLATIONS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid value for argument `method`. Expected of one "
|
"Invalid value for argument `interpolation`. Expected of one "
|
||||||
f"{RESIZE_METHODS}. Received: method={method}"
|
f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}"
|
||||||
)
|
)
|
||||||
if not len(size) == 2:
|
if not len(size) == 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -41,5 +48,121 @@ def resize(
|
|||||||
f"image.shape={image.shape}"
|
f"image.shape={image.shape}"
|
||||||
)
|
)
|
||||||
return np.array(
|
return np.array(
|
||||||
jax.image.resize(image, size, method=method, antialias=antialias)
|
jax.image.resize(image, size, method=interpolation, antialias=antialias)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order
|
||||||
|
"nearest": 0,
|
||||||
|
"bilinear": 1,
|
||||||
|
}
|
||||||
|
AFFINE_TRANSFORM_FILL_MODES = {
|
||||||
|
"constant": "grid-constant",
|
||||||
|
"nearest": "nearest",
|
||||||
|
"wrap": "grid-wrap",
|
||||||
|
"mirror": "mirror",
|
||||||
|
"reflect": "reflect",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def affine_transform(
|
||||||
|
image,
|
||||||
|
transform,
|
||||||
|
interpolation="bilinear",
|
||||||
|
fill_mode="constant",
|
||||||
|
fill_value=0,
|
||||||
|
data_format="channels_last",
|
||||||
|
):
|
||||||
|
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid value for argument `interpolation`. Expected of one "
|
||||||
|
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
|
||||||
|
f"interpolation={interpolation}"
|
||||||
|
)
|
||||||
|
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES.keys():
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid value for argument `fill_mode`. Expected of one "
|
||||||
|
f"{set(AFFINE_TRANSFORM_FILL_MODES.keys())}. "
|
||||||
|
f"Received: fill_mode={fill_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
|
transform = convert_to_tensor(transform)
|
||||||
|
|
||||||
|
if len(image.shape) not in (3, 4):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image rank: expected rank 3 (single image) "
|
||||||
|
"or rank 4 (batch of images). Received input with shape: "
|
||||||
|
f"image.shape={image.shape}"
|
||||||
|
)
|
||||||
|
if len(transform.shape) not in (1, 2):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid transform rank: expected rank 1 (single transform) "
|
||||||
|
"or rank 2 (batch of transforms). Received input with shape: "
|
||||||
|
f"transform.shape={transform.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# unbatched case
|
||||||
|
need_squeeze = False
|
||||||
|
if len(image.shape) == 3:
|
||||||
|
image = np.expand_dims(image, axis=0)
|
||||||
|
need_squeeze = True
|
||||||
|
if len(transform.shape) == 1:
|
||||||
|
transform = np.expand_dims(transform, axis=0)
|
||||||
|
|
||||||
|
if data_format == "channels_first":
|
||||||
|
image = np.transpose(image, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
batch_size = image.shape[0]
|
||||||
|
|
||||||
|
# get indices
|
||||||
|
meshgrid = np.meshgrid(
|
||||||
|
*[np.arange(size) for size in image.shape[1:]], indexing="ij"
|
||||||
|
)
|
||||||
|
indices = np.concatenate(
|
||||||
|
[np.expand_dims(x, axis=-1) for x in meshgrid], axis=-1
|
||||||
|
)
|
||||||
|
indices = np.tile(indices, (batch_size, 1, 1, 1, 1))
|
||||||
|
|
||||||
|
# swap the values
|
||||||
|
a0 = transform[:, 0].copy()
|
||||||
|
a2 = transform[:, 2].copy()
|
||||||
|
b1 = transform[:, 4].copy()
|
||||||
|
b2 = transform[:, 5].copy()
|
||||||
|
transform[:, 0] = b1
|
||||||
|
transform[:, 2] = b2
|
||||||
|
transform[:, 4] = a0
|
||||||
|
transform[:, 5] = a2
|
||||||
|
|
||||||
|
# deal with transform
|
||||||
|
transform = np.pad(transform, pad_width=[[0, 0], [0, 1]], constant_values=1)
|
||||||
|
transform = np.reshape(transform, (batch_size, 3, 3))
|
||||||
|
offset = transform[:, 0:2, 2].copy()
|
||||||
|
offset = np.pad(offset, pad_width=[[0, 0], [0, 1]])
|
||||||
|
transform[:, 0:2, 2] = 0
|
||||||
|
|
||||||
|
# transform the indices
|
||||||
|
coordinates = np.einsum("Bhwij, Bjk -> Bhwik", indices, transform)
|
||||||
|
coordinates = np.moveaxis(coordinates, source=-1, destination=1)
|
||||||
|
coordinates += np.reshape(a=offset, newshape=(*offset.shape, 1, 1, 1))
|
||||||
|
|
||||||
|
# apply affine transformation
|
||||||
|
affined = np.stack(
|
||||||
|
[
|
||||||
|
scipy.ndimage.map_coordinates(
|
||||||
|
image[i],
|
||||||
|
coordinates[i],
|
||||||
|
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
|
||||||
|
mode=AFFINE_TRANSFORM_FILL_MODES[fill_mode],
|
||||||
|
cval=fill_value,
|
||||||
|
prefilter=False,
|
||||||
|
)
|
||||||
|
for i in range(batch_size)
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if data_format == "channels_first":
|
||||||
|
affined = np.transpose(affined, (0, 3, 1, 2))
|
||||||
|
if need_squeeze:
|
||||||
|
affined = np.squeeze(affined, axis=0)
|
||||||
|
return affined
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
RESIZE_METHODS = (
|
RESIZE_INTERPOLATIONS = (
|
||||||
"bilinear",
|
"bilinear",
|
||||||
"nearest",
|
"nearest",
|
||||||
"lanczos3",
|
"lanczos3",
|
||||||
@ -10,12 +10,16 @@ RESIZE_METHODS = (
|
|||||||
|
|
||||||
|
|
||||||
def resize(
|
def resize(
|
||||||
image, size, method="bilinear", antialias=False, data_format="channels_last"
|
image,
|
||||||
|
size,
|
||||||
|
interpolation="bilinear",
|
||||||
|
antialias=False,
|
||||||
|
data_format="channels_last",
|
||||||
):
|
):
|
||||||
if method not in RESIZE_METHODS:
|
if interpolation not in RESIZE_INTERPOLATIONS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid value for argument `method`. Expected of one "
|
"Invalid value for argument `interpolation`. Expected of one "
|
||||||
f"{RESIZE_METHODS}. Received: method={method}"
|
f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}"
|
||||||
)
|
)
|
||||||
if not len(size) == 2:
|
if not len(size) == 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -35,10 +39,83 @@ def resize(
|
|||||||
f"image.shape={image.shape}"
|
f"image.shape={image.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
resized = tf.image.resize(image, size, method=method, antialias=antialias)
|
resized = tf.image.resize(
|
||||||
|
image, size, method=interpolation, antialias=antialias
|
||||||
|
)
|
||||||
if data_format == "channels_first":
|
if data_format == "channels_first":
|
||||||
if len(image.shape) == 4:
|
if len(image.shape) == 4:
|
||||||
resized = tf.transpose(resized, (0, 3, 1, 2))
|
resized = tf.transpose(resized, (0, 3, 1, 2))
|
||||||
elif len(image.shape) == 3:
|
elif len(image.shape) == 3:
|
||||||
resized = tf.transpose(resized, (2, 0, 1))
|
resized = tf.transpose(resized, (2, 0, 1))
|
||||||
return resized
|
return resized
|
||||||
|
|
||||||
|
|
||||||
|
AFFINE_TRANSFORM_INTERPOLATIONS = (
|
||||||
|
"nearest",
|
||||||
|
"bilinear",
|
||||||
|
)
|
||||||
|
AFFINE_TRANSFORM_FILL_MODES = (
|
||||||
|
"constant",
|
||||||
|
"nearest",
|
||||||
|
"wrap",
|
||||||
|
# "mirror", not supported by TF
|
||||||
|
"reflect",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def affine_transform(
|
||||||
|
image,
|
||||||
|
transform,
|
||||||
|
interpolation="bilinear",
|
||||||
|
fill_mode="constant",
|
||||||
|
fill_value=0,
|
||||||
|
data_format="channels_last",
|
||||||
|
):
|
||||||
|
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid value for argument `interpolation`. Expected of one "
|
||||||
|
f"{AFFINE_TRANSFORM_INTERPOLATIONS}. Received: "
|
||||||
|
f"interpolation={interpolation}"
|
||||||
|
)
|
||||||
|
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid value for argument `fill_mode`. Expected of one "
|
||||||
|
f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}"
|
||||||
|
)
|
||||||
|
if len(image.shape) not in (3, 4):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image rank: expected rank 3 (single image) "
|
||||||
|
"or rank 4 (batch of images). Received input with shape: "
|
||||||
|
f"image.shape={image.shape}"
|
||||||
|
)
|
||||||
|
if len(transform.shape) not in (1, 2):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid transform rank: expected rank 1 (single transform) "
|
||||||
|
"or rank 2 (batch of transforms). Received input with shape: "
|
||||||
|
f"transform.shape={transform.shape}"
|
||||||
|
)
|
||||||
|
# unbatched case
|
||||||
|
need_squeeze = False
|
||||||
|
if len(image.shape) == 3:
|
||||||
|
image = tf.expand_dims(image, axis=0)
|
||||||
|
need_squeeze = True
|
||||||
|
if len(transform.shape) == 1:
|
||||||
|
transform = tf.expand_dims(transform, axis=0)
|
||||||
|
|
||||||
|
if data_format == "channels_first":
|
||||||
|
image = tf.transpose(image, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
affined = tf.raw_ops.ImageProjectiveTransformV3(
|
||||||
|
images=image,
|
||||||
|
transforms=tf.cast(transform, dtype=tf.float32),
|
||||||
|
output_shape=tf.shape(image)[1:-1],
|
||||||
|
fill_value=fill_value,
|
||||||
|
interpolation=interpolation.upper(),
|
||||||
|
fill_mode=fill_mode.upper(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if data_format == "channels_first":
|
||||||
|
affined = tf.transpose(affined, (0, 3, 1, 2))
|
||||||
|
if need_squeeze:
|
||||||
|
affined = tf.squeeze(affined, axis=0)
|
||||||
|
return affined
|
||||||
|
@ -1,21 +1,28 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as tnn
|
||||||
|
|
||||||
from keras_core.backend.torch.core import convert_to_tensor
|
from keras_core.backend.torch.core import convert_to_tensor
|
||||||
|
|
||||||
RESIZE_METHODS = {} # populated after torchvision import
|
RESIZE_INTERPOLATIONS = {} # populated after torchvision import
|
||||||
|
|
||||||
UNSUPPORTED_METHODS = (
|
UNSUPPORTED_INTERPOLATIONS = (
|
||||||
"lanczos3",
|
"lanczos3",
|
||||||
"lanczos5",
|
"lanczos5",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def resize(
|
def resize(
|
||||||
image, size, method="bilinear", antialias=False, data_format="channels_last"
|
image,
|
||||||
|
size,
|
||||||
|
interpolation="bilinear",
|
||||||
|
antialias=False,
|
||||||
|
data_format="channels_last",
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
import torchvision
|
import torchvision
|
||||||
from torchvision.transforms import InterpolationMode as im
|
from torchvision.transforms import InterpolationMode as im
|
||||||
|
|
||||||
RESIZE_METHODS.update(
|
RESIZE_INTERPOLATIONS.update(
|
||||||
{
|
{
|
||||||
"bilinear": im.BILINEAR,
|
"bilinear": im.BILINEAR,
|
||||||
"nearest": im.NEAREST_EXACT,
|
"nearest": im.NEAREST_EXACT,
|
||||||
@ -27,16 +34,16 @@ def resize(
|
|||||||
"The torchvision package is necessary to use `resize` with the "
|
"The torchvision package is necessary to use `resize` with the "
|
||||||
"torch backend. Please install torchvision."
|
"torch backend. Please install torchvision."
|
||||||
)
|
)
|
||||||
if method in UNSUPPORTED_METHODS:
|
if interpolation in UNSUPPORTED_INTERPOLATIONS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Resizing with Lanczos interpolation is "
|
"Resizing with Lanczos interpolation is "
|
||||||
"not supported by the PyTorch backend. "
|
"not supported by the PyTorch backend. "
|
||||||
f"Received: method={method}."
|
f"Received: interpolation={interpolation}."
|
||||||
)
|
)
|
||||||
if method not in RESIZE_METHODS:
|
if interpolation not in RESIZE_INTERPOLATIONS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid value for argument `method`. Expected of one "
|
"Invalid value for argument `interpolation`. Expected of one "
|
||||||
f"{RESIZE_METHODS}. Received: method={method}"
|
f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}"
|
||||||
)
|
)
|
||||||
if not len(size) == 2:
|
if not len(size) == 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -60,7 +67,7 @@ def resize(
|
|||||||
resized = torchvision.transforms.functional.resize(
|
resized = torchvision.transforms.functional.resize(
|
||||||
img=image,
|
img=image,
|
||||||
size=size,
|
size=size,
|
||||||
interpolation=RESIZE_METHODS[method],
|
interpolation=RESIZE_INTERPOLATIONS[interpolation],
|
||||||
antialias=antialias,
|
antialias=antialias,
|
||||||
)
|
)
|
||||||
if data_format == "channels_last":
|
if data_format == "channels_last":
|
||||||
@ -69,3 +76,153 @@ def resize(
|
|||||||
elif len(image.shape) == 3:
|
elif len(image.shape) == 3:
|
||||||
resized = resized.permute((1, 2, 0))
|
resized = resized.permute((1, 2, 0))
|
||||||
return resized
|
return resized
|
||||||
|
|
||||||
|
|
||||||
|
AFFINE_TRANSFORM_INTERPOLATIONS = (
|
||||||
|
"nearest",
|
||||||
|
"bilinear",
|
||||||
|
)
|
||||||
|
AFFINE_TRANSFORM_FILL_MODES = {
|
||||||
|
"constant": "zeros",
|
||||||
|
"nearest": "border",
|
||||||
|
# "wrap", not supported by torch
|
||||||
|
# "mirror", not supported by torch
|
||||||
|
"reflect": "reflection",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_grid_transform(
|
||||||
|
img,
|
||||||
|
grid,
|
||||||
|
interpolation="bilinear",
|
||||||
|
fill_mode="zeros",
|
||||||
|
fill_value=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Modified from https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional/_geometry.py
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
# We are using context knowledge that grid should have float dtype
|
||||||
|
fp = img.dtype == grid.dtype
|
||||||
|
float_img = img if fp else img.to(grid.dtype)
|
||||||
|
|
||||||
|
shape = float_img.shape
|
||||||
|
# Append a dummy mask for customized fill colors, should be faster than
|
||||||
|
# grid_sample() twice
|
||||||
|
if fill_value is not None:
|
||||||
|
mask = torch.ones(
|
||||||
|
(shape[0], 1, shape[2], shape[3]),
|
||||||
|
dtype=float_img.dtype,
|
||||||
|
device=float_img.device,
|
||||||
|
)
|
||||||
|
float_img = torch.cat((float_img, mask), dim=1)
|
||||||
|
|
||||||
|
float_img = tnn.grid_sample(
|
||||||
|
float_img,
|
||||||
|
grid,
|
||||||
|
mode=interpolation,
|
||||||
|
padding_mode=fill_mode,
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
# Fill with required color
|
||||||
|
if fill_value is not None:
|
||||||
|
float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
|
||||||
|
mask = mask.expand_as(float_img)
|
||||||
|
fill_list = (
|
||||||
|
fill_value
|
||||||
|
if isinstance(fill_value, (tuple, list))
|
||||||
|
else [float(fill_value)]
|
||||||
|
)
|
||||||
|
fill_img = torch.tensor(
|
||||||
|
fill_list, dtype=float_img.dtype, device=float_img.device
|
||||||
|
).view(1, -1, 1, 1)
|
||||||
|
if interpolation == "nearest":
|
||||||
|
bool_mask = mask < 0.5
|
||||||
|
float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
|
||||||
|
else: # 'bilinear'
|
||||||
|
# The following is mathematically equivalent to:
|
||||||
|
# img * mask + (1.0 - mask) * fill =
|
||||||
|
# img * mask - fill * mask + fill =
|
||||||
|
# mask * (img - fill) + fill
|
||||||
|
float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)
|
||||||
|
|
||||||
|
img = float_img.round_().to(img.dtype) if not fp else float_img
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def affine_transform(
|
||||||
|
image,
|
||||||
|
transform,
|
||||||
|
interpolation="bilinear",
|
||||||
|
fill_mode="constant",
|
||||||
|
fill_value=0,
|
||||||
|
data_format="channels_last",
|
||||||
|
):
|
||||||
|
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid value for argument `interpolation`. Expected of one "
|
||||||
|
f"{AFFINE_TRANSFORM_INTERPOLATIONS}. Received: "
|
||||||
|
f"interpolation={interpolation}"
|
||||||
|
)
|
||||||
|
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES.keys():
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid value for argument `fill_mode`. Expected of one "
|
||||||
|
f"{set(AFFINE_TRANSFORM_FILL_MODES.keys())}. "
|
||||||
|
f"Received: fill_mode={fill_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
|
image = convert_to_tensor(image)
|
||||||
|
transform = convert_to_tensor(transform)
|
||||||
|
|
||||||
|
if image.ndim not in (3, 4):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image rank: expected rank 3 (single image) "
|
||||||
|
"or rank 4 (batch of images). Received input with shape: "
|
||||||
|
f"image.shape={image.shape}"
|
||||||
|
)
|
||||||
|
if transform.ndim not in (1, 2):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid transform rank: expected rank 1 (single transform) "
|
||||||
|
"or rank 2 (batch of transforms). Received input with shape: "
|
||||||
|
f"transform.shape={transform.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if fill_mode != "constant":
|
||||||
|
fill_value = None
|
||||||
|
fill_mode = AFFINE_TRANSFORM_FILL_MODES[fill_mode]
|
||||||
|
|
||||||
|
# unbatched case
|
||||||
|
need_squeeze = False
|
||||||
|
if image.ndim == 3:
|
||||||
|
image = image.unsqueeze(dim=0)
|
||||||
|
need_squeeze = True
|
||||||
|
if transform.ndim == 1:
|
||||||
|
transform = transform.unsqueeze(dim=0)
|
||||||
|
|
||||||
|
if data_format == "channels_last":
|
||||||
|
image = image.permute((0, 3, 1, 2))
|
||||||
|
|
||||||
|
# deal with transform
|
||||||
|
h, w = image.shape[2], image.shape[3]
|
||||||
|
theta = torch.zeros((image.shape[0], 2, 3)).to(transform)
|
||||||
|
theta[:, 0, 0] = transform[:, 0]
|
||||||
|
theta[:, 0, 1] = transform[:, 1] * h / w
|
||||||
|
theta[:, 0, 2] = (
|
||||||
|
transform[:, 2] * 2 / w + theta[:, 0, 0] + theta[:, 0, 1] - 1
|
||||||
|
)
|
||||||
|
theta[:, 1, 0] = transform[:, 3] * w / h
|
||||||
|
theta[:, 1, 1] = transform[:, 4]
|
||||||
|
theta[:, 1, 2] = (
|
||||||
|
transform[:, 5] * 2 / h + theta[:, 1, 0] + theta[:, 1, 1] - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
grid = tnn.affine_grid(theta, image.shape)
|
||||||
|
affined = _apply_grid_transform(
|
||||||
|
image, grid, interpolation, fill_mode, fill_value
|
||||||
|
)
|
||||||
|
|
||||||
|
if data_format == "channels_last":
|
||||||
|
affined = affined.permute((0, 2, 3, 1))
|
||||||
|
if need_squeeze:
|
||||||
|
affined = affined.squeeze(dim=0)
|
||||||
|
return affined
|
||||||
|
@ -81,7 +81,7 @@ class Resizing(TFDataLayer):
|
|||||||
outputs = self.backend.image.resize(
|
outputs = self.backend.image.resize(
|
||||||
inputs,
|
inputs,
|
||||||
size=size,
|
size=size,
|
||||||
method=self.interpolation,
|
interpolation=self.interpolation,
|
||||||
data_format=self.data_format,
|
data_format=self.data_format,
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -160,7 +160,7 @@ class UpSampling2D(Layer):
|
|||||||
x = ops.repeat(x, height_factor, axis=1)
|
x = ops.repeat(x, height_factor, axis=1)
|
||||||
x = ops.repeat(x, width_factor, axis=2)
|
x = ops.repeat(x, width_factor, axis=2)
|
||||||
else:
|
else:
|
||||||
x = ops.image.resize(x, new_shape, method=interpolation)
|
x = ops.image.resize(x, new_shape, interpolation=interpolation)
|
||||||
if data_format == "channels_first":
|
if data_format == "channels_first":
|
||||||
x = ops.transpose(x, [0, 3, 1, 2])
|
x = ops.transpose(x, [0, 3, 1, 2])
|
||||||
|
|
||||||
|
@ -9,13 +9,13 @@ class Resize(Operation):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
size,
|
size,
|
||||||
method="bilinear",
|
interpolation="bilinear",
|
||||||
antialias=False,
|
antialias=False,
|
||||||
data_format="channels_last",
|
data_format="channels_last",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.size = tuple(size)
|
self.size = tuple(size)
|
||||||
self.method = method
|
self.interpolation = interpolation
|
||||||
self.antialias = antialias
|
self.antialias = antialias
|
||||||
self.data_format = data_format
|
self.data_format = data_format
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ class Resize(Operation):
|
|||||||
return backend.image.resize(
|
return backend.image.resize(
|
||||||
image,
|
image,
|
||||||
self.size,
|
self.size,
|
||||||
method=self.method,
|
interpolation=self.interpolation,
|
||||||
antialias=self.antialias,
|
antialias=self.antialias,
|
||||||
data_format=self.data_format,
|
data_format=self.data_format,
|
||||||
)
|
)
|
||||||
@ -53,14 +53,18 @@ class Resize(Operation):
|
|||||||
|
|
||||||
@keras_core_export("keras_core.ops.image.resize")
|
@keras_core_export("keras_core.ops.image.resize")
|
||||||
def resize(
|
def resize(
|
||||||
image, size, method="bilinear", antialias=False, data_format="channels_last"
|
image,
|
||||||
|
size,
|
||||||
|
interpolation="bilinear",
|
||||||
|
antialias=False,
|
||||||
|
data_format="channels_last",
|
||||||
):
|
):
|
||||||
"""Resize images to size using the specified method.
|
"""Resize images to size using the specified interpolation method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: Input image or batch of images. Must be 3D or 4D.
|
image: Input image or batch of images. Must be 3D or 4D.
|
||||||
size: Size of output image in `(height, width)` format.
|
size: Size of output image in `(height, width)` format.
|
||||||
method: Interpolation method. Available methods are `"nearest"`,
|
interpolation: Interpolation method. Available methods are `"nearest"`,
|
||||||
`"bilinear"`, and `"bicubic"`. Defaults to `"bilinear"`.
|
`"bilinear"`, and `"bicubic"`. Defaults to `"bilinear"`.
|
||||||
antialias: Whether to use an antialiasing filter when downsampling an
|
antialias: Whether to use an antialiasing filter when downsampling an
|
||||||
image. Defaults to `False`.
|
image. Defaults to `False`.
|
||||||
@ -97,8 +101,146 @@ def resize(
|
|||||||
|
|
||||||
if any_symbolic_tensors((image,)):
|
if any_symbolic_tensors((image,)):
|
||||||
return Resize(
|
return Resize(
|
||||||
size, method=method, antialias=antialias, data_format=data_format
|
size,
|
||||||
|
interpolation=interpolation,
|
||||||
|
antialias=antialias,
|
||||||
|
data_format=data_format,
|
||||||
).symbolic_call(image)
|
).symbolic_call(image)
|
||||||
return backend.image.resize(
|
return backend.image.resize(
|
||||||
image, size, method=method, antialias=antialias, data_format=data_format
|
image,
|
||||||
|
size,
|
||||||
|
interpolation=interpolation,
|
||||||
|
antialias=antialias,
|
||||||
|
data_format=data_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AffineTransform(Operation):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
interpolation="bilinear",
|
||||||
|
fill_mode="constant",
|
||||||
|
fill_value=0,
|
||||||
|
data_format="channels_last",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.interpolation = interpolation
|
||||||
|
self.fill_mode = fill_mode
|
||||||
|
self.fill_value = fill_value
|
||||||
|
self.data_format = data_format
|
||||||
|
|
||||||
|
def call(self, image, transform):
|
||||||
|
return backend.image.affine_transform(
|
||||||
|
image,
|
||||||
|
transform,
|
||||||
|
interpolation=self.interpolation,
|
||||||
|
fill_mode=self.fill_mode,
|
||||||
|
fill_value=self.fill_value,
|
||||||
|
data_format=self.data_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_output_spec(self, image, transform):
|
||||||
|
if len(image.shape) not in (3, 4):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image rank: expected rank 3 (single image) "
|
||||||
|
"or rank 4 (batch of images). Received input with shape: "
|
||||||
|
f"image.shape={image.shape}"
|
||||||
|
)
|
||||||
|
if len(transform.shape) not in (1, 2):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid transform rank: expected rank 1 (single transform) "
|
||||||
|
"or rank 2 (batch of transforms). Received input with shape: "
|
||||||
|
f"transform.shape={transform.shape}"
|
||||||
|
)
|
||||||
|
return KerasTensor(image.shape, dtype=image.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_core_export("keras_core.ops.image.affine_transform")
|
||||||
|
def affine_transform(
|
||||||
|
image,
|
||||||
|
transform,
|
||||||
|
interpolation="bilinear",
|
||||||
|
fill_mode="constant",
|
||||||
|
fill_value=0,
|
||||||
|
data_format="channels_last",
|
||||||
|
):
|
||||||
|
"""Applies the given transform(s) to the image(s).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image or batch of images. Must be 3D or 4D.
|
||||||
|
transform: Projective transform matrix/matrices. A vector of length 8 or
|
||||||
|
tensor of size N x 8. If one row of transform is
|
||||||
|
`[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps the output point
|
||||||
|
`(x, y)` to a transformed input point
|
||||||
|
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
|
||||||
|
where `k = c0 x + c1 y + 1`. The transform is inverted compared to
|
||||||
|
the transform mapping input points to output points. Note that
|
||||||
|
gradients are not backpropagated into transformation parameters.
|
||||||
|
Note that `c0` and `c1` are only effective when using TensorFlow
|
||||||
|
backend and will be considered as `0` when using other backends.
|
||||||
|
interpolation: Interpolation method. Available methods are `"nearest"`,
|
||||||
|
and `"bilinear"`. Defaults to `"bilinear"`.
|
||||||
|
fill_mode: Points outside the boundaries of the input are filled
|
||||||
|
according to the given mode. Available methods are `"constant"`,
|
||||||
|
`"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`.
|
||||||
|
Note that `"wrap"` is not supported by Torch backend.
|
||||||
|
fill_value: Value used for points outside the boundaries of the input if
|
||||||
|
`fill_mode="constant"`. Defaults to `0`.
|
||||||
|
data_format: string, either `"channels_last"` or `"channels_first"`.
|
||||||
|
The ordering of the dimensions in the inputs. `"channels_last"`
|
||||||
|
corresponds to inputs with shape `(batch, height, width, channels)`
|
||||||
|
while `"channels_first"` corresponds to inputs with shape
|
||||||
|
`(batch, channels, height, weight)`. It defaults to the
|
||||||
|
`image_data_format` value found in your Keras config file at
|
||||||
|
`~/.keras/keras.json`. If you never set it, then it will be
|
||||||
|
`"channels_last"`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Applied affine transform image or batch of images.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
>>> x = np.random.random((2, 64, 80, 3)) # batch of 2 RGB images
|
||||||
|
>>> transform = np.array(
|
||||||
|
... [
|
||||||
|
... [1.5, 0, -20, 0, 1.5, -16, 0, 0], # zoom
|
||||||
|
... [1, 0, -20, 0, 1, -16, 0, 0], # translation
|
||||||
|
... ]
|
||||||
|
... )
|
||||||
|
>>> y = keras_core.ops.image.affine_transform(x, transform)
|
||||||
|
>>> y.shape
|
||||||
|
(2, 64, 80, 3)
|
||||||
|
|
||||||
|
>>> x = np.random.random((64, 80, 3)) # single RGB image
|
||||||
|
>>> transform = np.array([1.0, 0.5, -20, 0.5, 1.0, -16, 0, 0]) # shear
|
||||||
|
>>> y = keras_core.ops.image.affine_transform(x, transform)
|
||||||
|
>>> y.shape
|
||||||
|
(64, 80, 3)
|
||||||
|
|
||||||
|
>>> x = np.random.random((2, 3, 64, 80)) # batch of 2 RGB images
|
||||||
|
>>> transform = np.array(
|
||||||
|
... [
|
||||||
|
... [1.5, 0, -20, 0, 1.5, -16, 0, 0], # zoom
|
||||||
|
... [1, 0, -20, 0, 1, -16, 0, 0], # translation
|
||||||
|
... ]
|
||||||
|
... )
|
||||||
|
>>> y = keras_core.ops.image.affine_transform(x, transform,
|
||||||
|
... data_format="channels_first")
|
||||||
|
>>> y.shape
|
||||||
|
(2, 3, 64, 80)
|
||||||
|
"""
|
||||||
|
if any_symbolic_tensors((image, transform)):
|
||||||
|
return AffineTransform(
|
||||||
|
interpolation=interpolation,
|
||||||
|
fill_mode=fill_mode,
|
||||||
|
fill_value=fill_value,
|
||||||
|
data_format=data_format,
|
||||||
|
).symbolic_call(image, transform)
|
||||||
|
return backend.image.affine_transform(
|
||||||
|
image,
|
||||||
|
transform,
|
||||||
|
interpolation=interpolation,
|
||||||
|
fill_mode=fill_mode,
|
||||||
|
fill_value=fill_value,
|
||||||
|
data_format=data_format,
|
||||||
)
|
)
|
||||||
|
@ -18,6 +18,12 @@ class ImageOpsDynamicShapeTest(testing.TestCase):
|
|||||||
out = kimage.resize(x, size=(15, 15))
|
out = kimage.resize(x, size=(15, 15))
|
||||||
self.assertEqual(out.shape, (15, 15, 3))
|
self.assertEqual(out.shape, (15, 15, 3))
|
||||||
|
|
||||||
|
def test_affine_transform(self):
|
||||||
|
x = KerasTensor([None, 20, 20, 3])
|
||||||
|
transform = KerasTensor([None, 8])
|
||||||
|
out = kimage.affine_transform(x, transform)
|
||||||
|
self.assertEqual(out.shape, (None, 20, 20, 3))
|
||||||
|
|
||||||
|
|
||||||
class ImageOpsStaticShapeTest(testing.TestCase):
|
class ImageOpsStaticShapeTest(testing.TestCase):
|
||||||
def test_resize(self):
|
def test_resize(self):
|
||||||
@ -25,6 +31,12 @@ class ImageOpsStaticShapeTest(testing.TestCase):
|
|||||||
out = kimage.resize(x, size=(15, 15))
|
out = kimage.resize(x, size=(15, 15))
|
||||||
self.assertEqual(out.shape, (15, 15, 3))
|
self.assertEqual(out.shape, (15, 15, 3))
|
||||||
|
|
||||||
|
def test_affine_transform(self):
|
||||||
|
x = KerasTensor([20, 20, 3])
|
||||||
|
transform = KerasTensor([8])
|
||||||
|
out = kimage.affine_transform(x, transform)
|
||||||
|
self.assertEqual(out.shape, (20, 20, 3))
|
||||||
|
|
||||||
|
|
||||||
class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
|
class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
@ -42,20 +54,20 @@ class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
|
|||||||
("bilinear", True, "channels_first"),
|
("bilinear", True, "channels_first"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_resize(self, method, antialias, data_format):
|
def test_resize(self, interpolation, antialias, data_format):
|
||||||
if backend.backend() == "torch":
|
if backend.backend() == "torch":
|
||||||
if "lanczos" in method:
|
if "lanczos" in interpolation:
|
||||||
self.skipTest(
|
self.skipTest(
|
||||||
"Resizing with Lanczos interpolation is "
|
"Resizing with Lanczos interpolation is "
|
||||||
"not supported by the PyTorch backend. "
|
"not supported by the PyTorch backend. "
|
||||||
f"Received: method={method}."
|
f"Received: interpolation={interpolation}."
|
||||||
)
|
)
|
||||||
if method == "bicubic" and antialias is False:
|
if interpolation == "bicubic" and antialias is False:
|
||||||
self.skipTest(
|
self.skipTest(
|
||||||
"Resizing with Bicubic interpolation in "
|
"Resizing with Bicubic interpolation in "
|
||||||
"PyTorch backend produces noise. Please "
|
"PyTorch backend produces noise. Please "
|
||||||
"turn on anti-aliasing. "
|
"turn on anti-aliasing. "
|
||||||
f"Received: method={method}, "
|
f"Received: interpolation={interpolation}, "
|
||||||
f"antialias={antialias}."
|
f"antialias={antialias}."
|
||||||
)
|
)
|
||||||
# Unbatched case
|
# Unbatched case
|
||||||
@ -66,14 +78,14 @@ class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
|
|||||||
out = kimage.resize(
|
out = kimage.resize(
|
||||||
x,
|
x,
|
||||||
size=(25, 25),
|
size=(25, 25),
|
||||||
method=method,
|
interpolation=interpolation,
|
||||||
antialias=antialias,
|
antialias=antialias,
|
||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
)
|
)
|
||||||
if data_format == "channels_first":
|
if data_format == "channels_first":
|
||||||
x = np.transpose(x, (1, 2, 0))
|
x = np.transpose(x, (1, 2, 0))
|
||||||
ref_out = tf.image.resize(
|
ref_out = tf.image.resize(
|
||||||
x, size=(25, 25), method=method, antialias=antialias
|
x, size=(25, 25), method=interpolation, antialias=antialias
|
||||||
)
|
)
|
||||||
if data_format == "channels_first":
|
if data_format == "channels_first":
|
||||||
ref_out = np.transpose(ref_out, (2, 0, 1))
|
ref_out = np.transpose(ref_out, (2, 0, 1))
|
||||||
@ -88,16 +100,111 @@ class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
|
|||||||
out = kimage.resize(
|
out = kimage.resize(
|
||||||
x,
|
x,
|
||||||
size=(25, 25),
|
size=(25, 25),
|
||||||
method=method,
|
interpolation=interpolation,
|
||||||
antialias=antialias,
|
antialias=antialias,
|
||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
)
|
)
|
||||||
if data_format == "channels_first":
|
if data_format == "channels_first":
|
||||||
x = np.transpose(x, (0, 2, 3, 1))
|
x = np.transpose(x, (0, 2, 3, 1))
|
||||||
ref_out = tf.image.resize(
|
ref_out = tf.image.resize(
|
||||||
x, size=(25, 25), method=method, antialias=antialias
|
x, size=(25, 25), method=interpolation, antialias=antialias
|
||||||
)
|
)
|
||||||
if data_format == "channels_first":
|
if data_format == "channels_first":
|
||||||
ref_out = np.transpose(ref_out, (0, 3, 1, 2))
|
ref_out = np.transpose(ref_out, (0, 3, 1, 2))
|
||||||
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
|
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
|
||||||
self.assertAllClose(ref_out, out, atol=0.3)
|
self.assertAllClose(ref_out, out, atol=0.3)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
[
|
||||||
|
("bilinear", "constant", "channels_last"),
|
||||||
|
("nearest", "constant", "channels_last"),
|
||||||
|
("bilinear", "nearest", "channels_last"),
|
||||||
|
("nearest", "nearest", "channels_last"),
|
||||||
|
("bilinear", "wrap", "channels_last"),
|
||||||
|
("nearest", "wrap", "channels_last"),
|
||||||
|
("bilinear", "reflect", "channels_last"),
|
||||||
|
("nearest", "reflect", "channels_last"),
|
||||||
|
("bilinear", "constant", "channels_first"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_affine_transform(self, interpolation, fill_mode, data_format):
|
||||||
|
if fill_mode == "wrap" and backend.backend() == "torch":
|
||||||
|
self.skipTest(
|
||||||
|
"Applying affine transform with fill_mode=wrap is not support"
|
||||||
|
" in torch backend"
|
||||||
|
)
|
||||||
|
if fill_mode == "wrap" and backend.backend() in ("jax", "numpy"):
|
||||||
|
self.skipTest(
|
||||||
|
"The numerical results of applying affine transform with "
|
||||||
|
"fill_mode=wrap in tensorflow is inconsistent with jax and "
|
||||||
|
"numpy backends"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unbatched case
|
||||||
|
if data_format == "channels_first":
|
||||||
|
x = np.random.random((3, 50, 50)) * 255
|
||||||
|
else:
|
||||||
|
x = np.random.random((50, 50, 3)) * 255
|
||||||
|
transform = np.random.random(size=(6))
|
||||||
|
transform = np.pad(transform, (0, 2)) # makes c0, c1 always 0
|
||||||
|
out = kimage.affine_transform(
|
||||||
|
x,
|
||||||
|
transform,
|
||||||
|
interpolation=interpolation,
|
||||||
|
fill_mode=fill_mode,
|
||||||
|
data_format=data_format,
|
||||||
|
)
|
||||||
|
if data_format == "channels_first":
|
||||||
|
x = np.transpose(x, (1, 2, 0))
|
||||||
|
ref_out = tf.raw_ops.ImageProjectiveTransformV3(
|
||||||
|
images=tf.expand_dims(x, axis=0),
|
||||||
|
transforms=tf.cast(tf.expand_dims(transform, axis=0), tf.float32),
|
||||||
|
output_shape=tf.shape(x)[:-1],
|
||||||
|
fill_value=0,
|
||||||
|
interpolation=interpolation.upper(),
|
||||||
|
fill_mode=fill_mode.upper(),
|
||||||
|
)
|
||||||
|
ref_out = ref_out[0]
|
||||||
|
if data_format == "channels_first":
|
||||||
|
ref_out = np.transpose(ref_out, (2, 0, 1))
|
||||||
|
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
|
||||||
|
if backend.backend() == "torch":
|
||||||
|
# TODO: cannot pass with torch backend
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
self.assertAllClose(ref_out, out, atol=0.3)
|
||||||
|
else:
|
||||||
|
self.assertAllClose(ref_out, out, atol=0.3)
|
||||||
|
|
||||||
|
# Batched case
|
||||||
|
if data_format == "channels_first":
|
||||||
|
x = np.random.random((2, 3, 50, 50)) * 255
|
||||||
|
else:
|
||||||
|
x = np.random.random((2, 50, 50, 3)) * 255
|
||||||
|
transform = np.random.random(size=(2, 6))
|
||||||
|
transform = np.pad(transform, [(0, 0), (0, 2)]) # makes c0, c1 always 0
|
||||||
|
out = kimage.affine_transform(
|
||||||
|
x,
|
||||||
|
transform,
|
||||||
|
interpolation=interpolation,
|
||||||
|
fill_mode=fill_mode,
|
||||||
|
data_format=data_format,
|
||||||
|
)
|
||||||
|
if data_format == "channels_first":
|
||||||
|
x = np.transpose(x, (0, 2, 3, 1))
|
||||||
|
ref_out = tf.raw_ops.ImageProjectiveTransformV3(
|
||||||
|
images=x,
|
||||||
|
transforms=tf.cast(transform, tf.float32),
|
||||||
|
output_shape=tf.shape(x)[1:-1],
|
||||||
|
fill_value=0,
|
||||||
|
interpolation=interpolation.upper(),
|
||||||
|
fill_mode=fill_mode.upper(),
|
||||||
|
)
|
||||||
|
if data_format == "channels_first":
|
||||||
|
ref_out = np.transpose(ref_out, (0, 3, 1, 2))
|
||||||
|
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
|
||||||
|
if backend.backend() == "torch":
|
||||||
|
# TODO: cannot pass with torch backend
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
self.assertAllClose(ref_out, out, atol=0.3)
|
||||||
|
else:
|
||||||
|
self.assertAllClose(ref_out, out, atol=0.3)
|
||||||
|
@ -439,7 +439,7 @@ def smart_resize(
|
|||||||
]
|
]
|
||||||
|
|
||||||
img = backend_module.image.resize(
|
img = backend_module.image.resize(
|
||||||
img, size=size, method=interpolation, data_format=data_format
|
img, size=size, interpolation=interpolation, data_format=data_format
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(x, np.ndarray):
|
if isinstance(x, np.ndarray):
|
||||||
|
Loading…
Reference in New Issue
Block a user