Convert to Keras 3: super_resolution_sub_pixel example (#18505)

* add: `super_resolution_sub_pixel` example

* add: old version

* Revert "add: old version"

This reverts commit 6fc71031f5df0e7ec10a5cdbfcff7914331ffae3.

* add: accelerator

* remove: hardcoded backend

* add: comment on TF

* update: use python3 style init
This commit is contained in:
Awsaf 2023-09-27 15:35:46 +06:00 committed by GitHub
parent bba4728980
commit eb125174fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,407 @@
"""
Title: Image Super-Resolution using an Efficient Sub-Pixel CNN
Author: [Xingyu Long](https://github.com/xingyu-long)
Converted to Keras 3 by: [Md Awsfalur Rahman](https://awsaf49.github.io)
Date created: 2020/07/28
Last modified: 2020/08/27
Description: Implementing Super-Resolution using Efficient sub-pixel model on BSDS500.
Accelerator: GPU
"""
"""
## Introduction
ESPCN (Efficient Sub-Pixel CNN), proposed by [Shi, 2016](https://arxiv.org/abs/1609.05158)
is a model that reconstructs a high-resolution version of an image given a low-resolution
version.
It leverages efficient "sub-pixel convolution" layers, which learns an array of
image upscaling filters.
In this code example, we will implement the model from the paper and train it on a small
dataset,
[BSDS500](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html).
[BSDS500](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html).
"""
"""
## Setup
"""
import keras as keras
from keras import layers
from keras import ops
from keras.utils import load_img
from keras.utils import array_to_img
from keras.utils import img_to_array
from keras.preprocessing import image_dataset_from_directory
import tensorflow as tf # only for data preprocessing
import math
import numpy as np
from IPython.display import display
"""
## Load data: BSDS500 dataset
### Download dataset
We use the built-in `keras.utils.get_file` utility to retrieve the dataset.
"""
dataset_url =
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
data_dir = keras.utils.get_file(origin=dataset_url, fname="BSR", untar=True)
root_dir = os.path.join(data_dir, "BSDS500/data")
"""
We create training and validation datasets via `image_dataset_from_directory`.
"""
crop_size = 300
upscale_factor = 3
input_size = crop_size // upscale_factor
batch_size = 8
train_ds = image_dataset_from_directory(
root_dir,
batch_size=batch_size,
image_size=(crop_size, crop_size),
validation_split=0.2,
subset="training",
seed=1337,
label_mode=None,
)
valid_ds = image_dataset_from_directory(
root_dir,
batch_size=batch_size,
image_size=(crop_size, crop_size),
validation_split=0.2,
subset="validation",
seed=1337,
label_mode=None,
)
"""
We rescale the images to take values in the range [0, 1].
"""
def scaling(input_image):
input_image = input_image / 255.0
return input_image
# Scale from (0, 255) to (0, 1)
train_ds = train_ds.map(scaling)
valid_ds = valid_ds.map(scaling)
"""
Let's visualize a few sample images:
"""
for batch in train_ds.take(1):
for img in batch:
display(array_to_img(img))
"""
We prepare a dataset of test image paths that we will use for
visual evaluation at the end of this example.
"""
dataset = os.path.join(root_dir, "images")
test_path = os.path.join(dataset, "test")
test_img_paths = sorted(
[
os.path.join(test_path, fname)
for fname in os.listdir(test_path)
if fname.endswith(".jpg")
]
)
"""
## Crop and resize images
Let's process image data.
First, we convert our images from the RGB color space to the
[YUV colour space](https://en.wikipedia.org/wiki/YUV).
For the input data (low-resolution images),
we crop the image, retrieve the `y` channel (luninance),
and resize it with the `area` method (use `BICUBIC` if you use PIL).
We only consider the luminance channel
in the YUV color space because humans are more sensitive to
luminance change.
For the target data (high-resolution images), we just crop the image
and retrieve the `y` channel.
"""
# Use TF Ops to process.
def process_input(input, input_size, upscale_factor):
input = tf.image.rgb_to_yuv(input)
last_dimension_axis = len(input.shape) - 1
y, u, v = tf.split(input, 3, axis=last_dimension_axis)
return tf.image.resize(y, [input_size, input_size], method="area")
def process_target(input):
input = tf.image.rgb_to_yuv(input)
last_dimension_axis = len(input.shape) - 1
y, u, v = tf.split(input, 3, axis=last_dimension_axis)
return y
train_ds = train_ds.map(
lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
)
train_ds = train_ds.prefetch(buffer_size=32)
valid_ds = valid_ds.map(
lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
)
valid_ds = valid_ds.prefetch(buffer_size=32)
"""
Let's take a look at the input and target data.
"""
for batch in train_ds.take(1):
for img in batch[0]:
display(array_to_img(img))
for img in batch[1]:
display(array_to_img(img))
"""
## Build a model
Compared to the paper, we add one more layer and we use the `relu` activation function
instead of `tanh`.
It achieves better performance even though we train the model for fewer epochs.
"""
class DepthToSpace(layers.Layer):
def __init__(self, block_size):
super().__init__()
self.block_size = block_size
def call(self, input):
batch, height, width, depth = ops.shape(input)
depth = depth // (self.block_size**2)
x = ops.reshape(
input, [batch, height, width, self.block_size, self.block_size, depth]
)
x = ops.transpose(x, [0, 1, 3, 2, 4, 5])
x = ops.reshape(
x, [batch, height * self.block_size, width * self.block_size, depth]
)
return x
def get_model(upscale_factor=3, channels=1):
conv_args = {
"activation": "relu",
"kernel_initializer": "orthogonal",
"padding": "same",
}
inputs = keras.Input(shape=(None, None, channels))
x = layers.Conv2D(64, 5, **conv_args)(inputs)
x = layers.Conv2D(64, 3, **conv_args)(x)
x = layers.Conv2D(32, 3, **conv_args)(x)
x = layers.Conv2D(channels * (upscale_factor**2), 3, **conv_args)(x)
outputs = DepthToSpace(upscale_factor)(x)
return keras.Model(inputs, outputs)
"""
## Define utility functions
We need to define several utility functions to monitor our results:
- `plot_results` to plot an save an image.
- `get_lowres_image` to convert an image to its low-resolution version.
- `upscale_image` to turn a low-resolution image to
a high-resolution version reconstructed by the model.
In this function, we use the `y` channel from the YUV color space
as input to the model and then combine the output with the
other channels to obtain an RGB image.
"""
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
import PIL
def plot_results(img, prefix, title):
"""Plot the result with zoom-in area."""
img_array = img_to_array(img)
img_array = img_array.astype("float32") / 255.0
# Create a new figure with a default 111 subplot.
fig, ax = plt.subplots()
im = ax.imshow(img_array[::-1], origin="lower")
plt.title(title)
# zoom-factor: 2.0, location: upper-left
axins = zoomed_inset_axes(ax, 2, loc=2)
axins.imshow(img_array[::-1], origin="lower")
# Specify the limits.
x1, x2, y1, y2 = 200, 300, 100, 200
# Apply the x-limits.
axins.set_xlim(x1, x2)
# Apply the y-limits.
axins.set_ylim(y1, y2)
plt.yticks(visible=False)
plt.xticks(visible=False)
# Make the line.
mark_inset(ax, axins, loc1=1, loc2=3, fc="none", ec="blue")
plt.savefig(str(prefix) + "-" + title + ".png")
plt.show()
def get_lowres_image(img, upscale_factor):
"""Return low-resolution image to use as model input."""
return img.resize(
(img.size[0] // upscale_factor, img.size[1] // upscale_factor),
PIL.Image.BICUBIC,
)
def upscale_image(model, img):
"""Predict the result based on input image and restore the image as RGB."""
ycbcr = img.convert("YCbCr")
y, cb, cr = ycbcr.split()
y = img_to_array(y)
y = y.astype("float32") / 255.0
input = np.expand_dims(y, axis=0)
out = model.predict(input)
out_img_y = out[0]
out_img_y *= 255.0
# Restore the image in RGB color space.
out_img_y = out_img_y.clip(0, 255)
out_img_y = out_img_y.reshape((np.shape(out_img_y)[0], np.shape(out_img_y)[1]))
out_img_y = PIL.Image.fromarray(np.uint8(out_img_y), mode="L")
out_img_cb = cb.resize(out_img_y.size, PIL.Image.BICUBIC)
out_img_cr = cr.resize(out_img_y.size, PIL.Image.BICUBIC)
out_img = PIL.Image.merge("YCbCr", (out_img_y, out_img_cb, out_img_cr)).convert(
"RGB"
)
return out_img
"""
## Define callbacks to monitor training
The `ESPCNCallback` object will compute and display
the [PSNR](https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio) metric.
This is the main metric we use to evaluate super-resolution performance.
"""
class ESPCNCallback(keras.callbacks.Callback):
def __init__(self):
super().__init__()
self.test_img = get_lowres_image(load_img(test_img_paths[0]), upscale_factor)
# Store PSNR value in each epoch.
def on_epoch_begin(self, epoch, logs=None):
self.psnr = []
def on_epoch_end(self, epoch, logs=None):
print("Mean PSNR for epoch: %.2f" % (np.mean(self.psnr)))
if epoch % 20 == 0:
prediction = upscale_image(self.model, self.test_img)
plot_results(prediction, "epoch-" + str(epoch), "prediction")
def on_test_batch_end(self, batch, logs=None):
self.psnr.append(10 * math.log10(1 / logs["loss"]))
"""
Define `ModelCheckpoint` and `EarlyStopping` callbacks.
"""
early_stopping_callback = keras.callbacks.EarlyStopping(monitor="loss", patience=10)
checkpoint_filepath = "/tmp/checkpoint.keras"
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=False,
monitor="loss",
mode="min",
save_best_only=True,
)
model = get_model(upscale_factor=upscale_factor, channels=1)
model.summary()
callbacks = [ESPCNCallback(), early_stopping_callback, model_checkpoint_callback]
loss_fn = keras.losses.MeanSquaredError()
optimizer = keras.optimizers.Adam(learning_rate=0.001)
"""
## Train the model
"""
epochs = 100
model.compile(
optimizer=optimizer,
loss=loss_fn,
)
model.fit(
train_ds, epochs=epochs, callbacks=callbacks, validation_data=valid_ds, verbose=2
)
# The model weights (that are considered the best) are loaded into the model.
model.load_weights(checkpoint_filepath)
"""
## Run model prediction and plot the results
Let's compute the reconstructed version of a few images and save the results.
"""
total_bicubic_psnr = 0.0
total_test_psnr = 0.0
for index, test_img_path in enumerate(test_img_paths[50:60]):
img = load_img(test_img_path)
lowres_input = get_lowres_image(img, upscale_factor)
w = lowres_input.size[0] * upscale_factor
h = lowres_input.size[1] * upscale_factor
highres_img = img.resize((w, h))
prediction = upscale_image(model, lowres_input)
lowres_img = lowres_input.resize((w, h))
lowres_img_arr = img_to_array(lowres_img)
highres_img_arr = img_to_array(highres_img)
predict_img_arr = img_to_array(prediction)
bicubic_psnr = tf.image.psnr(lowres_img_arr, highres_img_arr, max_val=255)
test_psnr = tf.image.psnr(predict_img_arr, highres_img_arr, max_val=255)
total_bicubic_psnr += bicubic_psnr
total_test_psnr += test_psnr
print(
"PSNR of low resolution image and high resolution image is %.4f" % bicubic_psnr
)
print("PSNR of predict and high resolution is %.4f" % test_psnr)
plot_results(lowres_img, index, "lowres")
plot_results(highres_img, index, "highres")
plot_results(prediction, index, "prediction")
print("Avg. PSNR of lowres images is %.4f" % (total_bicubic_psnr / 10))
print("Avg. PSNR of reconstructions is %.4f" % (total_test_psnr / 10))