Convert to Keras 3: super_resolution_sub_pixel
example (#18505)
* add: `super_resolution_sub_pixel` example * add: old version * Revert "add: old version" This reverts commit 6fc71031f5df0e7ec10a5cdbfcff7914331ffae3. * add: accelerator * remove: hardcoded backend * add: comment on TF * update: use python3 style init
This commit is contained in:
parent
bba4728980
commit
eb125174fc
407
examples/keras_io/vision/super_resolution_sub_pixel.py
Normal file
407
examples/keras_io/vision/super_resolution_sub_pixel.py
Normal file
@ -0,0 +1,407 @@
|
|||||||
|
"""
|
||||||
|
Title: Image Super-Resolution using an Efficient Sub-Pixel CNN
|
||||||
|
Author: [Xingyu Long](https://github.com/xingyu-long)
|
||||||
|
Converted to Keras 3 by: [Md Awsfalur Rahman](https://awsaf49.github.io)
|
||||||
|
Date created: 2020/07/28
|
||||||
|
Last modified: 2020/08/27
|
||||||
|
Description: Implementing Super-Resolution using Efficient sub-pixel model on BSDS500.
|
||||||
|
Accelerator: GPU
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
ESPCN (Efficient Sub-Pixel CNN), proposed by [Shi, 2016](https://arxiv.org/abs/1609.05158)
|
||||||
|
is a model that reconstructs a high-resolution version of an image given a low-resolution
|
||||||
|
version.
|
||||||
|
It leverages efficient "sub-pixel convolution" layers, which learns an array of
|
||||||
|
image upscaling filters.
|
||||||
|
|
||||||
|
In this code example, we will implement the model from the paper and train it on a small
|
||||||
|
dataset,
|
||||||
|
[BSDS500](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html).
|
||||||
|
[BSDS500](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html).
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Setup
|
||||||
|
"""
|
||||||
|
|
||||||
|
import keras as keras
|
||||||
|
from keras import layers
|
||||||
|
from keras import ops
|
||||||
|
from keras.utils import load_img
|
||||||
|
from keras.utils import array_to_img
|
||||||
|
from keras.utils import img_to_array
|
||||||
|
from keras.preprocessing import image_dataset_from_directory
|
||||||
|
import tensorflow as tf # only for data preprocessing
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from IPython.display import display
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Load data: BSDS500 dataset
|
||||||
|
|
||||||
|
### Download dataset
|
||||||
|
|
||||||
|
We use the built-in `keras.utils.get_file` utility to retrieve the dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_url =
|
||||||
|
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
|
||||||
|
data_dir = keras.utils.get_file(origin=dataset_url, fname="BSR", untar=True)
|
||||||
|
root_dir = os.path.join(data_dir, "BSDS500/data")
|
||||||
|
|
||||||
|
"""
|
||||||
|
We create training and validation datasets via `image_dataset_from_directory`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
crop_size = 300
|
||||||
|
upscale_factor = 3
|
||||||
|
input_size = crop_size // upscale_factor
|
||||||
|
batch_size = 8
|
||||||
|
|
||||||
|
train_ds = image_dataset_from_directory(
|
||||||
|
root_dir,
|
||||||
|
batch_size=batch_size,
|
||||||
|
image_size=(crop_size, crop_size),
|
||||||
|
validation_split=0.2,
|
||||||
|
subset="training",
|
||||||
|
seed=1337,
|
||||||
|
label_mode=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_ds = image_dataset_from_directory(
|
||||||
|
root_dir,
|
||||||
|
batch_size=batch_size,
|
||||||
|
image_size=(crop_size, crop_size),
|
||||||
|
validation_split=0.2,
|
||||||
|
subset="validation",
|
||||||
|
seed=1337,
|
||||||
|
label_mode=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
We rescale the images to take values in the range [0, 1].
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def scaling(input_image):
|
||||||
|
input_image = input_image / 255.0
|
||||||
|
return input_image
|
||||||
|
|
||||||
|
|
||||||
|
# Scale from (0, 255) to (0, 1)
|
||||||
|
train_ds = train_ds.map(scaling)
|
||||||
|
valid_ds = valid_ds.map(scaling)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Let's visualize a few sample images:
|
||||||
|
"""
|
||||||
|
|
||||||
|
for batch in train_ds.take(1):
|
||||||
|
for img in batch:
|
||||||
|
display(array_to_img(img))
|
||||||
|
|
||||||
|
"""
|
||||||
|
We prepare a dataset of test image paths that we will use for
|
||||||
|
visual evaluation at the end of this example.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset = os.path.join(root_dir, "images")
|
||||||
|
test_path = os.path.join(dataset, "test")
|
||||||
|
|
||||||
|
test_img_paths = sorted(
|
||||||
|
[
|
||||||
|
os.path.join(test_path, fname)
|
||||||
|
for fname in os.listdir(test_path)
|
||||||
|
if fname.endswith(".jpg")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Crop and resize images
|
||||||
|
|
||||||
|
Let's process image data.
|
||||||
|
First, we convert our images from the RGB color space to the
|
||||||
|
[YUV colour space](https://en.wikipedia.org/wiki/YUV).
|
||||||
|
|
||||||
|
For the input data (low-resolution images),
|
||||||
|
we crop the image, retrieve the `y` channel (luninance),
|
||||||
|
and resize it with the `area` method (use `BICUBIC` if you use PIL).
|
||||||
|
We only consider the luminance channel
|
||||||
|
in the YUV color space because humans are more sensitive to
|
||||||
|
luminance change.
|
||||||
|
|
||||||
|
For the target data (high-resolution images), we just crop the image
|
||||||
|
and retrieve the `y` channel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Use TF Ops to process.
|
||||||
|
def process_input(input, input_size, upscale_factor):
|
||||||
|
input = tf.image.rgb_to_yuv(input)
|
||||||
|
last_dimension_axis = len(input.shape) - 1
|
||||||
|
y, u, v = tf.split(input, 3, axis=last_dimension_axis)
|
||||||
|
return tf.image.resize(y, [input_size, input_size], method="area")
|
||||||
|
|
||||||
|
|
||||||
|
def process_target(input):
|
||||||
|
input = tf.image.rgb_to_yuv(input)
|
||||||
|
last_dimension_axis = len(input.shape) - 1
|
||||||
|
y, u, v = tf.split(input, 3, axis=last_dimension_axis)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
train_ds = train_ds.map(
|
||||||
|
lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
|
||||||
|
)
|
||||||
|
train_ds = train_ds.prefetch(buffer_size=32)
|
||||||
|
|
||||||
|
valid_ds = valid_ds.map(
|
||||||
|
lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
|
||||||
|
)
|
||||||
|
valid_ds = valid_ds.prefetch(buffer_size=32)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Let's take a look at the input and target data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for batch in train_ds.take(1):
|
||||||
|
for img in batch[0]:
|
||||||
|
display(array_to_img(img))
|
||||||
|
for img in batch[1]:
|
||||||
|
display(array_to_img(img))
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Build a model
|
||||||
|
|
||||||
|
Compared to the paper, we add one more layer and we use the `relu` activation function
|
||||||
|
instead of `tanh`.
|
||||||
|
It achieves better performance even though we train the model for fewer epochs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class DepthToSpace(layers.Layer):
|
||||||
|
def __init__(self, block_size):
|
||||||
|
super().__init__()
|
||||||
|
self.block_size = block_size
|
||||||
|
|
||||||
|
def call(self, input):
|
||||||
|
batch, height, width, depth = ops.shape(input)
|
||||||
|
depth = depth // (self.block_size**2)
|
||||||
|
|
||||||
|
x = ops.reshape(
|
||||||
|
input, [batch, height, width, self.block_size, self.block_size, depth]
|
||||||
|
)
|
||||||
|
x = ops.transpose(x, [0, 1, 3, 2, 4, 5])
|
||||||
|
x = ops.reshape(
|
||||||
|
x, [batch, height * self.block_size, width * self.block_size, depth]
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(upscale_factor=3, channels=1):
|
||||||
|
conv_args = {
|
||||||
|
"activation": "relu",
|
||||||
|
"kernel_initializer": "orthogonal",
|
||||||
|
"padding": "same",
|
||||||
|
}
|
||||||
|
inputs = keras.Input(shape=(None, None, channels))
|
||||||
|
x = layers.Conv2D(64, 5, **conv_args)(inputs)
|
||||||
|
x = layers.Conv2D(64, 3, **conv_args)(x)
|
||||||
|
x = layers.Conv2D(32, 3, **conv_args)(x)
|
||||||
|
x = layers.Conv2D(channels * (upscale_factor**2), 3, **conv_args)(x)
|
||||||
|
outputs = DepthToSpace(upscale_factor)(x)
|
||||||
|
|
||||||
|
return keras.Model(inputs, outputs)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Define utility functions
|
||||||
|
|
||||||
|
We need to define several utility functions to monitor our results:
|
||||||
|
|
||||||
|
- `plot_results` to plot an save an image.
|
||||||
|
- `get_lowres_image` to convert an image to its low-resolution version.
|
||||||
|
- `upscale_image` to turn a low-resolution image to
|
||||||
|
a high-resolution version reconstructed by the model.
|
||||||
|
In this function, we use the `y` channel from the YUV color space
|
||||||
|
as input to the model and then combine the output with the
|
||||||
|
other channels to obtain an RGB image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
|
||||||
|
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
def plot_results(img, prefix, title):
|
||||||
|
"""Plot the result with zoom-in area."""
|
||||||
|
img_array = img_to_array(img)
|
||||||
|
img_array = img_array.astype("float32") / 255.0
|
||||||
|
|
||||||
|
# Create a new figure with a default 111 subplot.
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
im = ax.imshow(img_array[::-1], origin="lower")
|
||||||
|
|
||||||
|
plt.title(title)
|
||||||
|
# zoom-factor: 2.0, location: upper-left
|
||||||
|
axins = zoomed_inset_axes(ax, 2, loc=2)
|
||||||
|
axins.imshow(img_array[::-1], origin="lower")
|
||||||
|
|
||||||
|
# Specify the limits.
|
||||||
|
x1, x2, y1, y2 = 200, 300, 100, 200
|
||||||
|
# Apply the x-limits.
|
||||||
|
axins.set_xlim(x1, x2)
|
||||||
|
# Apply the y-limits.
|
||||||
|
axins.set_ylim(y1, y2)
|
||||||
|
|
||||||
|
plt.yticks(visible=False)
|
||||||
|
plt.xticks(visible=False)
|
||||||
|
|
||||||
|
# Make the line.
|
||||||
|
mark_inset(ax, axins, loc1=1, loc2=3, fc="none", ec="blue")
|
||||||
|
plt.savefig(str(prefix) + "-" + title + ".png")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def get_lowres_image(img, upscale_factor):
|
||||||
|
"""Return low-resolution image to use as model input."""
|
||||||
|
return img.resize(
|
||||||
|
(img.size[0] // upscale_factor, img.size[1] // upscale_factor),
|
||||||
|
PIL.Image.BICUBIC,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_image(model, img):
|
||||||
|
"""Predict the result based on input image and restore the image as RGB."""
|
||||||
|
ycbcr = img.convert("YCbCr")
|
||||||
|
y, cb, cr = ycbcr.split()
|
||||||
|
y = img_to_array(y)
|
||||||
|
y = y.astype("float32") / 255.0
|
||||||
|
|
||||||
|
input = np.expand_dims(y, axis=0)
|
||||||
|
out = model.predict(input)
|
||||||
|
|
||||||
|
out_img_y = out[0]
|
||||||
|
out_img_y *= 255.0
|
||||||
|
|
||||||
|
# Restore the image in RGB color space.
|
||||||
|
out_img_y = out_img_y.clip(0, 255)
|
||||||
|
out_img_y = out_img_y.reshape((np.shape(out_img_y)[0], np.shape(out_img_y)[1]))
|
||||||
|
out_img_y = PIL.Image.fromarray(np.uint8(out_img_y), mode="L")
|
||||||
|
out_img_cb = cb.resize(out_img_y.size, PIL.Image.BICUBIC)
|
||||||
|
out_img_cr = cr.resize(out_img_y.size, PIL.Image.BICUBIC)
|
||||||
|
out_img = PIL.Image.merge("YCbCr", (out_img_y, out_img_cb, out_img_cr)).convert(
|
||||||
|
"RGB"
|
||||||
|
)
|
||||||
|
return out_img
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Define callbacks to monitor training
|
||||||
|
|
||||||
|
The `ESPCNCallback` object will compute and display
|
||||||
|
the [PSNR](https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio) metric.
|
||||||
|
This is the main metric we use to evaluate super-resolution performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ESPCNCallback(keras.callbacks.Callback):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.test_img = get_lowres_image(load_img(test_img_paths[0]), upscale_factor)
|
||||||
|
|
||||||
|
# Store PSNR value in each epoch.
|
||||||
|
def on_epoch_begin(self, epoch, logs=None):
|
||||||
|
self.psnr = []
|
||||||
|
|
||||||
|
def on_epoch_end(self, epoch, logs=None):
|
||||||
|
print("Mean PSNR for epoch: %.2f" % (np.mean(self.psnr)))
|
||||||
|
if epoch % 20 == 0:
|
||||||
|
prediction = upscale_image(self.model, self.test_img)
|
||||||
|
plot_results(prediction, "epoch-" + str(epoch), "prediction")
|
||||||
|
|
||||||
|
def on_test_batch_end(self, batch, logs=None):
|
||||||
|
self.psnr.append(10 * math.log10(1 / logs["loss"]))
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Define `ModelCheckpoint` and `EarlyStopping` callbacks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
early_stopping_callback = keras.callbacks.EarlyStopping(monitor="loss", patience=10)
|
||||||
|
|
||||||
|
checkpoint_filepath = "/tmp/checkpoint.keras"
|
||||||
|
|
||||||
|
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
|
||||||
|
filepath=checkpoint_filepath,
|
||||||
|
save_weights_only=False,
|
||||||
|
monitor="loss",
|
||||||
|
mode="min",
|
||||||
|
save_best_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = get_model(upscale_factor=upscale_factor, channels=1)
|
||||||
|
model.summary()
|
||||||
|
|
||||||
|
callbacks = [ESPCNCallback(), early_stopping_callback, model_checkpoint_callback]
|
||||||
|
loss_fn = keras.losses.MeanSquaredError()
|
||||||
|
optimizer = keras.optimizers.Adam(learning_rate=0.001)
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Train the model
|
||||||
|
"""
|
||||||
|
|
||||||
|
epochs = 100
|
||||||
|
|
||||||
|
model.compile(
|
||||||
|
optimizer=optimizer,
|
||||||
|
loss=loss_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.fit(
|
||||||
|
train_ds, epochs=epochs, callbacks=callbacks, validation_data=valid_ds, verbose=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# The model weights (that are considered the best) are loaded into the model.
|
||||||
|
model.load_weights(checkpoint_filepath)
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Run model prediction and plot the results
|
||||||
|
|
||||||
|
Let's compute the reconstructed version of a few images and save the results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
total_bicubic_psnr = 0.0
|
||||||
|
total_test_psnr = 0.0
|
||||||
|
|
||||||
|
for index, test_img_path in enumerate(test_img_paths[50:60]):
|
||||||
|
img = load_img(test_img_path)
|
||||||
|
lowres_input = get_lowres_image(img, upscale_factor)
|
||||||
|
w = lowres_input.size[0] * upscale_factor
|
||||||
|
h = lowres_input.size[1] * upscale_factor
|
||||||
|
highres_img = img.resize((w, h))
|
||||||
|
prediction = upscale_image(model, lowres_input)
|
||||||
|
lowres_img = lowres_input.resize((w, h))
|
||||||
|
lowres_img_arr = img_to_array(lowres_img)
|
||||||
|
highres_img_arr = img_to_array(highres_img)
|
||||||
|
predict_img_arr = img_to_array(prediction)
|
||||||
|
bicubic_psnr = tf.image.psnr(lowres_img_arr, highres_img_arr, max_val=255)
|
||||||
|
test_psnr = tf.image.psnr(predict_img_arr, highres_img_arr, max_val=255)
|
||||||
|
|
||||||
|
total_bicubic_psnr += bicubic_psnr
|
||||||
|
total_test_psnr += test_psnr
|
||||||
|
|
||||||
|
print(
|
||||||
|
"PSNR of low resolution image and high resolution image is %.4f" % bicubic_psnr
|
||||||
|
)
|
||||||
|
print("PSNR of predict and high resolution is %.4f" % test_psnr)
|
||||||
|
plot_results(lowres_img, index, "lowres")
|
||||||
|
plot_results(highres_img, index, "highres")
|
||||||
|
plot_results(prediction, index, "prediction")
|
||||||
|
|
||||||
|
print("Avg. PSNR of lowres images is %.4f" % (total_bicubic_psnr / 10))
|
||||||
|
print("Avg. PSNR of reconstructions is %.4f" % (total_test_psnr / 10))
|
Loading…
Reference in New Issue
Block a user