Add Resizing layer and corresponding backend ops.
This commit is contained in:
parent
a333b9bd1b
commit
8c014ca995
@ -38,7 +38,7 @@ def resize(
|
||||
resized = tf.image.resize(image, size, method=method, antialias=antialias)
|
||||
if data_format == "channels_first":
|
||||
if len(image.shape) == 4:
|
||||
resized = tf.transpose(image, (0, 3, 1, 2))
|
||||
resized = tf.transpose(resized, (0, 3, 1, 2))
|
||||
elif len(image.shape) == 3:
|
||||
image = tf.transpose(image, (2, 0, 1))
|
||||
resized = tf.transpose(resized, (2, 0, 1))
|
||||
return resized
|
||||
|
@ -26,10 +26,6 @@ from keras_core.layers.merging.subtract import subtract
|
||||
from keras_core.layers.normalization.batch_normalization import (
|
||||
BatchNormalization,
|
||||
)
|
||||
from keras_core.layers.normalization.layer_normalization import (
|
||||
LayerNormalization,
|
||||
)
|
||||
from keras_core.layers.normalization.unit_normalization import UnitNormalization
|
||||
from keras_core.layers.pooling.average_pooling1d import AveragePooling1D
|
||||
from keras_core.layers.pooling.average_pooling2d import AveragePooling2D
|
||||
from keras_core.layers.pooling.average_pooling3d import AveragePooling3D
|
||||
@ -50,6 +46,7 @@ from keras_core.layers.pooling.max_pooling2d import MaxPooling2D
|
||||
from keras_core.layers.pooling.max_pooling3d import MaxPooling3D
|
||||
from keras_core.layers.preprocessing.normalization import Normalization
|
||||
from keras_core.layers.preprocessing.rescaling import Rescaling
|
||||
from keras_core.layers.preprocessing.resizing import Resizing
|
||||
from keras_core.layers.regularization.activity_regularization import (
|
||||
ActivityRegularization,
|
||||
)
|
||||
|
@ -1,239 +0,0 @@
|
||||
from keras_core import constraints
|
||||
from keras_core import initializers
|
||||
from keras_core import operations as ops
|
||||
from keras_core import regularizers
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.LayerNormalization")
|
||||
class LayerNormalization(Layer):
|
||||
"""Layer normalization layer (Ba et al., 2016).
|
||||
|
||||
Normalize the activations of the previous layer for each given example in a
|
||||
batch independently, rather than across a batch like Batch Normalization.
|
||||
i.e. applies a transformation that maintains the mean activation within each
|
||||
example close to 0 and the activation standard deviation close to 1.
|
||||
|
||||
If `scale` or `center` are enabled, the layer will scale the normalized
|
||||
outputs by broadcasting them with a trainable variable `gamma`, and center
|
||||
the outputs by broadcasting with a trainable variable `beta`. `gamma` will
|
||||
default to a ones tensor and `beta` will default to a zeros tensor, so that
|
||||
centering and scaling are no-ops before training has begun.
|
||||
|
||||
So, with scaling and centering enabled the normalization equations
|
||||
are as follows:
|
||||
|
||||
Let the intermediate activations for a mini-batch to be the `inputs`.
|
||||
|
||||
For each sample `x_i` in `inputs` with `k` features, we compute the mean and
|
||||
variance of the sample:
|
||||
|
||||
```python
|
||||
mean_i = sum(x_i[j] for j in range(k)) / k
|
||||
var_i = sum((x_i[j] - mean_i) ** 2 for j in range(k)) / k
|
||||
```
|
||||
|
||||
and then compute a normalized `x_i_normalized`, including a small factor
|
||||
`epsilon` for numerical stability.
|
||||
|
||||
```python
|
||||
x_i_normalized = (x_i - mean_i) / sqrt(var_i + epsilon)
|
||||
```
|
||||
|
||||
And finally `x_i_normalized ` is linearly transformed by `gamma` and `beta`,
|
||||
which are learned parameters:
|
||||
|
||||
```python
|
||||
output_i = x_i_normalized * gamma + beta
|
||||
```
|
||||
|
||||
`gamma` and `beta` will span the axes of `inputs` specified in `axis`, and
|
||||
this part of the inputs' shape must be fully defined.
|
||||
|
||||
For example:
|
||||
|
||||
>>> layer = keras_core.layers.LayerNormalization(axis=[1, 2, 3])
|
||||
>>> layer.build([5, 20, 30, 40])
|
||||
>>> print(layer.beta.shape)
|
||||
(20, 30, 40)
|
||||
>>> print(layer.gamma.shape)
|
||||
(20, 30, 40)
|
||||
|
||||
Note that other implementations of layer normalization may choose to define
|
||||
`gamma` and `beta` over a separate set of axes from the axes being
|
||||
normalized across. For example, Group Normalization
|
||||
([Wu et al. 2018](https://arxiv.org/abs/1803.08494)) with group size of 1
|
||||
corresponds to a Layer Normalization that normalizes across height, width,
|
||||
and channel and has `gamma` and `beta` span only the channel dimension.
|
||||
So, this Layer Normalization implementation will not match a Group
|
||||
Normalization layer with group size set to 1.
|
||||
|
||||
Args:
|
||||
axis: Integer or List/Tuple. The axis or axes to normalize across.
|
||||
Typically, this is the features axis/axes. The left-out axes are
|
||||
typically the batch axis/axes. `-1` is the last dimension in the
|
||||
input. Defaults to `-1`.
|
||||
epsilon: Small float added to variance to avoid dividing by zero.
|
||||
Defaults to 1e-3.
|
||||
center: If True, add offset of `beta` to normalized tensor. If False,
|
||||
`beta` is ignored. Defaults to `True`.
|
||||
scale: If True, multiply by `gamma`. If False, `gamma` is not used.
|
||||
When the next layer is linear (also e.g. `nn.relu`), this can be
|
||||
disabled since the scaling will be done by the next layer.
|
||||
Defaults to `True`.
|
||||
beta_initializer: Initializer for the beta weight. Defaults to zeros.
|
||||
gamma_initializer: Initializer for the gamma weight. Defaults to ones.
|
||||
beta_regularizer: Optional regularizer for the beta weight.
|
||||
None by default.
|
||||
gamma_regularizer: Optional regularizer for the gamma weight.
|
||||
None by default.
|
||||
beta_constraint: Optional constraint for the beta weight.
|
||||
None by default.
|
||||
gamma_constraint: Optional constraint for the gamma weight.
|
||||
None by default.
|
||||
**kwargs: Base layer keyword arguments (e.g. `name` and `dtype`).
|
||||
|
||||
|
||||
Reference:
|
||||
|
||||
- [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
axis=-1,
|
||||
epsilon=1e-3,
|
||||
center=True,
|
||||
scale=True,
|
||||
beta_initializer="zeros",
|
||||
gamma_initializer="ones",
|
||||
beta_regularizer=None,
|
||||
gamma_regularizer=None,
|
||||
beta_constraint=None,
|
||||
gamma_constraint=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if isinstance(axis, (list, tuple)):
|
||||
self.axis = list(axis)
|
||||
elif isinstance(axis, int):
|
||||
self.axis = axis
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected an int or a list/tuple of ints for the "
|
||||
"argument 'axis', but received: %r" % axis
|
||||
)
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.center = center
|
||||
self.scale = scale
|
||||
self.beta_initializer = initializers.get(beta_initializer)
|
||||
self.gamma_initializer = initializers.get(gamma_initializer)
|
||||
self.beta_regularizer = regularizers.get(beta_regularizer)
|
||||
self.gamma_regularizer = regularizers.get(gamma_regularizer)
|
||||
self.beta_constraint = constraints.get(beta_constraint)
|
||||
self.gamma_constraint = constraints.get(gamma_constraint)
|
||||
|
||||
self.supports_masking = True
|
||||
|
||||
def build(self, input_shape):
|
||||
if isinstance(self.axis, list):
|
||||
shape = tuple([input_shape[dim] for dim in self.axis])
|
||||
else:
|
||||
shape = (input_shape[self.axis],)
|
||||
self.axis = [self.axis]
|
||||
if self.scale:
|
||||
self.gamma = self.add_weight(
|
||||
name="gamma",
|
||||
shape=shape,
|
||||
initializer=self.gamma_initializer,
|
||||
regularizer=self.gamma_regularizer,
|
||||
constraint=self.gamma_constraint,
|
||||
trainable=True,
|
||||
)
|
||||
else:
|
||||
self.gamma = None
|
||||
|
||||
if self.center:
|
||||
self.beta = self.add_weight(
|
||||
name="beta",
|
||||
shape=shape,
|
||||
initializer=self.beta_initializer,
|
||||
regularizer=self.beta_regularizer,
|
||||
constraint=self.beta_constraint,
|
||||
trainable=True,
|
||||
)
|
||||
else:
|
||||
self.beta = None
|
||||
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
inputs = ops.cast(inputs, self.compute_dtype)
|
||||
# Compute the axes along which to reduce the mean / variance
|
||||
input_shape = inputs.shape
|
||||
ndims = len(input_shape)
|
||||
|
||||
# Broadcasting only necessary for norm when the axis is not just
|
||||
# the last dimension
|
||||
broadcast_shape = [1] * ndims
|
||||
for dim in self.axis:
|
||||
broadcast_shape[dim] = input_shape[dim]
|
||||
|
||||
def _broadcast(v):
|
||||
if (
|
||||
v is not None
|
||||
and len(v.shape) != ndims
|
||||
and self.axis != [ndims - 1]
|
||||
):
|
||||
return ops.reshape(v, broadcast_shape)
|
||||
return v
|
||||
|
||||
input_dtype = inputs.dtype
|
||||
if input_dtype in ("float16", "bfloat16") and self.dtype == "float32":
|
||||
# If mixed precision is used, cast inputs to float32 so that
|
||||
# this is at least as numerically stable as the fused version.
|
||||
inputs = ops.cast(inputs, "float32")
|
||||
|
||||
# Calculate the mean and variance last axis (layer activations).
|
||||
mean = ops.mean(inputs, axis=self.axis, keepdims=True)
|
||||
variance = ops.var(inputs, axis=self.axis, keepdims=True)
|
||||
|
||||
scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
|
||||
|
||||
# Compute the batch normalization.
|
||||
inv = 1 / ops.sqrt(variance + self.epsilon)
|
||||
if scale is not None:
|
||||
inv *= scale
|
||||
|
||||
x = offset - mean * inv if offset is not None else -mean * inv
|
||||
outputs = inputs * ops.cast(inv, inputs.dtype) + ops.cast(
|
||||
x, inputs.dtype
|
||||
)
|
||||
|
||||
outputs = ops.cast(outputs, input_dtype)
|
||||
|
||||
# If some components of the shape got lost due to adjustments, fix that.
|
||||
outputs = ops.reshape(outputs, input_shape)
|
||||
|
||||
return outputs
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"axis": self.axis,
|
||||
"epsilon": self.epsilon,
|
||||
"center": self.center,
|
||||
"scale": self.scale,
|
||||
"beta_initializer": initializers.serialize(self.beta_initializer),
|
||||
"gamma_initializer": initializers.serialize(self.gamma_initializer),
|
||||
"beta_regularizer": regularizers.serialize(self.beta_regularizer),
|
||||
"gamma_regularizer": regularizers.serialize(self.gamma_regularizer),
|
||||
"beta_constraint": constraints.serialize(self.beta_constraint),
|
||||
"gamma_constraint": constraints.serialize(self.gamma_constraint),
|
||||
}
|
||||
base_config = super().get_config()
|
||||
return {**base_config, **config}
|
@ -1,85 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
from keras_core import layers
|
||||
from keras_core import operations as ops
|
||||
from keras_core import regularizers
|
||||
from keras_core import testing
|
||||
|
||||
|
||||
class LayerNormalizationTest(testing.TestCase):
|
||||
def test_ln_basics(self):
|
||||
self.run_layer_test(
|
||||
layers.LayerNormalization,
|
||||
init_kwargs={
|
||||
"gamma_regularizer": regularizers.L2(0.01),
|
||||
"beta_regularizer": regularizers.L2(0.01),
|
||||
},
|
||||
input_shape=(3, 4, 2),
|
||||
expected_output_shape=(3, 4, 2),
|
||||
expected_num_trainable_weights=2,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=2,
|
||||
supports_masking=True,
|
||||
)
|
||||
self.run_layer_test(
|
||||
layers.LayerNormalization,
|
||||
init_kwargs={
|
||||
"gamma_initializer": "ones",
|
||||
"beta_initializer": "ones",
|
||||
},
|
||||
input_shape=(3, 4, 2),
|
||||
expected_output_shape=(3, 4, 2),
|
||||
expected_num_trainable_weights=2,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
self.run_layer_test(
|
||||
layers.LayerNormalization,
|
||||
init_kwargs={"scale": False, "center": False},
|
||||
input_shape=(3, 3),
|
||||
expected_output_shape=(3, 3),
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
self.run_layer_test(
|
||||
layers.LayerNormalization,
|
||||
init_kwargs={"axis": (-3, -2, -1)},
|
||||
input_shape=(2, 8, 8, 3),
|
||||
expected_output_shape=(2, 8, 8, 3),
|
||||
expected_num_trainable_weights=2,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
self.run_layer_test(
|
||||
layers.LayerNormalization,
|
||||
init_kwargs={},
|
||||
input_shape=(1, 0, 10),
|
||||
expected_output_shape=(1, 0, 10),
|
||||
expected_num_trainable_weights=2,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
|
||||
def test_correctness(self):
|
||||
layer = layers.LayerNormalization(dtype="float32")
|
||||
layer.build(input_shape=(2, 2, 2))
|
||||
inputs = np.random.normal(
|
||||
loc=5.0, scale=10.0, size=(1000, 2, 2, 2)
|
||||
).astype("float32")
|
||||
|
||||
out = layer(inputs)
|
||||
out -= layer.beta
|
||||
out /= layer.gamma
|
||||
|
||||
self.assertAllClose(ops.mean(out), 0.0, atol=1e-1)
|
||||
self.assertAllClose(ops.std(out), 1.0, atol=1e-1)
|
@ -1,57 +0,0 @@
|
||||
from keras_core import operations as ops
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.UnitNormalization")
|
||||
class UnitNormalization(Layer):
|
||||
"""Unit normalization layer.
|
||||
|
||||
Normalize a batch of inputs so that each input in the batch has a L2 norm
|
||||
equal to 1 (across the axes specified in `axis`).
|
||||
|
||||
Example:
|
||||
|
||||
>>> data = np.arange(6).reshape(2, 3)
|
||||
>>> normalized_data = keras_core.layers.UnitNormalization()(data)
|
||||
>>> print(np.sum(normalized_data[0, :] ** 2)
|
||||
1.0
|
||||
|
||||
Args:
|
||||
axis: Integer or list/tuple. The axis or axes to normalize across.
|
||||
Typically, this is the features axis or axes. The left-out axes are
|
||||
typically the batch axis or axes. `-1` is the last dimension
|
||||
in the input. Defaults to `-1`.
|
||||
"""
|
||||
|
||||
def __init__(self, axis=-1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if isinstance(axis, (list, tuple)):
|
||||
self.axis = list(axis)
|
||||
elif isinstance(axis, int):
|
||||
self.axis = axis
|
||||
else:
|
||||
raise TypeError(
|
||||
"Invalid value for `axis` argument: "
|
||||
"expected an int or a list/tuple of ints. "
|
||||
f"Received: axis={axis}"
|
||||
)
|
||||
self.supports_masking = True
|
||||
|
||||
def build(self, input_shape):
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
x = ops.cast(inputs, self.compute_dtype)
|
||||
|
||||
square_sum = ops.sum(ops.square(x), axis=self.axis, keepdims=True)
|
||||
x_inv_norm = 1 / ops.sqrt(ops.maximum(square_sum, 1e-12))
|
||||
return ops.multiply(x, x_inv_norm)
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config()
|
||||
config.update({"axis": self.axis})
|
||||
return config
|
@ -1,47 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
from keras_core import layers
|
||||
from keras_core import testing
|
||||
|
||||
|
||||
def squared_l2_norm(x):
|
||||
return np.sum(x**2)
|
||||
|
||||
|
||||
class UnitNormalizationTest(testing.TestCase):
|
||||
def test_un_basics(self):
|
||||
self.run_layer_test(
|
||||
layers.UnitNormalization,
|
||||
init_kwargs={"axis": -1},
|
||||
input_shape=(2, 3),
|
||||
expected_output_shape=(2, 3),
|
||||
supports_masking=True,
|
||||
)
|
||||
self.run_layer_test(
|
||||
layers.UnitNormalization,
|
||||
init_kwargs={"axis": (1, 2)},
|
||||
input_shape=(1, 3, 3),
|
||||
expected_output_shape=(1, 3, 3),
|
||||
supports_masking=True,
|
||||
)
|
||||
|
||||
def test_correctness(self):
|
||||
layer = layers.UnitNormalization(axis=-1)
|
||||
inputs = np.random.normal(size=(2, 3))
|
||||
outputs = layer(inputs)
|
||||
self.assertAllClose(squared_l2_norm(outputs[0, :]), 1.0)
|
||||
self.assertAllClose(squared_l2_norm(outputs[1, :]), 1.0)
|
||||
|
||||
layer = layers.UnitNormalization(axis=(1, 2))
|
||||
inputs = np.random.normal(size=(2, 3, 3))
|
||||
outputs = layer(inputs)
|
||||
self.assertAllClose(squared_l2_norm(outputs[0, :, :]), 1.0)
|
||||
self.assertAllClose(squared_l2_norm(outputs[1, :, :]), 1.0)
|
||||
|
||||
layer = layers.UnitNormalization(axis=1)
|
||||
inputs = np.random.normal(size=(2, 3, 2))
|
||||
outputs = layer(inputs)
|
||||
self.assertAllClose(squared_l2_norm(outputs[0, :, 0]), 1.0)
|
||||
self.assertAllClose(squared_l2_norm(outputs[1, :, 0]), 1.0)
|
||||
self.assertAllClose(squared_l2_norm(outputs[0, :, 1]), 1.0)
|
||||
self.assertAllClose(squared_l2_norm(outputs[1, :, 1]), 1.0)
|
@ -25,7 +25,7 @@ class Rescaling(Layer):
|
||||
Args:
|
||||
scale: Float, the scale to apply to the inputs.
|
||||
offset: Float, the offset to apply to the inputs.
|
||||
**kwargs: Base layer keyword arguments, such as `name` and `dtype.
|
||||
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
|
||||
"""
|
||||
|
||||
def __init__(self, scale, offset=0.0, **kwargs):
|
||||
|
102
keras_core/layers/preprocessing/resizing.py
Normal file
102
keras_core/layers/preprocessing/resizing.py
Normal file
@ -0,0 +1,102 @@
|
||||
from keras_core import backend
|
||||
from keras_core import operations as ops
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import image_utils
|
||||
|
||||
|
||||
@keras_core_export("keras_core.layers.Resizing")
|
||||
class Resizing(Layer):
|
||||
"""A preprocessing layer which resizes images.
|
||||
|
||||
This layer resizes an image input to a target height and width. The input
|
||||
should be a 4D (batched) or 3D (unbatched) tensor in `"channels_last"`
|
||||
format. Input pixel values can be of any range
|
||||
(e.g. `[0., 1.)` or `[0, 255]`).
|
||||
|
||||
Args:
|
||||
height: Integer, the height of the output shape.
|
||||
width: Integer, the width of the output shape.
|
||||
interpolation: String, the interpolation method.
|
||||
Supports `"bilinear"`, `"nearest"`, `"bicubic"`,
|
||||
`"lanczos3"`, `"lanczos5"`. Defaults to `"bilinear"`.
|
||||
crop_to_aspect_ratio: If `True`, resize the images without aspect
|
||||
ratio distortion. When the original aspect ratio differs
|
||||
from the target aspect ratio, the output image will be
|
||||
cropped so as to return the
|
||||
largest possible window in the image (of size `(height, width)`)
|
||||
that matches the target aspect ratio. By default
|
||||
(`crop_to_aspect_ratio=False`), aspect ratio may not be preserved.
|
||||
data_format: string, either `"channels_last"` or `"channels_first"`.
|
||||
The ordering of the dimensions in the inputs. `"channels_last"`
|
||||
corresponds to inputs with shape `(batch, height, width, channels)`
|
||||
while `"channels_first"` corresponds to inputs with shape
|
||||
`(batch, channels, height, width)`. It defaults to the
|
||||
`image_data_format` value found in your Keras config file at
|
||||
`~/.keras/keras.json`. If you never set it, then it will be
|
||||
`"channels_last"`.
|
||||
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
height,
|
||||
width,
|
||||
interpolation="bilinear",
|
||||
crop_to_aspect_ratio=False,
|
||||
data_format=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.interpolation = interpolation
|
||||
self.data_format = data_format or backend.image_data_format()
|
||||
self.crop_to_aspect_ratio = crop_to_aspect_ratio
|
||||
|
||||
def call(self, inputs):
|
||||
size = (self.height, self.width)
|
||||
if self.crop_to_aspect_ratio:
|
||||
outputs = image_utils.smart_resize(
|
||||
inputs,
|
||||
size=size,
|
||||
interpolation=self.interpolation,
|
||||
data_format=self.data_format,
|
||||
)
|
||||
else:
|
||||
outputs = ops.image.resize(
|
||||
inputs,
|
||||
size=size,
|
||||
method=self.interpolation,
|
||||
data_format=self.data_format,
|
||||
)
|
||||
return outputs
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
input_shape = list(input_shape)
|
||||
if len(input_shape) == 4:
|
||||
if self.data_format == "channels_last":
|
||||
input_shape[1] = self.height
|
||||
input_shape[2] = self.width
|
||||
else:
|
||||
input_shape[2] = self.height
|
||||
input_shape[3] = self.width
|
||||
else:
|
||||
if self.data_format == "channels_last":
|
||||
input_shape[0] = self.height
|
||||
input_shape[1] = self.width
|
||||
else:
|
||||
input_shape[1] = self.height
|
||||
input_shape[2] = self.width
|
||||
return tuple(input_shape)
|
||||
|
||||
def get_config(self):
|
||||
base_config = super().get_config()
|
||||
config = {
|
||||
"height": self.height,
|
||||
"width": self.width,
|
||||
"interpolation": self.interpolation,
|
||||
"crop_to_aspect_ratio": self.crop_to_aspect_ratio,
|
||||
"data_format": self.data_format,
|
||||
}
|
||||
return {**base_config, **config}
|
140
keras_core/layers/preprocessing/resizing_test.py
Normal file
140
keras_core/layers/preprocessing/resizing_test.py
Normal file
@ -0,0 +1,140 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from absl.testing import parameterized
|
||||
|
||||
from keras_core import layers
|
||||
from keras_core import testing
|
||||
|
||||
|
||||
class ResizingTest(testing.TestCase, parameterized.TestCase):
|
||||
def test_resizing_basics(self):
|
||||
self.run_layer_test(
|
||||
layers.Resizing,
|
||||
init_kwargs={
|
||||
"height": 6,
|
||||
"width": 6,
|
||||
"data_format": "channels_last",
|
||||
"interpolation": "bicubic",
|
||||
"crop_to_aspect_ratio": True,
|
||||
},
|
||||
input_shape=(2, 12, 12, 3),
|
||||
expected_output_shape=(2, 6, 6, 3),
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=False,
|
||||
)
|
||||
self.run_layer_test(
|
||||
layers.Resizing,
|
||||
init_kwargs={
|
||||
"height": 6,
|
||||
"width": 6,
|
||||
"data_format": "channels_first",
|
||||
"interpolation": "bilinear",
|
||||
"crop_to_aspect_ratio": True,
|
||||
},
|
||||
input_shape=(2, 3, 12, 12),
|
||||
expected_output_shape=(2, 3, 6, 6),
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=False,
|
||||
)
|
||||
self.run_layer_test(
|
||||
layers.Resizing,
|
||||
init_kwargs={
|
||||
"height": 6,
|
||||
"width": 6,
|
||||
"data_format": "channels_last",
|
||||
"interpolation": "nearest",
|
||||
"crop_to_aspect_ratio": False,
|
||||
},
|
||||
input_shape=(2, 12, 12, 3),
|
||||
expected_output_shape=(2, 6, 6, 3),
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=False,
|
||||
)
|
||||
self.run_layer_test(
|
||||
layers.Resizing,
|
||||
init_kwargs={
|
||||
"height": 6,
|
||||
"width": 6,
|
||||
"data_format": "channels_first",
|
||||
"interpolation": "lanczos5",
|
||||
"crop_to_aspect_ratio": False,
|
||||
},
|
||||
input_shape=(2, 3, 12, 12),
|
||||
expected_output_shape=(2, 3, 6, 6),
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=False,
|
||||
)
|
||||
|
||||
@parameterized.parameters(
|
||||
[
|
||||
((5, 7), "channels_first", True),
|
||||
((5, 7), "channels_last", True),
|
||||
((6, 8), "channels_first", False),
|
||||
((6, 8), "channels_last", False),
|
||||
]
|
||||
)
|
||||
def test_resizing_correctness(
|
||||
self, size, data_format, crop_to_aspect_ratio
|
||||
):
|
||||
# batched case
|
||||
if data_format == "channels_first":
|
||||
img = np.random.random((2, 3, 9, 11))
|
||||
else:
|
||||
img = np.random.random((2, 9, 11, 3))
|
||||
out = layers.Resizing(
|
||||
size[0],
|
||||
size[1],
|
||||
data_format=data_format,
|
||||
crop_to_aspect_ratio=crop_to_aspect_ratio,
|
||||
)(img)
|
||||
if data_format == "channels_first":
|
||||
img_transpose = np.transpose(img, (0, 2, 3, 1))
|
||||
|
||||
ref_out = tf.transpose(
|
||||
tf.keras.layers.Resizing(
|
||||
size[0], size[1], crop_to_aspect_ratio=crop_to_aspect_ratio
|
||||
)(img_transpose),
|
||||
(0, 3, 1, 2),
|
||||
)
|
||||
else:
|
||||
ref_out = tf.keras.layers.Resizing(
|
||||
size[0], size[1], crop_to_aspect_ratio=crop_to_aspect_ratio
|
||||
)(img)
|
||||
self.assertAllClose(ref_out, out)
|
||||
|
||||
# unbatched case
|
||||
if data_format == "channels_first":
|
||||
img = np.random.random((3, 9, 11))
|
||||
else:
|
||||
img = np.random.random((9, 11, 3))
|
||||
out = layers.Resizing(
|
||||
size[0],
|
||||
size[1],
|
||||
data_format=data_format,
|
||||
crop_to_aspect_ratio=crop_to_aspect_ratio,
|
||||
)(img)
|
||||
if data_format == "channels_first":
|
||||
img_transpose = np.transpose(img, (1, 2, 0))
|
||||
ref_out = tf.transpose(
|
||||
tf.keras.layers.Resizing(
|
||||
size[0], size[1], crop_to_aspect_ratio=crop_to_aspect_ratio
|
||||
)(img_transpose),
|
||||
(2, 0, 1),
|
||||
)
|
||||
else:
|
||||
ref_out = tf.keras.layers.Resizing(
|
||||
size[0], size[1], crop_to_aspect_ratio=crop_to_aspect_ratio
|
||||
)(img)
|
||||
self.assertAllClose(ref_out, out)
|
@ -9,6 +9,7 @@ from keras_core.backend import is_tensor
|
||||
from keras_core.backend import name_scope
|
||||
from keras_core.backend import random
|
||||
from keras_core.backend import shape
|
||||
from keras_core.operations import image
|
||||
from keras_core.operations import operation_utils
|
||||
from keras_core.operations.math import * # noqa: F403
|
||||
from keras_core.operations.nn import * # noqa: F403
|
||||
|
@ -34,35 +34,60 @@ class ImageOpsStaticShapeTest(testing.TestCase):
|
||||
class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
|
||||
@parameterized.parameters(
|
||||
[
|
||||
("bilinear", True),
|
||||
("nearest", True),
|
||||
("lanczos3", True),
|
||||
("lanczos5", True),
|
||||
("bicubic", True),
|
||||
("bilinear", False),
|
||||
("nearest", False),
|
||||
("lanczos3", False),
|
||||
("lanczos5", False),
|
||||
("bicubic", False),
|
||||
("bilinear", True, "channels_last"),
|
||||
("nearest", True, "channels_last"),
|
||||
("lanczos3", True, "channels_last"),
|
||||
("lanczos5", True, "channels_last"),
|
||||
("bicubic", True, "channels_last"),
|
||||
("bilinear", False, "channels_last"),
|
||||
("nearest", False, "channels_last"),
|
||||
("lanczos3", False, "channels_last"),
|
||||
("lanczos5", False, "channels_last"),
|
||||
("bicubic", False, "channels_last"),
|
||||
("bilinear", True, "channels_first"),
|
||||
]
|
||||
)
|
||||
def test_resize(self, method, antialias):
|
||||
x = np.random.random((50, 50, 3)) * 255
|
||||
def test_resize(self, method, antialias, data_format):
|
||||
# Unbatched case
|
||||
if data_format == "channels_first":
|
||||
x = np.random.random((3, 50, 50)) * 255
|
||||
else:
|
||||
x = np.random.random((50, 50, 3)) * 255
|
||||
out = kimage.resize(
|
||||
x, size=(25, 25), method=method, antialias=antialias
|
||||
x,
|
||||
size=(25, 25),
|
||||
method=method,
|
||||
antialias=antialias,
|
||||
data_format=data_format,
|
||||
)
|
||||
if data_format == "channels_first":
|
||||
x = np.transpose(x, (1, 2, 0))
|
||||
ref_out = tf.image.resize(
|
||||
x, size=(25, 25), method=method, antialias=antialias
|
||||
)
|
||||
if data_format == "channels_first":
|
||||
ref_out = np.transpose(ref_out, (2, 0, 1))
|
||||
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
|
||||
self.assertAllClose(ref_out, out, atol=0.3)
|
||||
|
||||
x = np.random.random((2, 50, 50, 3)) * 255
|
||||
# Batched case
|
||||
if data_format == "channels_first":
|
||||
x = np.random.random((2, 3, 50, 50)) * 255
|
||||
else:
|
||||
x = np.random.random((2, 50, 50, 3)) * 255
|
||||
out = kimage.resize(
|
||||
x, size=(25, 25), method=method, antialias=antialias
|
||||
x,
|
||||
size=(25, 25),
|
||||
method=method,
|
||||
antialias=antialias,
|
||||
data_format=data_format,
|
||||
)
|
||||
if data_format == "channels_first":
|
||||
x = np.transpose(x, (0, 2, 3, 1))
|
||||
ref_out = tf.image.resize(
|
||||
x, size=(25, 25), method=method, antialias=antialias
|
||||
)
|
||||
if data_format == "channels_first":
|
||||
ref_out = np.transpose(ref_out, (0, 3, 1, 2))
|
||||
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
|
||||
self.assertAllClose(ref_out, out, atol=0.3)
|
||||
|
@ -7,6 +7,7 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import operations as ops
|
||||
from keras_core.api_export import keras_core_export
|
||||
|
||||
try:
|
||||
@ -112,8 +113,10 @@ def array_to_img(x, data_format=None, scale=True, dtype=None):
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
"keras_core.utils.img_to_array",
|
||||
"keras_core.preprocessing.image.img_to_array",
|
||||
[
|
||||
"keras_core.utils.img_to_array",
|
||||
"keras_core.preprocessing.image.img_to_array",
|
||||
]
|
||||
)
|
||||
def img_to_array(img, data_format=None, dtype=None):
|
||||
"""Converts a PIL Image instance to a NumPy array.
|
||||
@ -301,3 +304,142 @@ def load_img(
|
||||
else:
|
||||
img = img.resize(width_height_tuple, resample)
|
||||
return img
|
||||
|
||||
|
||||
def smart_resize(
|
||||
x, size, interpolation="bilinear", data_format="channels_last"
|
||||
):
|
||||
"""Resize images to a target size without aspect ratio distortion.
|
||||
|
||||
Image datasets typically yield images that have each a different
|
||||
size. However, these images need to be batched before they can be
|
||||
processed by Keras layers. To be batched, images need to share the same
|
||||
height and width.
|
||||
|
||||
You could simply do, in TF (or JAX equivalent):
|
||||
|
||||
```python
|
||||
size = (200, 200)
|
||||
ds = ds.map(lambda img: resize(img, size))
|
||||
```
|
||||
|
||||
However, if you do this, you distort the aspect ratio of your images, since
|
||||
in general they do not all have the same aspect ratio as `size`. This is
|
||||
fine in many cases, but not always (e.g. for image generation models
|
||||
this can be a problem).
|
||||
|
||||
Note that passing the argument `preserve_aspect_ratio=True` to `resize`
|
||||
will preserve the aspect ratio, but at the cost of no longer respecting the
|
||||
provided target size.
|
||||
|
||||
This calls for:
|
||||
|
||||
```python
|
||||
size = (200, 200)
|
||||
ds = ds.map(lambda img: smart_resize(img, size))
|
||||
```
|
||||
|
||||
Your output images will actually be `(200, 200)`, and will not be distorted.
|
||||
Instead, the parts of the image that do not fit within the target size
|
||||
get cropped out.
|
||||
|
||||
The resizing process is:
|
||||
|
||||
1. Take the largest centered crop of the image that has the same aspect
|
||||
ratio as the target size. For instance, if `size=(200, 200)` and the input
|
||||
image has size `(340, 500)`, we take a crop of `(340, 340)` centered along
|
||||
the width.
|
||||
2. Resize the cropped image to the target size. In the example above,
|
||||
we resize the `(340, 340)` crop to `(200, 200)`.
|
||||
|
||||
Args:
|
||||
x: Input image or batch of images (as a tensor or NumPy array).
|
||||
Must be in format `(height, width, channels)`
|
||||
or `(batch_size, height, width, channels)`.
|
||||
size: Tuple of `(height, width)` integer. Target size.
|
||||
interpolation: String, interpolation to use for resizing.
|
||||
Defaults to `'bilinear'`.
|
||||
Supports `bilinear`, `nearest`, `bicubic`,
|
||||
`lanczos3`, `lanczos5`.
|
||||
data_format: `"channels_last"` or `"channels_first"`.
|
||||
|
||||
Returns:
|
||||
Array with shape `(size[0], size[1], channels)`. If the input image was a
|
||||
NumPy array, the output is a NumPy array,
|
||||
and if it was a backend-native tensor,
|
||||
the output is a backend-native tensor.
|
||||
"""
|
||||
if len(size) != 2:
|
||||
raise ValueError(
|
||||
f"Expected `size` to be a tuple of 2 integers, but got: {size}."
|
||||
)
|
||||
img = backend.convert_to_tensor(x)
|
||||
if len(img.shape) is not None:
|
||||
if len(img.shape) < 3 or len(img.shape) > 4:
|
||||
raise ValueError(
|
||||
"Expected an image array with shape `(height, width, "
|
||||
"channels)`, or `(batch_size, height, width, channels)`, but "
|
||||
f"got input with incorrect rank, of shape {img.shape}."
|
||||
)
|
||||
shape = ops.shape(img)
|
||||
if data_format == "channels_last":
|
||||
height, width = shape[-3], shape[-2]
|
||||
else:
|
||||
height, width = shape[-2], shape[-1]
|
||||
target_height, target_width = size
|
||||
|
||||
crop_height = ops.cast(
|
||||
ops.cast(width * target_height, "float32") / target_width, "int32"
|
||||
)
|
||||
crop_width = ops.cast(
|
||||
ops.cast(height * target_width, "float32") / target_height, "int32"
|
||||
)
|
||||
|
||||
# Set back to input height / width if crop_height / crop_width is not
|
||||
# smaller.
|
||||
crop_height = ops.minimum(height, crop_height)
|
||||
crop_width = ops.minimum(width, crop_width)
|
||||
|
||||
crop_box_hstart = ops.cast(
|
||||
ops.cast(height - crop_height, "float32") / 2, "int32"
|
||||
)
|
||||
crop_box_wstart = ops.cast(
|
||||
ops.cast(width - crop_width, "float32") / 2, "int32"
|
||||
)
|
||||
|
||||
if data_format == "channels_last":
|
||||
if len(img.shape) == 4:
|
||||
img = img[
|
||||
:,
|
||||
crop_box_hstart : crop_box_hstart + crop_height,
|
||||
crop_box_wstart : crop_box_wstart + crop_width,
|
||||
:,
|
||||
]
|
||||
else:
|
||||
img = img[
|
||||
crop_box_hstart : crop_box_hstart + crop_height,
|
||||
crop_box_wstart : crop_box_wstart + crop_width,
|
||||
:,
|
||||
]
|
||||
else:
|
||||
if len(img.shape) == 4:
|
||||
img = img[
|
||||
:,
|
||||
:,
|
||||
crop_box_hstart : crop_box_hstart + crop_height,
|
||||
crop_box_wstart : crop_box_wstart + crop_width,
|
||||
]
|
||||
else:
|
||||
img = img[
|
||||
:,
|
||||
crop_box_hstart : crop_box_hstart + crop_height,
|
||||
crop_box_wstart : crop_box_wstart + crop_width,
|
||||
]
|
||||
|
||||
img = ops.image.resize(
|
||||
img, size=size, method=interpolation, data_format=data_format
|
||||
)
|
||||
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.array(img)
|
||||
return img
|
||||
|
Loading…
Reference in New Issue
Block a user