Add Resizing layer and corresponding backend ops.

This commit is contained in:
Francois Chollet 2023-05-05 17:12:09 -07:00
parent a333b9bd1b
commit 8c014ca995
12 changed files with 431 additions and 452 deletions

@ -38,7 +38,7 @@ def resize(
resized = tf.image.resize(image, size, method=method, antialias=antialias)
if data_format == "channels_first":
if len(image.shape) == 4:
resized = tf.transpose(image, (0, 3, 1, 2))
resized = tf.transpose(resized, (0, 3, 1, 2))
elif len(image.shape) == 3:
image = tf.transpose(image, (2, 0, 1))
resized = tf.transpose(resized, (2, 0, 1))
return resized

@ -26,10 +26,6 @@ from keras_core.layers.merging.subtract import subtract
from keras_core.layers.normalization.batch_normalization import (
BatchNormalization,
)
from keras_core.layers.normalization.layer_normalization import (
LayerNormalization,
)
from keras_core.layers.normalization.unit_normalization import UnitNormalization
from keras_core.layers.pooling.average_pooling1d import AveragePooling1D
from keras_core.layers.pooling.average_pooling2d import AveragePooling2D
from keras_core.layers.pooling.average_pooling3d import AveragePooling3D
@ -50,6 +46,7 @@ from keras_core.layers.pooling.max_pooling2d import MaxPooling2D
from keras_core.layers.pooling.max_pooling3d import MaxPooling3D
from keras_core.layers.preprocessing.normalization import Normalization
from keras_core.layers.preprocessing.rescaling import Rescaling
from keras_core.layers.preprocessing.resizing import Resizing
from keras_core.layers.regularization.activity_regularization import (
ActivityRegularization,
)

@ -1,239 +0,0 @@
from keras_core import constraints
from keras_core import initializers
from keras_core import operations as ops
from keras_core import regularizers
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.LayerNormalization")
class LayerNormalization(Layer):
"""Layer normalization layer (Ba et al., 2016).
Normalize the activations of the previous layer for each given example in a
batch independently, rather than across a batch like Batch Normalization.
i.e. applies a transformation that maintains the mean activation within each
example close to 0 and the activation standard deviation close to 1.
If `scale` or `center` are enabled, the layer will scale the normalized
outputs by broadcasting them with a trainable variable `gamma`, and center
the outputs by broadcasting with a trainable variable `beta`. `gamma` will
default to a ones tensor and `beta` will default to a zeros tensor, so that
centering and scaling are no-ops before training has begun.
So, with scaling and centering enabled the normalization equations
are as follows:
Let the intermediate activations for a mini-batch to be the `inputs`.
For each sample `x_i` in `inputs` with `k` features, we compute the mean and
variance of the sample:
```python
mean_i = sum(x_i[j] for j in range(k)) / k
var_i = sum((x_i[j] - mean_i) ** 2 for j in range(k)) / k
```
and then compute a normalized `x_i_normalized`, including a small factor
`epsilon` for numerical stability.
```python
x_i_normalized = (x_i - mean_i) / sqrt(var_i + epsilon)
```
And finally `x_i_normalized ` is linearly transformed by `gamma` and `beta`,
which are learned parameters:
```python
output_i = x_i_normalized * gamma + beta
```
`gamma` and `beta` will span the axes of `inputs` specified in `axis`, and
this part of the inputs' shape must be fully defined.
For example:
>>> layer = keras_core.layers.LayerNormalization(axis=[1, 2, 3])
>>> layer.build([5, 20, 30, 40])
>>> print(layer.beta.shape)
(20, 30, 40)
>>> print(layer.gamma.shape)
(20, 30, 40)
Note that other implementations of layer normalization may choose to define
`gamma` and `beta` over a separate set of axes from the axes being
normalized across. For example, Group Normalization
([Wu et al. 2018](https://arxiv.org/abs/1803.08494)) with group size of 1
corresponds to a Layer Normalization that normalizes across height, width,
and channel and has `gamma` and `beta` span only the channel dimension.
So, this Layer Normalization implementation will not match a Group
Normalization layer with group size set to 1.
Args:
axis: Integer or List/Tuple. The axis or axes to normalize across.
Typically, this is the features axis/axes. The left-out axes are
typically the batch axis/axes. `-1` is the last dimension in the
input. Defaults to `-1`.
epsilon: Small float added to variance to avoid dividing by zero.
Defaults to 1e-3.
center: If True, add offset of `beta` to normalized tensor. If False,
`beta` is ignored. Defaults to `True`.
scale: If True, multiply by `gamma`. If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`), this can be
disabled since the scaling will be done by the next layer.
Defaults to `True`.
beta_initializer: Initializer for the beta weight. Defaults to zeros.
gamma_initializer: Initializer for the gamma weight. Defaults to ones.
beta_regularizer: Optional regularizer for the beta weight.
None by default.
gamma_regularizer: Optional regularizer for the gamma weight.
None by default.
beta_constraint: Optional constraint for the beta weight.
None by default.
gamma_constraint: Optional constraint for the gamma weight.
None by default.
**kwargs: Base layer keyword arguments (e.g. `name` and `dtype`).
Reference:
- [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450).
"""
def __init__(
self,
axis=-1,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer="zeros",
gamma_initializer="ones",
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs
):
super().__init__(**kwargs)
if isinstance(axis, (list, tuple)):
self.axis = list(axis)
elif isinstance(axis, int):
self.axis = axis
else:
raise TypeError(
"Expected an int or a list/tuple of ints for the "
"argument 'axis', but received: %r" % axis
)
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
self.supports_masking = True
def build(self, input_shape):
if isinstance(self.axis, list):
shape = tuple([input_shape[dim] for dim in self.axis])
else:
shape = (input_shape[self.axis],)
self.axis = [self.axis]
if self.scale:
self.gamma = self.add_weight(
name="gamma",
shape=shape,
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint,
trainable=True,
)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(
name="beta",
shape=shape,
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint,
trainable=True,
)
else:
self.beta = None
self.built = True
def call(self, inputs):
inputs = ops.cast(inputs, self.compute_dtype)
# Compute the axes along which to reduce the mean / variance
input_shape = inputs.shape
ndims = len(input_shape)
# Broadcasting only necessary for norm when the axis is not just
# the last dimension
broadcast_shape = [1] * ndims
for dim in self.axis:
broadcast_shape[dim] = input_shape[dim]
def _broadcast(v):
if (
v is not None
and len(v.shape) != ndims
and self.axis != [ndims - 1]
):
return ops.reshape(v, broadcast_shape)
return v
input_dtype = inputs.dtype
if input_dtype in ("float16", "bfloat16") and self.dtype == "float32":
# If mixed precision is used, cast inputs to float32 so that
# this is at least as numerically stable as the fused version.
inputs = ops.cast(inputs, "float32")
# Calculate the mean and variance last axis (layer activations).
mean = ops.mean(inputs, axis=self.axis, keepdims=True)
variance = ops.var(inputs, axis=self.axis, keepdims=True)
scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
# Compute the batch normalization.
inv = 1 / ops.sqrt(variance + self.epsilon)
if scale is not None:
inv *= scale
x = offset - mean * inv if offset is not None else -mean * inv
outputs = inputs * ops.cast(inv, inputs.dtype) + ops.cast(
x, inputs.dtype
)
outputs = ops.cast(outputs, input_dtype)
# If some components of the shape got lost due to adjustments, fix that.
outputs = ops.reshape(outputs, input_shape)
return outputs
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = {
"axis": self.axis,
"epsilon": self.epsilon,
"center": self.center,
"scale": self.scale,
"beta_initializer": initializers.serialize(self.beta_initializer),
"gamma_initializer": initializers.serialize(self.gamma_initializer),
"beta_regularizer": regularizers.serialize(self.beta_regularizer),
"gamma_regularizer": regularizers.serialize(self.gamma_regularizer),
"beta_constraint": constraints.serialize(self.beta_constraint),
"gamma_constraint": constraints.serialize(self.gamma_constraint),
}
base_config = super().get_config()
return {**base_config, **config}

@ -1,85 +0,0 @@
import numpy as np
from keras_core import layers
from keras_core import operations as ops
from keras_core import regularizers
from keras_core import testing
class LayerNormalizationTest(testing.TestCase):
def test_ln_basics(self):
self.run_layer_test(
layers.LayerNormalization,
init_kwargs={
"gamma_regularizer": regularizers.L2(0.01),
"beta_regularizer": regularizers.L2(0.01),
},
input_shape=(3, 4, 2),
expected_output_shape=(3, 4, 2),
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=2,
supports_masking=True,
)
self.run_layer_test(
layers.LayerNormalization,
init_kwargs={
"gamma_initializer": "ones",
"beta_initializer": "ones",
},
input_shape=(3, 4, 2),
expected_output_shape=(3, 4, 2),
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
)
self.run_layer_test(
layers.LayerNormalization,
init_kwargs={"scale": False, "center": False},
input_shape=(3, 3),
expected_output_shape=(3, 3),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
)
self.run_layer_test(
layers.LayerNormalization,
init_kwargs={"axis": (-3, -2, -1)},
input_shape=(2, 8, 8, 3),
expected_output_shape=(2, 8, 8, 3),
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
)
self.run_layer_test(
layers.LayerNormalization,
init_kwargs={},
input_shape=(1, 0, 10),
expected_output_shape=(1, 0, 10),
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
)
def test_correctness(self):
layer = layers.LayerNormalization(dtype="float32")
layer.build(input_shape=(2, 2, 2))
inputs = np.random.normal(
loc=5.0, scale=10.0, size=(1000, 2, 2, 2)
).astype("float32")
out = layer(inputs)
out -= layer.beta
out /= layer.gamma
self.assertAllClose(ops.mean(out), 0.0, atol=1e-1)
self.assertAllClose(ops.std(out), 1.0, atol=1e-1)

@ -1,57 +0,0 @@
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.UnitNormalization")
class UnitNormalization(Layer):
"""Unit normalization layer.
Normalize a batch of inputs so that each input in the batch has a L2 norm
equal to 1 (across the axes specified in `axis`).
Example:
>>> data = np.arange(6).reshape(2, 3)
>>> normalized_data = keras_core.layers.UnitNormalization()(data)
>>> print(np.sum(normalized_data[0, :] ** 2)
1.0
Args:
axis: Integer or list/tuple. The axis or axes to normalize across.
Typically, this is the features axis or axes. The left-out axes are
typically the batch axis or axes. `-1` is the last dimension
in the input. Defaults to `-1`.
"""
def __init__(self, axis=-1, **kwargs):
super().__init__(**kwargs)
if isinstance(axis, (list, tuple)):
self.axis = list(axis)
elif isinstance(axis, int):
self.axis = axis
else:
raise TypeError(
"Invalid value for `axis` argument: "
"expected an int or a list/tuple of ints. "
f"Received: axis={axis}"
)
self.supports_masking = True
def build(self, input_shape):
self.built = True
def call(self, inputs):
x = ops.cast(inputs, self.compute_dtype)
square_sum = ops.sum(ops.square(x), axis=self.axis, keepdims=True)
x_inv_norm = 1 / ops.sqrt(ops.maximum(square_sum, 1e-12))
return ops.multiply(x, x_inv_norm)
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = super().get_config()
config.update({"axis": self.axis})
return config

@ -1,47 +0,0 @@
import numpy as np
from keras_core import layers
from keras_core import testing
def squared_l2_norm(x):
return np.sum(x**2)
class UnitNormalizationTest(testing.TestCase):
def test_un_basics(self):
self.run_layer_test(
layers.UnitNormalization,
init_kwargs={"axis": -1},
input_shape=(2, 3),
expected_output_shape=(2, 3),
supports_masking=True,
)
self.run_layer_test(
layers.UnitNormalization,
init_kwargs={"axis": (1, 2)},
input_shape=(1, 3, 3),
expected_output_shape=(1, 3, 3),
supports_masking=True,
)
def test_correctness(self):
layer = layers.UnitNormalization(axis=-1)
inputs = np.random.normal(size=(2, 3))
outputs = layer(inputs)
self.assertAllClose(squared_l2_norm(outputs[0, :]), 1.0)
self.assertAllClose(squared_l2_norm(outputs[1, :]), 1.0)
layer = layers.UnitNormalization(axis=(1, 2))
inputs = np.random.normal(size=(2, 3, 3))
outputs = layer(inputs)
self.assertAllClose(squared_l2_norm(outputs[0, :, :]), 1.0)
self.assertAllClose(squared_l2_norm(outputs[1, :, :]), 1.0)
layer = layers.UnitNormalization(axis=1)
inputs = np.random.normal(size=(2, 3, 2))
outputs = layer(inputs)
self.assertAllClose(squared_l2_norm(outputs[0, :, 0]), 1.0)
self.assertAllClose(squared_l2_norm(outputs[1, :, 0]), 1.0)
self.assertAllClose(squared_l2_norm(outputs[0, :, 1]), 1.0)
self.assertAllClose(squared_l2_norm(outputs[1, :, 1]), 1.0)

@ -25,7 +25,7 @@ class Rescaling(Layer):
Args:
scale: Float, the scale to apply to the inputs.
offset: Float, the offset to apply to the inputs.
**kwargs: Base layer keyword arguments, such as `name` and `dtype.
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
"""
def __init__(self, scale, offset=0.0, **kwargs):

@ -0,0 +1,102 @@
from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import image_utils
@keras_core_export("keras_core.layers.Resizing")
class Resizing(Layer):
"""A preprocessing layer which resizes images.
This layer resizes an image input to a target height and width. The input
should be a 4D (batched) or 3D (unbatched) tensor in `"channels_last"`
format. Input pixel values can be of any range
(e.g. `[0., 1.)` or `[0, 255]`).
Args:
height: Integer, the height of the output shape.
width: Integer, the width of the output shape.
interpolation: String, the interpolation method.
Supports `"bilinear"`, `"nearest"`, `"bicubic"`,
`"lanczos3"`, `"lanczos5"`. Defaults to `"bilinear"`.
crop_to_aspect_ratio: If `True`, resize the images without aspect
ratio distortion. When the original aspect ratio differs
from the target aspect ratio, the output image will be
cropped so as to return the
largest possible window in the image (of size `(height, width)`)
that matches the target aspect ratio. By default
(`crop_to_aspect_ratio=False`), aspect ratio may not be preserved.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, height, width)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
"""
def __init__(
self,
height,
width,
interpolation="bilinear",
crop_to_aspect_ratio=False,
data_format=None,
**kwargs,
):
super().__init__(**kwargs)
self.height = height
self.width = width
self.interpolation = interpolation
self.data_format = data_format or backend.image_data_format()
self.crop_to_aspect_ratio = crop_to_aspect_ratio
def call(self, inputs):
size = (self.height, self.width)
if self.crop_to_aspect_ratio:
outputs = image_utils.smart_resize(
inputs,
size=size,
interpolation=self.interpolation,
data_format=self.data_format,
)
else:
outputs = ops.image.resize(
inputs,
size=size,
method=self.interpolation,
data_format=self.data_format,
)
return outputs
def compute_output_shape(self, input_shape):
input_shape = list(input_shape)
if len(input_shape) == 4:
if self.data_format == "channels_last":
input_shape[1] = self.height
input_shape[2] = self.width
else:
input_shape[2] = self.height
input_shape[3] = self.width
else:
if self.data_format == "channels_last":
input_shape[0] = self.height
input_shape[1] = self.width
else:
input_shape[1] = self.height
input_shape[2] = self.width
return tuple(input_shape)
def get_config(self):
base_config = super().get_config()
config = {
"height": self.height,
"width": self.width,
"interpolation": self.interpolation,
"crop_to_aspect_ratio": self.crop_to_aspect_ratio,
"data_format": self.data_format,
}
return {**base_config, **config}

@ -0,0 +1,140 @@
import numpy as np
import tensorflow as tf
from absl.testing import parameterized
from keras_core import layers
from keras_core import testing
class ResizingTest(testing.TestCase, parameterized.TestCase):
def test_resizing_basics(self):
self.run_layer_test(
layers.Resizing,
init_kwargs={
"height": 6,
"width": 6,
"data_format": "channels_last",
"interpolation": "bicubic",
"crop_to_aspect_ratio": True,
},
input_shape=(2, 12, 12, 3),
expected_output_shape=(2, 6, 6, 3),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=False,
)
self.run_layer_test(
layers.Resizing,
init_kwargs={
"height": 6,
"width": 6,
"data_format": "channels_first",
"interpolation": "bilinear",
"crop_to_aspect_ratio": True,
},
input_shape=(2, 3, 12, 12),
expected_output_shape=(2, 3, 6, 6),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=False,
)
self.run_layer_test(
layers.Resizing,
init_kwargs={
"height": 6,
"width": 6,
"data_format": "channels_last",
"interpolation": "nearest",
"crop_to_aspect_ratio": False,
},
input_shape=(2, 12, 12, 3),
expected_output_shape=(2, 6, 6, 3),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=False,
)
self.run_layer_test(
layers.Resizing,
init_kwargs={
"height": 6,
"width": 6,
"data_format": "channels_first",
"interpolation": "lanczos5",
"crop_to_aspect_ratio": False,
},
input_shape=(2, 3, 12, 12),
expected_output_shape=(2, 3, 6, 6),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=False,
)
@parameterized.parameters(
[
((5, 7), "channels_first", True),
((5, 7), "channels_last", True),
((6, 8), "channels_first", False),
((6, 8), "channels_last", False),
]
)
def test_resizing_correctness(
self, size, data_format, crop_to_aspect_ratio
):
# batched case
if data_format == "channels_first":
img = np.random.random((2, 3, 9, 11))
else:
img = np.random.random((2, 9, 11, 3))
out = layers.Resizing(
size[0],
size[1],
data_format=data_format,
crop_to_aspect_ratio=crop_to_aspect_ratio,
)(img)
if data_format == "channels_first":
img_transpose = np.transpose(img, (0, 2, 3, 1))
ref_out = tf.transpose(
tf.keras.layers.Resizing(
size[0], size[1], crop_to_aspect_ratio=crop_to_aspect_ratio
)(img_transpose),
(0, 3, 1, 2),
)
else:
ref_out = tf.keras.layers.Resizing(
size[0], size[1], crop_to_aspect_ratio=crop_to_aspect_ratio
)(img)
self.assertAllClose(ref_out, out)
# unbatched case
if data_format == "channels_first":
img = np.random.random((3, 9, 11))
else:
img = np.random.random((9, 11, 3))
out = layers.Resizing(
size[0],
size[1],
data_format=data_format,
crop_to_aspect_ratio=crop_to_aspect_ratio,
)(img)
if data_format == "channels_first":
img_transpose = np.transpose(img, (1, 2, 0))
ref_out = tf.transpose(
tf.keras.layers.Resizing(
size[0], size[1], crop_to_aspect_ratio=crop_to_aspect_ratio
)(img_transpose),
(2, 0, 1),
)
else:
ref_out = tf.keras.layers.Resizing(
size[0], size[1], crop_to_aspect_ratio=crop_to_aspect_ratio
)(img)
self.assertAllClose(ref_out, out)

@ -9,6 +9,7 @@ from keras_core.backend import is_tensor
from keras_core.backend import name_scope
from keras_core.backend import random
from keras_core.backend import shape
from keras_core.operations import image
from keras_core.operations import operation_utils
from keras_core.operations.math import * # noqa: F403
from keras_core.operations.nn import * # noqa: F403

@ -34,35 +34,60 @@ class ImageOpsStaticShapeTest(testing.TestCase):
class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
[
("bilinear", True),
("nearest", True),
("lanczos3", True),
("lanczos5", True),
("bicubic", True),
("bilinear", False),
("nearest", False),
("lanczos3", False),
("lanczos5", False),
("bicubic", False),
("bilinear", True, "channels_last"),
("nearest", True, "channels_last"),
("lanczos3", True, "channels_last"),
("lanczos5", True, "channels_last"),
("bicubic", True, "channels_last"),
("bilinear", False, "channels_last"),
("nearest", False, "channels_last"),
("lanczos3", False, "channels_last"),
("lanczos5", False, "channels_last"),
("bicubic", False, "channels_last"),
("bilinear", True, "channels_first"),
]
)
def test_resize(self, method, antialias):
x = np.random.random((50, 50, 3)) * 255
def test_resize(self, method, antialias, data_format):
# Unbatched case
if data_format == "channels_first":
x = np.random.random((3, 50, 50)) * 255
else:
x = np.random.random((50, 50, 3)) * 255
out = kimage.resize(
x, size=(25, 25), method=method, antialias=antialias
x,
size=(25, 25),
method=method,
antialias=antialias,
data_format=data_format,
)
if data_format == "channels_first":
x = np.transpose(x, (1, 2, 0))
ref_out = tf.image.resize(
x, size=(25, 25), method=method, antialias=antialias
)
if data_format == "channels_first":
ref_out = np.transpose(ref_out, (2, 0, 1))
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
self.assertAllClose(ref_out, out, atol=0.3)
x = np.random.random((2, 50, 50, 3)) * 255
# Batched case
if data_format == "channels_first":
x = np.random.random((2, 3, 50, 50)) * 255
else:
x = np.random.random((2, 50, 50, 3)) * 255
out = kimage.resize(
x, size=(25, 25), method=method, antialias=antialias
x,
size=(25, 25),
method=method,
antialias=antialias,
data_format=data_format,
)
if data_format == "channels_first":
x = np.transpose(x, (0, 2, 3, 1))
ref_out = tf.image.resize(
x, size=(25, 25), method=method, antialias=antialias
)
if data_format == "channels_first":
ref_out = np.transpose(ref_out, (0, 3, 1, 2))
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
self.assertAllClose(ref_out, out, atol=0.3)

@ -7,6 +7,7 @@ import warnings
import numpy as np
from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
try:
@ -112,8 +113,10 @@ def array_to_img(x, data_format=None, scale=True, dtype=None):
@keras_core_export(
"keras_core.utils.img_to_array",
"keras_core.preprocessing.image.img_to_array",
[
"keras_core.utils.img_to_array",
"keras_core.preprocessing.image.img_to_array",
]
)
def img_to_array(img, data_format=None, dtype=None):
"""Converts a PIL Image instance to a NumPy array.
@ -301,3 +304,142 @@ def load_img(
else:
img = img.resize(width_height_tuple, resample)
return img
def smart_resize(
x, size, interpolation="bilinear", data_format="channels_last"
):
"""Resize images to a target size without aspect ratio distortion.
Image datasets typically yield images that have each a different
size. However, these images need to be batched before they can be
processed by Keras layers. To be batched, images need to share the same
height and width.
You could simply do, in TF (or JAX equivalent):
```python
size = (200, 200)
ds = ds.map(lambda img: resize(img, size))
```
However, if you do this, you distort the aspect ratio of your images, since
in general they do not all have the same aspect ratio as `size`. This is
fine in many cases, but not always (e.g. for image generation models
this can be a problem).
Note that passing the argument `preserve_aspect_ratio=True` to `resize`
will preserve the aspect ratio, but at the cost of no longer respecting the
provided target size.
This calls for:
```python
size = (200, 200)
ds = ds.map(lambda img: smart_resize(img, size))
```
Your output images will actually be `(200, 200)`, and will not be distorted.
Instead, the parts of the image that do not fit within the target size
get cropped out.
The resizing process is:
1. Take the largest centered crop of the image that has the same aspect
ratio as the target size. For instance, if `size=(200, 200)` and the input
image has size `(340, 500)`, we take a crop of `(340, 340)` centered along
the width.
2. Resize the cropped image to the target size. In the example above,
we resize the `(340, 340)` crop to `(200, 200)`.
Args:
x: Input image or batch of images (as a tensor or NumPy array).
Must be in format `(height, width, channels)`
or `(batch_size, height, width, channels)`.
size: Tuple of `(height, width)` integer. Target size.
interpolation: String, interpolation to use for resizing.
Defaults to `'bilinear'`.
Supports `bilinear`, `nearest`, `bicubic`,
`lanczos3`, `lanczos5`.
data_format: `"channels_last"` or `"channels_first"`.
Returns:
Array with shape `(size[0], size[1], channels)`. If the input image was a
NumPy array, the output is a NumPy array,
and if it was a backend-native tensor,
the output is a backend-native tensor.
"""
if len(size) != 2:
raise ValueError(
f"Expected `size` to be a tuple of 2 integers, but got: {size}."
)
img = backend.convert_to_tensor(x)
if len(img.shape) is not None:
if len(img.shape) < 3 or len(img.shape) > 4:
raise ValueError(
"Expected an image array with shape `(height, width, "
"channels)`, or `(batch_size, height, width, channels)`, but "
f"got input with incorrect rank, of shape {img.shape}."
)
shape = ops.shape(img)
if data_format == "channels_last":
height, width = shape[-3], shape[-2]
else:
height, width = shape[-2], shape[-1]
target_height, target_width = size
crop_height = ops.cast(
ops.cast(width * target_height, "float32") / target_width, "int32"
)
crop_width = ops.cast(
ops.cast(height * target_width, "float32") / target_height, "int32"
)
# Set back to input height / width if crop_height / crop_width is not
# smaller.
crop_height = ops.minimum(height, crop_height)
crop_width = ops.minimum(width, crop_width)
crop_box_hstart = ops.cast(
ops.cast(height - crop_height, "float32") / 2, "int32"
)
crop_box_wstart = ops.cast(
ops.cast(width - crop_width, "float32") / 2, "int32"
)
if data_format == "channels_last":
if len(img.shape) == 4:
img = img[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
img = img[
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
if len(img.shape) == 4:
img = img[
:,
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
else:
img = img[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
img = ops.image.resize(
img, size=size, method=interpolation, data_format=data_format
)
if isinstance(x, np.ndarray):
return np.array(img)
return img