Add affine_transform op to all backends (#477)

* Add affine op

* Sync import convention

* Use `np.random.random`

* Refactor jax implementation

* Fix

* Address fchollet's comments

* Update docstring

* Fix test

* Replace method with interpolation

* Replace method with interpolation

* Replace method with interpolation

* Update test
This commit is contained in:
HongYu 2023-07-21 01:02:36 +08:00 committed by Francois Chollet
parent ed26f7f4c5
commit d7b3eae31f
9 changed files with 777 additions and 49 deletions

@ -1,6 +1,11 @@
import jax
import functools
RESIZE_METHODS = (
import jax
import jax.numpy as jnp
from keras_core.backend.jax.core import convert_to_tensor
RESIZE_INTERPOLATIONS = (
"bilinear",
"nearest",
"lanczos3",
@ -10,12 +15,16 @@ RESIZE_METHODS = (
def resize(
image, size, method="bilinear", antialias=False, data_format="channels_last"
image,
size,
interpolation="bilinear",
antialias=False,
data_format="channels_last",
):
if method not in RESIZE_METHODS:
if interpolation not in RESIZE_INTERPOLATIONS:
raise ValueError(
"Invalid value for argument `method`. Expected of one "
f"{RESIZE_METHODS}. Received: method={method}"
"Invalid value for argument `interpolation`. Expected of one "
f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}"
)
if not len(size) == 2:
raise ValueError(
@ -39,4 +48,117 @@ def resize(
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
return jax.image.resize(image, size, method=method, antialias=antialias)
return jax.image.resize(
image, size, method=interpolation, antialias=antialias
)
AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order
"nearest": 0,
"bilinear": 1,
}
AFFINE_TRANSFORM_FILL_MODES = {
"constant": "grid-constant",
"nearest": "nearest",
"wrap": "grid-wrap",
"mirror": "mirror",
"reflect": "reflect",
}
def affine_transform(
image,
transform,
interpolation="bilinear",
fill_mode="constant",
fill_value=0,
data_format="channels_last",
):
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
raise ValueError(
"Invalid value for argument `interpolation`. Expected of one "
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
f"interpolation={interpolation}"
)
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES.keys():
raise ValueError(
"Invalid value for argument `fill_mode`. Expected of one "
f"{set(AFFINE_TRANSFORM_FILL_MODES.keys())}. "
f"Received: fill_mode={fill_mode}"
)
transform = convert_to_tensor(transform)
if len(image.shape) not in (3, 4):
raise ValueError(
"Invalid image rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
if len(transform.shape) not in (1, 2):
raise ValueError(
"Invalid transform rank: expected rank 1 (single transform) "
"or rank 2 (batch of transforms). Received input with shape: "
f"transform.shape={transform.shape}"
)
# unbatched case
need_squeeze = False
if len(image.shape) == 3:
image = jnp.expand_dims(image, axis=0)
need_squeeze = True
if len(transform.shape) == 1:
transform = jnp.expand_dims(transform, axis=0)
if data_format == "channels_first":
image = jnp.transpose(image, (0, 2, 3, 1))
batch_size = image.shape[0]
# get indices
meshgrid = jnp.meshgrid(
*[jnp.arange(size) for size in image.shape[1:]], indexing="ij"
)
indices = jnp.concatenate(
[jnp.expand_dims(x, axis=-1) for x in meshgrid], axis=-1
)
indices = jnp.tile(indices, (batch_size, 1, 1, 1, 1))
# swap the values
a0 = transform[:, 0]
a2 = transform[:, 2]
b1 = transform[:, 4]
b2 = transform[:, 5]
transform = transform.at[:, 0].set(b1)
transform = transform.at[:, 2].set(b2)
transform = transform.at[:, 4].set(a0)
transform = transform.at[:, 5].set(a2)
# deal with transform
transform = jnp.pad(
transform, pad_width=[[0, 0], [0, 1]], constant_values=1
)
transform = jnp.reshape(transform, (batch_size, 3, 3))
offset = transform[:, 0:2, 2]
offset = jnp.pad(offset, pad_width=[[0, 0], [0, 1]])
transform = transform.at[:, 0:2, 2].set(0)
# transform the indices
coordinates = jnp.einsum("Bhwij, Bjk -> Bhwik", indices, transform)
coordinates = jnp.moveaxis(coordinates, source=-1, destination=1)
coordinates += jnp.reshape(a=offset, newshape=(*offset.shape, 1, 1, 1))
# apply affine transformation
_map_coordinates = functools.partial(
jax.scipy.ndimage.map_coordinates,
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
mode=fill_mode,
cval=fill_value,
)
affined = jax.vmap(_map_coordinates)(image, coordinates)
if data_format == "channels_first":
affined = jnp.transpose(affined, (0, 3, 1, 2))
if need_squeeze:
affined = jnp.squeeze(affined, axis=0)
return affined

@ -1,7 +1,10 @@
import jax
import numpy as np
import scipy.ndimage
RESIZE_METHODS = (
from keras_core.backend.numpy.core import convert_to_tensor
RESIZE_INTERPOLATIONS = (
"bilinear",
"nearest",
"lanczos3",
@ -11,12 +14,16 @@ RESIZE_METHODS = (
def resize(
image, size, method="bilinear", antialias=False, data_format="channels_last"
image,
size,
interpolation="bilinear",
antialias=False,
data_format="channels_last",
):
if method not in RESIZE_METHODS:
if interpolation not in RESIZE_INTERPOLATIONS:
raise ValueError(
"Invalid value for argument `method`. Expected of one "
f"{RESIZE_METHODS}. Received: method={method}"
"Invalid value for argument `interpolation`. Expected of one "
f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}"
)
if not len(size) == 2:
raise ValueError(
@ -41,5 +48,121 @@ def resize(
f"image.shape={image.shape}"
)
return np.array(
jax.image.resize(image, size, method=method, antialias=antialias)
jax.image.resize(image, size, method=interpolation, antialias=antialias)
)
AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order
"nearest": 0,
"bilinear": 1,
}
AFFINE_TRANSFORM_FILL_MODES = {
"constant": "grid-constant",
"nearest": "nearest",
"wrap": "grid-wrap",
"mirror": "mirror",
"reflect": "reflect",
}
def affine_transform(
image,
transform,
interpolation="bilinear",
fill_mode="constant",
fill_value=0,
data_format="channels_last",
):
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
raise ValueError(
"Invalid value for argument `interpolation`. Expected of one "
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
f"interpolation={interpolation}"
)
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES.keys():
raise ValueError(
"Invalid value for argument `fill_mode`. Expected of one "
f"{set(AFFINE_TRANSFORM_FILL_MODES.keys())}. "
f"Received: fill_mode={fill_mode}"
)
transform = convert_to_tensor(transform)
if len(image.shape) not in (3, 4):
raise ValueError(
"Invalid image rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
if len(transform.shape) not in (1, 2):
raise ValueError(
"Invalid transform rank: expected rank 1 (single transform) "
"or rank 2 (batch of transforms). Received input with shape: "
f"transform.shape={transform.shape}"
)
# unbatched case
need_squeeze = False
if len(image.shape) == 3:
image = np.expand_dims(image, axis=0)
need_squeeze = True
if len(transform.shape) == 1:
transform = np.expand_dims(transform, axis=0)
if data_format == "channels_first":
image = np.transpose(image, (0, 2, 3, 1))
batch_size = image.shape[0]
# get indices
meshgrid = np.meshgrid(
*[np.arange(size) for size in image.shape[1:]], indexing="ij"
)
indices = np.concatenate(
[np.expand_dims(x, axis=-1) for x in meshgrid], axis=-1
)
indices = np.tile(indices, (batch_size, 1, 1, 1, 1))
# swap the values
a0 = transform[:, 0].copy()
a2 = transform[:, 2].copy()
b1 = transform[:, 4].copy()
b2 = transform[:, 5].copy()
transform[:, 0] = b1
transform[:, 2] = b2
transform[:, 4] = a0
transform[:, 5] = a2
# deal with transform
transform = np.pad(transform, pad_width=[[0, 0], [0, 1]], constant_values=1)
transform = np.reshape(transform, (batch_size, 3, 3))
offset = transform[:, 0:2, 2].copy()
offset = np.pad(offset, pad_width=[[0, 0], [0, 1]])
transform[:, 0:2, 2] = 0
# transform the indices
coordinates = np.einsum("Bhwij, Bjk -> Bhwik", indices, transform)
coordinates = np.moveaxis(coordinates, source=-1, destination=1)
coordinates += np.reshape(a=offset, newshape=(*offset.shape, 1, 1, 1))
# apply affine transformation
affined = np.stack(
[
scipy.ndimage.map_coordinates(
image[i],
coordinates[i],
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
mode=AFFINE_TRANSFORM_FILL_MODES[fill_mode],
cval=fill_value,
prefilter=False,
)
for i in range(batch_size)
],
axis=0,
)
if data_format == "channels_first":
affined = np.transpose(affined, (0, 3, 1, 2))
if need_squeeze:
affined = np.squeeze(affined, axis=0)
return affined

@ -1,6 +1,6 @@
import tensorflow as tf
RESIZE_METHODS = (
RESIZE_INTERPOLATIONS = (
"bilinear",
"nearest",
"lanczos3",
@ -10,12 +10,16 @@ RESIZE_METHODS = (
def resize(
image, size, method="bilinear", antialias=False, data_format="channels_last"
image,
size,
interpolation="bilinear",
antialias=False,
data_format="channels_last",
):
if method not in RESIZE_METHODS:
if interpolation not in RESIZE_INTERPOLATIONS:
raise ValueError(
"Invalid value for argument `method`. Expected of one "
f"{RESIZE_METHODS}. Received: method={method}"
"Invalid value for argument `interpolation`. Expected of one "
f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}"
)
if not len(size) == 2:
raise ValueError(
@ -35,10 +39,83 @@ def resize(
f"image.shape={image.shape}"
)
resized = tf.image.resize(image, size, method=method, antialias=antialias)
resized = tf.image.resize(
image, size, method=interpolation, antialias=antialias
)
if data_format == "channels_first":
if len(image.shape) == 4:
resized = tf.transpose(resized, (0, 3, 1, 2))
elif len(image.shape) == 3:
resized = tf.transpose(resized, (2, 0, 1))
return resized
AFFINE_TRANSFORM_INTERPOLATIONS = (
"nearest",
"bilinear",
)
AFFINE_TRANSFORM_FILL_MODES = (
"constant",
"nearest",
"wrap",
# "mirror", not supported by TF
"reflect",
)
def affine_transform(
image,
transform,
interpolation="bilinear",
fill_mode="constant",
fill_value=0,
data_format="channels_last",
):
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS:
raise ValueError(
"Invalid value for argument `interpolation`. Expected of one "
f"{AFFINE_TRANSFORM_INTERPOLATIONS}. Received: "
f"interpolation={interpolation}"
)
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES:
raise ValueError(
"Invalid value for argument `fill_mode`. Expected of one "
f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}"
)
if len(image.shape) not in (3, 4):
raise ValueError(
"Invalid image rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
if len(transform.shape) not in (1, 2):
raise ValueError(
"Invalid transform rank: expected rank 1 (single transform) "
"or rank 2 (batch of transforms). Received input with shape: "
f"transform.shape={transform.shape}"
)
# unbatched case
need_squeeze = False
if len(image.shape) == 3:
image = tf.expand_dims(image, axis=0)
need_squeeze = True
if len(transform.shape) == 1:
transform = tf.expand_dims(transform, axis=0)
if data_format == "channels_first":
image = tf.transpose(image, (0, 2, 3, 1))
affined = tf.raw_ops.ImageProjectiveTransformV3(
images=image,
transforms=tf.cast(transform, dtype=tf.float32),
output_shape=tf.shape(image)[1:-1],
fill_value=fill_value,
interpolation=interpolation.upper(),
fill_mode=fill_mode.upper(),
)
if data_format == "channels_first":
affined = tf.transpose(affined, (0, 3, 1, 2))
if need_squeeze:
affined = tf.squeeze(affined, axis=0)
return affined

@ -1,21 +1,28 @@
import torch
import torch.nn.functional as tnn
from keras_core.backend.torch.core import convert_to_tensor
RESIZE_METHODS = {} # populated after torchvision import
RESIZE_INTERPOLATIONS = {} # populated after torchvision import
UNSUPPORTED_METHODS = (
UNSUPPORTED_INTERPOLATIONS = (
"lanczos3",
"lanczos5",
)
def resize(
image, size, method="bilinear", antialias=False, data_format="channels_last"
image,
size,
interpolation="bilinear",
antialias=False,
data_format="channels_last",
):
try:
import torchvision
from torchvision.transforms import InterpolationMode as im
RESIZE_METHODS.update(
RESIZE_INTERPOLATIONS.update(
{
"bilinear": im.BILINEAR,
"nearest": im.NEAREST_EXACT,
@ -27,16 +34,16 @@ def resize(
"The torchvision package is necessary to use `resize` with the "
"torch backend. Please install torchvision."
)
if method in UNSUPPORTED_METHODS:
if interpolation in UNSUPPORTED_INTERPOLATIONS:
raise ValueError(
"Resizing with Lanczos interpolation is "
"not supported by the PyTorch backend. "
f"Received: method={method}."
f"Received: interpolation={interpolation}."
)
if method not in RESIZE_METHODS:
if interpolation not in RESIZE_INTERPOLATIONS:
raise ValueError(
"Invalid value for argument `method`. Expected of one "
f"{RESIZE_METHODS}. Received: method={method}"
"Invalid value for argument `interpolation`. Expected of one "
f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}"
)
if not len(size) == 2:
raise ValueError(
@ -60,7 +67,7 @@ def resize(
resized = torchvision.transforms.functional.resize(
img=image,
size=size,
interpolation=RESIZE_METHODS[method],
interpolation=RESIZE_INTERPOLATIONS[interpolation],
antialias=antialias,
)
if data_format == "channels_last":
@ -69,3 +76,153 @@ def resize(
elif len(image.shape) == 3:
resized = resized.permute((1, 2, 0))
return resized
AFFINE_TRANSFORM_INTERPOLATIONS = (
"nearest",
"bilinear",
)
AFFINE_TRANSFORM_FILL_MODES = {
"constant": "zeros",
"nearest": "border",
# "wrap", not supported by torch
# "mirror", not supported by torch
"reflect": "reflection",
}
def _apply_grid_transform(
img,
grid,
interpolation="bilinear",
fill_mode="zeros",
fill_value=None,
):
"""
Modified from https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional/_geometry.py
""" # noqa: E501
# We are using context knowledge that grid should have float dtype
fp = img.dtype == grid.dtype
float_img = img if fp else img.to(grid.dtype)
shape = float_img.shape
# Append a dummy mask for customized fill colors, should be faster than
# grid_sample() twice
if fill_value is not None:
mask = torch.ones(
(shape[0], 1, shape[2], shape[3]),
dtype=float_img.dtype,
device=float_img.device,
)
float_img = torch.cat((float_img, mask), dim=1)
float_img = tnn.grid_sample(
float_img,
grid,
mode=interpolation,
padding_mode=fill_mode,
align_corners=False,
)
# Fill with required color
if fill_value is not None:
float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
mask = mask.expand_as(float_img)
fill_list = (
fill_value
if isinstance(fill_value, (tuple, list))
else [float(fill_value)]
)
fill_img = torch.tensor(
fill_list, dtype=float_img.dtype, device=float_img.device
).view(1, -1, 1, 1)
if interpolation == "nearest":
bool_mask = mask < 0.5
float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
else: # 'bilinear'
# The following is mathematically equivalent to:
# img * mask + (1.0 - mask) * fill =
# img * mask - fill * mask + fill =
# mask * (img - fill) + fill
float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)
img = float_img.round_().to(img.dtype) if not fp else float_img
return img
def affine_transform(
image,
transform,
interpolation="bilinear",
fill_mode="constant",
fill_value=0,
data_format="channels_last",
):
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS:
raise ValueError(
"Invalid value for argument `interpolation`. Expected of one "
f"{AFFINE_TRANSFORM_INTERPOLATIONS}. Received: "
f"interpolation={interpolation}"
)
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES.keys():
raise ValueError(
"Invalid value for argument `fill_mode`. Expected of one "
f"{set(AFFINE_TRANSFORM_FILL_MODES.keys())}. "
f"Received: fill_mode={fill_mode}"
)
image = convert_to_tensor(image)
transform = convert_to_tensor(transform)
if image.ndim not in (3, 4):
raise ValueError(
"Invalid image rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
if transform.ndim not in (1, 2):
raise ValueError(
"Invalid transform rank: expected rank 1 (single transform) "
"or rank 2 (batch of transforms). Received input with shape: "
f"transform.shape={transform.shape}"
)
if fill_mode != "constant":
fill_value = None
fill_mode = AFFINE_TRANSFORM_FILL_MODES[fill_mode]
# unbatched case
need_squeeze = False
if image.ndim == 3:
image = image.unsqueeze(dim=0)
need_squeeze = True
if transform.ndim == 1:
transform = transform.unsqueeze(dim=0)
if data_format == "channels_last":
image = image.permute((0, 3, 1, 2))
# deal with transform
h, w = image.shape[2], image.shape[3]
theta = torch.zeros((image.shape[0], 2, 3)).to(transform)
theta[:, 0, 0] = transform[:, 0]
theta[:, 0, 1] = transform[:, 1] * h / w
theta[:, 0, 2] = (
transform[:, 2] * 2 / w + theta[:, 0, 0] + theta[:, 0, 1] - 1
)
theta[:, 1, 0] = transform[:, 3] * w / h
theta[:, 1, 1] = transform[:, 4]
theta[:, 1, 2] = (
transform[:, 5] * 2 / h + theta[:, 1, 0] + theta[:, 1, 1] - 1
)
grid = tnn.affine_grid(theta, image.shape)
affined = _apply_grid_transform(
image, grid, interpolation, fill_mode, fill_value
)
if data_format == "channels_last":
affined = affined.permute((0, 2, 3, 1))
if need_squeeze:
affined = affined.squeeze(dim=0)
return affined

@ -81,7 +81,7 @@ class Resizing(TFDataLayer):
outputs = self.backend.image.resize(
inputs,
size=size,
method=self.interpolation,
interpolation=self.interpolation,
data_format=self.data_format,
)
return outputs

@ -160,7 +160,7 @@ class UpSampling2D(Layer):
x = ops.repeat(x, height_factor, axis=1)
x = ops.repeat(x, width_factor, axis=2)
else:
x = ops.image.resize(x, new_shape, method=interpolation)
x = ops.image.resize(x, new_shape, interpolation=interpolation)
if data_format == "channels_first":
x = ops.transpose(x, [0, 3, 1, 2])

@ -9,13 +9,13 @@ class Resize(Operation):
def __init__(
self,
size,
method="bilinear",
interpolation="bilinear",
antialias=False,
data_format="channels_last",
):
super().__init__()
self.size = tuple(size)
self.method = method
self.interpolation = interpolation
self.antialias = antialias
self.data_format = data_format
@ -23,7 +23,7 @@ class Resize(Operation):
return backend.image.resize(
image,
self.size,
method=self.method,
interpolation=self.interpolation,
antialias=self.antialias,
data_format=self.data_format,
)
@ -53,14 +53,18 @@ class Resize(Operation):
@keras_core_export("keras_core.ops.image.resize")
def resize(
image, size, method="bilinear", antialias=False, data_format="channels_last"
image,
size,
interpolation="bilinear",
antialias=False,
data_format="channels_last",
):
"""Resize images to size using the specified method.
"""Resize images to size using the specified interpolation method.
Args:
image: Input image or batch of images. Must be 3D or 4D.
size: Size of output image in `(height, width)` format.
method: Interpolation method. Available methods are `"nearest"`,
interpolation: Interpolation method. Available methods are `"nearest"`,
`"bilinear"`, and `"bicubic"`. Defaults to `"bilinear"`.
antialias: Whether to use an antialiasing filter when downsampling an
image. Defaults to `False`.
@ -97,8 +101,146 @@ def resize(
if any_symbolic_tensors((image,)):
return Resize(
size, method=method, antialias=antialias, data_format=data_format
size,
interpolation=interpolation,
antialias=antialias,
data_format=data_format,
).symbolic_call(image)
return backend.image.resize(
image, size, method=method, antialias=antialias, data_format=data_format
image,
size,
interpolation=interpolation,
antialias=antialias,
data_format=data_format,
)
class AffineTransform(Operation):
def __init__(
self,
interpolation="bilinear",
fill_mode="constant",
fill_value=0,
data_format="channels_last",
):
super().__init__()
self.interpolation = interpolation
self.fill_mode = fill_mode
self.fill_value = fill_value
self.data_format = data_format
def call(self, image, transform):
return backend.image.affine_transform(
image,
transform,
interpolation=self.interpolation,
fill_mode=self.fill_mode,
fill_value=self.fill_value,
data_format=self.data_format,
)
def compute_output_spec(self, image, transform):
if len(image.shape) not in (3, 4):
raise ValueError(
"Invalid image rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
if len(transform.shape) not in (1, 2):
raise ValueError(
"Invalid transform rank: expected rank 1 (single transform) "
"or rank 2 (batch of transforms). Received input with shape: "
f"transform.shape={transform.shape}"
)
return KerasTensor(image.shape, dtype=image.dtype)
@keras_core_export("keras_core.ops.image.affine_transform")
def affine_transform(
image,
transform,
interpolation="bilinear",
fill_mode="constant",
fill_value=0,
data_format="channels_last",
):
"""Applies the given transform(s) to the image(s).
Args:
image: Input image or batch of images. Must be 3D or 4D.
transform: Projective transform matrix/matrices. A vector of length 8 or
tensor of size N x 8. If one row of transform is
`[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps the output point
`(x, y)` to a transformed input point
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
where `k = c0 x + c1 y + 1`. The transform is inverted compared to
the transform mapping input points to output points. Note that
gradients are not backpropagated into transformation parameters.
Note that `c0` and `c1` are only effective when using TensorFlow
backend and will be considered as `0` when using other backends.
interpolation: Interpolation method. Available methods are `"nearest"`,
and `"bilinear"`. Defaults to `"bilinear"`.
fill_mode: Points outside the boundaries of the input are filled
according to the given mode. Available methods are `"constant"`,
`"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`.
Note that `"wrap"` is not supported by Torch backend.
fill_value: Value used for points outside the boundaries of the input if
`fill_mode="constant"`. Defaults to `0`.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, height, weight)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
Returns:
Applied affine transform image or batch of images.
Examples:
>>> x = np.random.random((2, 64, 80, 3)) # batch of 2 RGB images
>>> transform = np.array(
... [
... [1.5, 0, -20, 0, 1.5, -16, 0, 0], # zoom
... [1, 0, -20, 0, 1, -16, 0, 0], # translation
... ]
... )
>>> y = keras_core.ops.image.affine_transform(x, transform)
>>> y.shape
(2, 64, 80, 3)
>>> x = np.random.random((64, 80, 3)) # single RGB image
>>> transform = np.array([1.0, 0.5, -20, 0.5, 1.0, -16, 0, 0]) # shear
>>> y = keras_core.ops.image.affine_transform(x, transform)
>>> y.shape
(64, 80, 3)
>>> x = np.random.random((2, 3, 64, 80)) # batch of 2 RGB images
>>> transform = np.array(
... [
... [1.5, 0, -20, 0, 1.5, -16, 0, 0], # zoom
... [1, 0, -20, 0, 1, -16, 0, 0], # translation
... ]
... )
>>> y = keras_core.ops.image.affine_transform(x, transform,
... data_format="channels_first")
>>> y.shape
(2, 3, 64, 80)
"""
if any_symbolic_tensors((image, transform)):
return AffineTransform(
interpolation=interpolation,
fill_mode=fill_mode,
fill_value=fill_value,
data_format=data_format,
).symbolic_call(image, transform)
return backend.image.affine_transform(
image,
transform,
interpolation=interpolation,
fill_mode=fill_mode,
fill_value=fill_value,
data_format=data_format,
)

@ -18,6 +18,12 @@ class ImageOpsDynamicShapeTest(testing.TestCase):
out = kimage.resize(x, size=(15, 15))
self.assertEqual(out.shape, (15, 15, 3))
def test_affine_transform(self):
x = KerasTensor([None, 20, 20, 3])
transform = KerasTensor([None, 8])
out = kimage.affine_transform(x, transform)
self.assertEqual(out.shape, (None, 20, 20, 3))
class ImageOpsStaticShapeTest(testing.TestCase):
def test_resize(self):
@ -25,6 +31,12 @@ class ImageOpsStaticShapeTest(testing.TestCase):
out = kimage.resize(x, size=(15, 15))
self.assertEqual(out.shape, (15, 15, 3))
def test_affine_transform(self):
x = KerasTensor([20, 20, 3])
transform = KerasTensor([8])
out = kimage.affine_transform(x, transform)
self.assertEqual(out.shape, (20, 20, 3))
class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
@ -42,20 +54,20 @@ class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
("bilinear", True, "channels_first"),
]
)
def test_resize(self, method, antialias, data_format):
def test_resize(self, interpolation, antialias, data_format):
if backend.backend() == "torch":
if "lanczos" in method:
if "lanczos" in interpolation:
self.skipTest(
"Resizing with Lanczos interpolation is "
"not supported by the PyTorch backend. "
f"Received: method={method}."
f"Received: interpolation={interpolation}."
)
if method == "bicubic" and antialias is False:
if interpolation == "bicubic" and antialias is False:
self.skipTest(
"Resizing with Bicubic interpolation in "
"PyTorch backend produces noise. Please "
"turn on anti-aliasing. "
f"Received: method={method}, "
f"Received: interpolation={interpolation}, "
f"antialias={antialias}."
)
# Unbatched case
@ -66,14 +78,14 @@ class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
out = kimage.resize(
x,
size=(25, 25),
method=method,
interpolation=interpolation,
antialias=antialias,
data_format=data_format,
)
if data_format == "channels_first":
x = np.transpose(x, (1, 2, 0))
ref_out = tf.image.resize(
x, size=(25, 25), method=method, antialias=antialias
x, size=(25, 25), method=interpolation, antialias=antialias
)
if data_format == "channels_first":
ref_out = np.transpose(ref_out, (2, 0, 1))
@ -88,16 +100,111 @@ class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
out = kimage.resize(
x,
size=(25, 25),
method=method,
interpolation=interpolation,
antialias=antialias,
data_format=data_format,
)
if data_format == "channels_first":
x = np.transpose(x, (0, 2, 3, 1))
ref_out = tf.image.resize(
x, size=(25, 25), method=method, antialias=antialias
x, size=(25, 25), method=interpolation, antialias=antialias
)
if data_format == "channels_first":
ref_out = np.transpose(ref_out, (0, 3, 1, 2))
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
self.assertAllClose(ref_out, out, atol=0.3)
@parameterized.parameters(
[
("bilinear", "constant", "channels_last"),
("nearest", "constant", "channels_last"),
("bilinear", "nearest", "channels_last"),
("nearest", "nearest", "channels_last"),
("bilinear", "wrap", "channels_last"),
("nearest", "wrap", "channels_last"),
("bilinear", "reflect", "channels_last"),
("nearest", "reflect", "channels_last"),
("bilinear", "constant", "channels_first"),
]
)
def test_affine_transform(self, interpolation, fill_mode, data_format):
if fill_mode == "wrap" and backend.backend() == "torch":
self.skipTest(
"Applying affine transform with fill_mode=wrap is not support"
" in torch backend"
)
if fill_mode == "wrap" and backend.backend() in ("jax", "numpy"):
self.skipTest(
"The numerical results of applying affine transform with "
"fill_mode=wrap in tensorflow is inconsistent with jax and "
"numpy backends"
)
# Unbatched case
if data_format == "channels_first":
x = np.random.random((3, 50, 50)) * 255
else:
x = np.random.random((50, 50, 3)) * 255
transform = np.random.random(size=(6))
transform = np.pad(transform, (0, 2)) # makes c0, c1 always 0
out = kimage.affine_transform(
x,
transform,
interpolation=interpolation,
fill_mode=fill_mode,
data_format=data_format,
)
if data_format == "channels_first":
x = np.transpose(x, (1, 2, 0))
ref_out = tf.raw_ops.ImageProjectiveTransformV3(
images=tf.expand_dims(x, axis=0),
transforms=tf.cast(tf.expand_dims(transform, axis=0), tf.float32),
output_shape=tf.shape(x)[:-1],
fill_value=0,
interpolation=interpolation.upper(),
fill_mode=fill_mode.upper(),
)
ref_out = ref_out[0]
if data_format == "channels_first":
ref_out = np.transpose(ref_out, (2, 0, 1))
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
if backend.backend() == "torch":
# TODO: cannot pass with torch backend
with self.assertRaises(AssertionError):
self.assertAllClose(ref_out, out, atol=0.3)
else:
self.assertAllClose(ref_out, out, atol=0.3)
# Batched case
if data_format == "channels_first":
x = np.random.random((2, 3, 50, 50)) * 255
else:
x = np.random.random((2, 50, 50, 3)) * 255
transform = np.random.random(size=(2, 6))
transform = np.pad(transform, [(0, 0), (0, 2)]) # makes c0, c1 always 0
out = kimage.affine_transform(
x,
transform,
interpolation=interpolation,
fill_mode=fill_mode,
data_format=data_format,
)
if data_format == "channels_first":
x = np.transpose(x, (0, 2, 3, 1))
ref_out = tf.raw_ops.ImageProjectiveTransformV3(
images=x,
transforms=tf.cast(transform, tf.float32),
output_shape=tf.shape(x)[1:-1],
fill_value=0,
interpolation=interpolation.upper(),
fill_mode=fill_mode.upper(),
)
if data_format == "channels_first":
ref_out = np.transpose(ref_out, (0, 3, 1, 2))
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
if backend.backend() == "torch":
# TODO: cannot pass with torch backend
with self.assertRaises(AssertionError):
self.assertAllClose(ref_out, out, atol=0.3)
else:
self.assertAllClose(ref_out, out, atol=0.3)

@ -439,7 +439,7 @@ def smart_resize(
]
img = backend_module.image.resize(
img, size=size, method=interpolation, data_format=data_format
img, size=size, interpolation=interpolation, data_format=data_format
)
if isinstance(x, np.ndarray):