304 lines
11 KiB
Python
304 lines
11 KiB
Python
"""Utilities related to image handling."""
|
|
|
|
import io
|
|
import pathlib
|
|
import warnings
|
|
|
|
import numpy as np
|
|
|
|
from keras_core import backend
|
|
from keras_core.api_export import keras_core_export
|
|
|
|
try:
|
|
from PIL import Image as pil_image
|
|
|
|
try:
|
|
pil_image_resampling = pil_image.Resampling
|
|
except AttributeError:
|
|
pil_image_resampling = pil_image
|
|
except ImportError:
|
|
pil_image = None
|
|
pil_image_resampling = None
|
|
|
|
|
|
if pil_image_resampling is not None:
|
|
PIL_INTERPOLATION_METHODS = {
|
|
"nearest": pil_image_resampling.NEAREST,
|
|
"bilinear": pil_image_resampling.BILINEAR,
|
|
"bicubic": pil_image_resampling.BICUBIC,
|
|
"hamming": pil_image_resampling.HAMMING,
|
|
"box": pil_image_resampling.BOX,
|
|
"lanczos": pil_image_resampling.LANCZOS,
|
|
}
|
|
|
|
|
|
@keras_core_export(
|
|
[
|
|
"keras_core.utils.array_to_img",
|
|
"keras_core.preprocessing.image.array_to_img",
|
|
]
|
|
)
|
|
def array_to_img(x, data_format=None, scale=True, dtype=None):
|
|
"""Converts a 3D NumPy array to a PIL Image instance.
|
|
|
|
Usage:
|
|
|
|
```python
|
|
from PIL import Image
|
|
img = np.random.random(size=(100, 100, 3))
|
|
pil_img = keras_core.utils.array_to_img(img)
|
|
```
|
|
|
|
Args:
|
|
x: Input data, in any form that can be converted to a NumPy array.
|
|
data_format: Image data format, can be either `"channels_first"` or
|
|
`"channels_last"`. Defaults to `None`, in which case the global
|
|
setting `keras_core.backend.image_data_format()` is used (unless you
|
|
changed it, it defaults to `"channels_last"`).
|
|
scale: Whether to rescale the image such that minimum and maximum values
|
|
are 0 and 255 respectively. Defaults to `True`.
|
|
dtype: Dtype to use. Default to `None`, in which case the global setting
|
|
`keras_core.backend.floatx()` is used (unless you changed it, it
|
|
defaults to `"float32"`).
|
|
|
|
Returns:
|
|
A PIL Image instance.
|
|
"""
|
|
|
|
if data_format is None:
|
|
data_format = backend.image_data_format()
|
|
if dtype is None:
|
|
dtype = backend.floatx()
|
|
if pil_image is None:
|
|
raise ImportError(
|
|
"Could not import PIL.Image. "
|
|
"The use of `array_to_img` requires PIL."
|
|
)
|
|
x = np.asarray(x, dtype=dtype)
|
|
if x.ndim != 3:
|
|
raise ValueError(
|
|
"Expected image array to have rank 3 (single image). "
|
|
f"Got array with shape: {x.shape}"
|
|
)
|
|
|
|
if data_format not in {"channels_first", "channels_last"}:
|
|
raise ValueError(f"Invalid data_format: {data_format}")
|
|
|
|
# Original NumPy array x has format (height, width, channel)
|
|
# or (channel, height, width)
|
|
# but target PIL image has format (width, height, channel)
|
|
if data_format == "channels_first":
|
|
x = x.transpose(1, 2, 0)
|
|
if scale:
|
|
x = x - np.min(x)
|
|
x_max = np.max(x)
|
|
if x_max != 0:
|
|
x /= x_max
|
|
x *= 255
|
|
if x.shape[2] == 4:
|
|
# RGBA
|
|
return pil_image.fromarray(x.astype("uint8"), "RGBA")
|
|
elif x.shape[2] == 3:
|
|
# RGB
|
|
return pil_image.fromarray(x.astype("uint8"), "RGB")
|
|
elif x.shape[2] == 1:
|
|
# grayscale
|
|
if np.max(x) > 255:
|
|
# 32-bit signed integer grayscale image. PIL mode "I"
|
|
return pil_image.fromarray(x[:, :, 0].astype("int32"), "I")
|
|
return pil_image.fromarray(x[:, :, 0].astype("uint8"), "L")
|
|
else:
|
|
raise ValueError(f"Unsupported channel number: {x.shape[2]}")
|
|
|
|
|
|
@keras_core_export(
|
|
"keras_core.utils.img_to_array",
|
|
"keras_core.preprocessing.image.img_to_array",
|
|
)
|
|
def img_to_array(img, data_format=None, dtype=None):
|
|
"""Converts a PIL Image instance to a NumPy array.
|
|
|
|
Usage:
|
|
|
|
```python
|
|
from PIL import Image
|
|
img_data = np.random.random(size=(100, 100, 3))
|
|
img = keras_core.utils.array_to_img(img_data)
|
|
array = keras_core.utils.image.img_to_array(img)
|
|
```
|
|
|
|
Args:
|
|
img: Input PIL Image instance.
|
|
data_format: Image data format, can be either `"channels_first"` or
|
|
`"channels_last"`. Defaults to `None`, in which case the global
|
|
setting `keras_core.backend.image_data_format()` is used (unless you
|
|
changed it, it defaults to `"channels_last"`).
|
|
dtype: Dtype to use. Default to `None`, in which case the global setting
|
|
`keras_core.backend.floatx()` is used (unless you changed it, it
|
|
defaults to `"float32"`).
|
|
|
|
Returns:
|
|
A 3D NumPy array.
|
|
"""
|
|
|
|
if data_format is None:
|
|
data_format = backend.image_data_format()
|
|
if dtype is None:
|
|
dtype = backend.floatx()
|
|
if data_format not in {"channels_first", "channels_last"}:
|
|
raise ValueError(f"Unknown data_format: {data_format}")
|
|
# NumPy array x has format (height, width, channel)
|
|
# or (channel, height, width)
|
|
# but original PIL image has format (width, height, channel)
|
|
x = np.asarray(img, dtype=dtype)
|
|
if len(x.shape) == 3:
|
|
if data_format == "channels_first":
|
|
x = x.transpose(2, 0, 1)
|
|
elif len(x.shape) == 2:
|
|
if data_format == "channels_first":
|
|
x = x.reshape((1, x.shape[0], x.shape[1]))
|
|
else:
|
|
x = x.reshape((x.shape[0], x.shape[1], 1))
|
|
else:
|
|
raise ValueError(f"Unsupported image shape: {x.shape}")
|
|
return x
|
|
|
|
|
|
@keras_core_export(
|
|
["keras_core.utils.save_img", "keras_core.preprocessing.image.save_img"]
|
|
)
|
|
def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
|
|
"""Saves an image stored as a NumPy array to a path or file object.
|
|
|
|
Args:
|
|
path: Path or file object.
|
|
x: NumPy array.
|
|
data_format: Image data format, either `"channels_first"` or
|
|
`"channels_last"`.
|
|
file_format: Optional file format override. If omitted, the format to
|
|
use is determined from the filename extension. If a file object was
|
|
used instead of a filename, this parameter should always be used.
|
|
scale: Whether to rescale image values to be within `[0, 255]`.
|
|
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
|
|
"""
|
|
if data_format is None:
|
|
data_format = backend.image_data_format()
|
|
img = array_to_img(x, data_format=data_format, scale=scale)
|
|
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"):
|
|
warnings.warn(
|
|
"The JPG format does not support RGBA images, converting to RGB."
|
|
)
|
|
img = img.convert("RGB")
|
|
img.save(path, format=file_format, **kwargs)
|
|
|
|
|
|
@keras_core_export(
|
|
["keras_core.utils.load_img", "keras_core.preprocessing.image.load_img"]
|
|
)
|
|
def load_img(
|
|
path,
|
|
color_mode="rgb",
|
|
target_size=None,
|
|
interpolation="nearest",
|
|
keep_aspect_ratio=False,
|
|
):
|
|
"""Loads an image into PIL format.
|
|
|
|
Usage:
|
|
|
|
```python
|
|
image = keras_core.utils.load_img(image_path)
|
|
input_arr = keras_core.utils.img_to_array(image)
|
|
input_arr = np.array([input_arr]) # Convert single image to a batch.
|
|
predictions = model.predict(input_arr)
|
|
```
|
|
|
|
Args:
|
|
path: Path to image file.
|
|
color_mode: One of `"grayscale"`, `"rgb"`, `"rgba"`. Default: `"rgb"`.
|
|
The desired image format.
|
|
target_size: Either `None` (default to original size) or tuple of ints
|
|
`(img_height, img_width)`.
|
|
interpolation: Interpolation method used to resample the image if the
|
|
target size is different from that of the loaded image. Supported
|
|
methods are `"nearest"`, `"bilinear"`, and `"bicubic"`.
|
|
If PIL version 1.1.3 or newer is installed, `"lanczos"`
|
|
is also supported. If PIL version 3.4.0 or newer is installed,
|
|
`"box"` and `"hamming"` are also
|
|
supported. By default, `"nearest"` is used.
|
|
keep_aspect_ratio: Boolean, whether to resize images to a target
|
|
size without aspect ratio distortion. The image is cropped in
|
|
the center with target aspect ratio before resizing.
|
|
|
|
Returns:
|
|
A PIL Image instance.
|
|
"""
|
|
if pil_image is None:
|
|
raise ImportError(
|
|
"Could not import PIL.Image. The use of `load_img` requires PIL."
|
|
)
|
|
if isinstance(path, io.BytesIO):
|
|
img = pil_image.open(path)
|
|
elif isinstance(path, (pathlib.Path, bytes, str)):
|
|
if isinstance(path, pathlib.Path):
|
|
path = str(path.resolve())
|
|
with open(path, "rb") as f:
|
|
img = pil_image.open(io.BytesIO(f.read()))
|
|
else:
|
|
raise TypeError(
|
|
f"path should be path-like or io.BytesIO, not {type(path)}"
|
|
)
|
|
|
|
if color_mode == "grayscale":
|
|
# if image is not already an 8-bit, 16-bit or 32-bit grayscale image
|
|
# convert it to an 8-bit grayscale image.
|
|
if img.mode not in ("L", "I;16", "I"):
|
|
img = img.convert("L")
|
|
elif color_mode == "rgba":
|
|
if img.mode != "RGBA":
|
|
img = img.convert("RGBA")
|
|
elif color_mode == "rgb":
|
|
if img.mode != "RGB":
|
|
img = img.convert("RGB")
|
|
else:
|
|
raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"')
|
|
if target_size is not None:
|
|
width_height_tuple = (target_size[1], target_size[0])
|
|
if img.size != width_height_tuple:
|
|
if interpolation not in PIL_INTERPOLATION_METHODS:
|
|
raise ValueError(
|
|
"Invalid interpolation method {} specified. Supported "
|
|
"methods are {}".format(
|
|
interpolation,
|
|
", ".join(PIL_INTERPOLATION_METHODS.keys()),
|
|
)
|
|
)
|
|
resample = PIL_INTERPOLATION_METHODS[interpolation]
|
|
|
|
if keep_aspect_ratio:
|
|
width, height = img.size
|
|
target_width, target_height = width_height_tuple
|
|
|
|
crop_height = (width * target_height) // target_width
|
|
crop_width = (height * target_width) // target_height
|
|
|
|
# Set back to input height / width
|
|
# if crop_height / crop_width is not smaller.
|
|
crop_height = min(height, crop_height)
|
|
crop_width = min(width, crop_width)
|
|
|
|
crop_box_hstart = (height - crop_height) // 2
|
|
crop_box_wstart = (width - crop_width) // 2
|
|
crop_box_wend = crop_box_wstart + crop_width
|
|
crop_box_hend = crop_box_hstart + crop_height
|
|
crop_box = [
|
|
crop_box_wstart,
|
|
crop_box_hstart,
|
|
crop_box_wend,
|
|
crop_box_hend,
|
|
]
|
|
img = img.resize(width_height_tuple, resample, box=crop_box)
|
|
else:
|
|
img = img.resize(width_height_tuple, resample)
|
|
return img
|