keras/keras_core/utils/image_utils.py
2023-04-30 15:11:54 -07:00

304 lines
11 KiB
Python

"""Utilities related to image handling."""
import io
import pathlib
import warnings
import numpy as np
from keras_core import backend
from keras_core.api_export import keras_core_export
try:
from PIL import Image as pil_image
try:
pil_image_resampling = pil_image.Resampling
except AttributeError:
pil_image_resampling = pil_image
except ImportError:
pil_image = None
pil_image_resampling = None
if pil_image_resampling is not None:
PIL_INTERPOLATION_METHODS = {
"nearest": pil_image_resampling.NEAREST,
"bilinear": pil_image_resampling.BILINEAR,
"bicubic": pil_image_resampling.BICUBIC,
"hamming": pil_image_resampling.HAMMING,
"box": pil_image_resampling.BOX,
"lanczos": pil_image_resampling.LANCZOS,
}
@keras_core_export(
[
"keras_core.utils.array_to_img",
"keras_core.preprocessing.image.array_to_img",
]
)
def array_to_img(x, data_format=None, scale=True, dtype=None):
"""Converts a 3D NumPy array to a PIL Image instance.
Usage:
```python
from PIL import Image
img = np.random.random(size=(100, 100, 3))
pil_img = keras_core.utils.array_to_img(img)
```
Args:
x: Input data, in any form that can be converted to a NumPy array.
data_format: Image data format, can be either `"channels_first"` or
`"channels_last"`. Defaults to `None`, in which case the global
setting `keras_core.backend.image_data_format()` is used (unless you
changed it, it defaults to `"channels_last"`).
scale: Whether to rescale the image such that minimum and maximum values
are 0 and 255 respectively. Defaults to `True`.
dtype: Dtype to use. Default to `None`, in which case the global setting
`keras_core.backend.floatx()` is used (unless you changed it, it
defaults to `"float32"`).
Returns:
A PIL Image instance.
"""
if data_format is None:
data_format = backend.image_data_format()
if dtype is None:
dtype = backend.floatx()
if pil_image is None:
raise ImportError(
"Could not import PIL.Image. "
"The use of `array_to_img` requires PIL."
)
x = np.asarray(x, dtype=dtype)
if x.ndim != 3:
raise ValueError(
"Expected image array to have rank 3 (single image). "
f"Got array with shape: {x.shape}"
)
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(f"Invalid data_format: {data_format}")
# Original NumPy array x has format (height, width, channel)
# or (channel, height, width)
# but target PIL image has format (width, height, channel)
if data_format == "channels_first":
x = x.transpose(1, 2, 0)
if scale:
x = x - np.min(x)
x_max = np.max(x)
if x_max != 0:
x /= x_max
x *= 255
if x.shape[2] == 4:
# RGBA
return pil_image.fromarray(x.astype("uint8"), "RGBA")
elif x.shape[2] == 3:
# RGB
return pil_image.fromarray(x.astype("uint8"), "RGB")
elif x.shape[2] == 1:
# grayscale
if np.max(x) > 255:
# 32-bit signed integer grayscale image. PIL mode "I"
return pil_image.fromarray(x[:, :, 0].astype("int32"), "I")
return pil_image.fromarray(x[:, :, 0].astype("uint8"), "L")
else:
raise ValueError(f"Unsupported channel number: {x.shape[2]}")
@keras_core_export(
"keras_core.utils.img_to_array",
"keras_core.preprocessing.image.img_to_array",
)
def img_to_array(img, data_format=None, dtype=None):
"""Converts a PIL Image instance to a NumPy array.
Usage:
```python
from PIL import Image
img_data = np.random.random(size=(100, 100, 3))
img = keras_core.utils.array_to_img(img_data)
array = keras_core.utils.image.img_to_array(img)
```
Args:
img: Input PIL Image instance.
data_format: Image data format, can be either `"channels_first"` or
`"channels_last"`. Defaults to `None`, in which case the global
setting `keras_core.backend.image_data_format()` is used (unless you
changed it, it defaults to `"channels_last"`).
dtype: Dtype to use. Default to `None`, in which case the global setting
`keras_core.backend.floatx()` is used (unless you changed it, it
defaults to `"float32"`).
Returns:
A 3D NumPy array.
"""
if data_format is None:
data_format = backend.image_data_format()
if dtype is None:
dtype = backend.floatx()
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(f"Unknown data_format: {data_format}")
# NumPy array x has format (height, width, channel)
# or (channel, height, width)
# but original PIL image has format (width, height, channel)
x = np.asarray(img, dtype=dtype)
if len(x.shape) == 3:
if data_format == "channels_first":
x = x.transpose(2, 0, 1)
elif len(x.shape) == 2:
if data_format == "channels_first":
x = x.reshape((1, x.shape[0], x.shape[1]))
else:
x = x.reshape((x.shape[0], x.shape[1], 1))
else:
raise ValueError(f"Unsupported image shape: {x.shape}")
return x
@keras_core_export(
["keras_core.utils.save_img", "keras_core.preprocessing.image.save_img"]
)
def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
"""Saves an image stored as a NumPy array to a path or file object.
Args:
path: Path or file object.
x: NumPy array.
data_format: Image data format, either `"channels_first"` or
`"channels_last"`.
file_format: Optional file format override. If omitted, the format to
use is determined from the filename extension. If a file object was
used instead of a filename, this parameter should always be used.
scale: Whether to rescale image values to be within `[0, 255]`.
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
"""
if data_format is None:
data_format = backend.image_data_format()
img = array_to_img(x, data_format=data_format, scale=scale)
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"):
warnings.warn(
"The JPG format does not support RGBA images, converting to RGB."
)
img = img.convert("RGB")
img.save(path, format=file_format, **kwargs)
@keras_core_export(
["keras_core.utils.load_img", "keras_core.preprocessing.image.load_img"]
)
def load_img(
path,
color_mode="rgb",
target_size=None,
interpolation="nearest",
keep_aspect_ratio=False,
):
"""Loads an image into PIL format.
Usage:
```python
image = keras_core.utils.load_img(image_path)
input_arr = keras_core.utils.img_to_array(image)
input_arr = np.array([input_arr]) # Convert single image to a batch.
predictions = model.predict(input_arr)
```
Args:
path: Path to image file.
color_mode: One of `"grayscale"`, `"rgb"`, `"rgba"`. Default: `"rgb"`.
The desired image format.
target_size: Either `None` (default to original size) or tuple of ints
`(img_height, img_width)`.
interpolation: Interpolation method used to resample the image if the
target size is different from that of the loaded image. Supported
methods are `"nearest"`, `"bilinear"`, and `"bicubic"`.
If PIL version 1.1.3 or newer is installed, `"lanczos"`
is also supported. If PIL version 3.4.0 or newer is installed,
`"box"` and `"hamming"` are also
supported. By default, `"nearest"` is used.
keep_aspect_ratio: Boolean, whether to resize images to a target
size without aspect ratio distortion. The image is cropped in
the center with target aspect ratio before resizing.
Returns:
A PIL Image instance.
"""
if pil_image is None:
raise ImportError(
"Could not import PIL.Image. The use of `load_img` requires PIL."
)
if isinstance(path, io.BytesIO):
img = pil_image.open(path)
elif isinstance(path, (pathlib.Path, bytes, str)):
if isinstance(path, pathlib.Path):
path = str(path.resolve())
with open(path, "rb") as f:
img = pil_image.open(io.BytesIO(f.read()))
else:
raise TypeError(
f"path should be path-like or io.BytesIO, not {type(path)}"
)
if color_mode == "grayscale":
# if image is not already an 8-bit, 16-bit or 32-bit grayscale image
# convert it to an 8-bit grayscale image.
if img.mode not in ("L", "I;16", "I"):
img = img.convert("L")
elif color_mode == "rgba":
if img.mode != "RGBA":
img = img.convert("RGBA")
elif color_mode == "rgb":
if img.mode != "RGB":
img = img.convert("RGB")
else:
raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"')
if target_size is not None:
width_height_tuple = (target_size[1], target_size[0])
if img.size != width_height_tuple:
if interpolation not in PIL_INTERPOLATION_METHODS:
raise ValueError(
"Invalid interpolation method {} specified. Supported "
"methods are {}".format(
interpolation,
", ".join(PIL_INTERPOLATION_METHODS.keys()),
)
)
resample = PIL_INTERPOLATION_METHODS[interpolation]
if keep_aspect_ratio:
width, height = img.size
target_width, target_height = width_height_tuple
crop_height = (width * target_height) // target_width
crop_width = (height * target_width) // target_height
# Set back to input height / width
# if crop_height / crop_width is not smaller.
crop_height = min(height, crop_height)
crop_width = min(width, crop_width)
crop_box_hstart = (height - crop_height) // 2
crop_box_wstart = (width - crop_width) // 2
crop_box_wend = crop_box_wstart + crop_width
crop_box_hend = crop_box_hstart + crop_height
crop_box = [
crop_box_wstart,
crop_box_hstart,
crop_box_wend,
crop_box_hend,
]
img = img.resize(width_height_tuple, resample, box=crop_box)
else:
img = img.resize(width_height_tuple, resample)
return img