Merge branch 'main' of github.com:keras-team/keras-core
This commit is contained in:
parent
ccb9db5b92
commit
16fc9cd173
492
examples/keras_io/tensorflow/vision/zero_dce.py
Normal file
492
examples/keras_io/tensorflow/vision/zero_dce.py
Normal file
@ -0,0 +1,492 @@
|
|||||||
|
"""
|
||||||
|
Title: Zero-DCE for low-light image enhancement
|
||||||
|
Author: [Soumik Rakshit](http://github.com/soumik12345)
|
||||||
|
Converted to Keras Core by: [Soumik Rakshit](http://github.com/soumik12345)
|
||||||
|
Date created: 2021/09/18
|
||||||
|
Last modified: 2023/07/15
|
||||||
|
Description: Implementing Zero-Reference Deep Curve Estimation for low-light image enhancement.
|
||||||
|
Accelerator: GPU
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
**Zero-Reference Deep Curve Estimation** or **Zero-DCE** formulates low-light image
|
||||||
|
enhancement as the task of estimating an image-specific
|
||||||
|
[*tonal curve*](https://en.wikipedia.org/wiki/Curve_(tonality)) with a deep neural network.
|
||||||
|
In this example, we train a lightweight deep network, **DCE-Net**, to estimate
|
||||||
|
pixel-wise and high-order tonal curves for dynamic range adjustment of a given image.
|
||||||
|
|
||||||
|
Zero-DCE takes a low-light image as input and produces high-order tonal curves as its output.
|
||||||
|
These curves are then used for pixel-wise adjustment on the dynamic range of the input to
|
||||||
|
obtain an enhanced image. The curve estimation process is done in such a way that it maintains
|
||||||
|
the range of the enhanced image and preserves the contrast of neighboring pixels. This
|
||||||
|
curve estimation is inspired by curves adjustment used in photo editing software such as
|
||||||
|
Adobe Photoshop where users can adjust points throughout an image’s tonal range.
|
||||||
|
|
||||||
|
Zero-DCE is appealing because of its relaxed assumptions with regard to reference images:
|
||||||
|
it does not require any input/output image pairs during training.
|
||||||
|
This is achieved through a set of carefully formulated non-reference loss functions,
|
||||||
|
which implicitly measure the enhancement quality and guide the training of the network.
|
||||||
|
|
||||||
|
### References
|
||||||
|
|
||||||
|
- [Zero-Reference Deep Curve Estimation for Low-Light Image Enhancement](https://arxiv.org/pdf/2001.06826.pdf)
|
||||||
|
- [Curves adjustment in Adobe Photoshop](https://helpx.adobe.com/photoshop/using/curves-adjustment.html)
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Downloading LOLDataset
|
||||||
|
|
||||||
|
The **LoL Dataset** has been created for low-light image enhancement. It provides 485
|
||||||
|
images for training and 15 for testing. Each image pair in the dataset consists of a
|
||||||
|
low-light input image and its corresponding well-exposed reference image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from glob import glob
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
import keras_core as keras
|
||||||
|
from keras_core import layers
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
"""shell
|
||||||
|
wget https://huggingface.co/datasets/geekyrakshit/LoL-Dataset/resolve/main/lol_dataset.zip
|
||||||
|
unzip -q lol_dataset.zip && rm lol_dataset.zip
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Creating a TensorFlow Dataset
|
||||||
|
|
||||||
|
We use 300 low-light images from the LoL Dataset training set for training, and we use
|
||||||
|
the remaining 185 low-light images for validation. We resize the images to size `256 x
|
||||||
|
256` to be used for both training and validation. Note that in order to train the DCE-Net,
|
||||||
|
we will not require the corresponding enhanced images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
IMAGE_SIZE = 256
|
||||||
|
BATCH_SIZE = 16
|
||||||
|
MAX_TRAIN_IMAGES = 400
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(image_path):
|
||||||
|
image = tf.io.read_file(image_path)
|
||||||
|
image = tf.image.decode_png(image, channels=3)
|
||||||
|
image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
|
||||||
|
image = image / 255.0
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def data_generator(low_light_images):
|
||||||
|
dataset = tf.data.Dataset.from_tensor_slices((low_light_images))
|
||||||
|
dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
|
||||||
|
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
train_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES]
|
||||||
|
val_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:]
|
||||||
|
test_low_light_images = sorted(glob("./lol_dataset/eval15/low/*"))
|
||||||
|
|
||||||
|
|
||||||
|
train_dataset = data_generator(train_low_light_images)
|
||||||
|
val_dataset = data_generator(val_low_light_images)
|
||||||
|
|
||||||
|
print("Train Dataset:", train_dataset)
|
||||||
|
print("Validation Dataset:", val_dataset)
|
||||||
|
|
||||||
|
"""
|
||||||
|
## The Zero-DCE Framework
|
||||||
|
|
||||||
|
The goal of DCE-Net is to estimate a set of best-fitting light-enhancement curves
|
||||||
|
(LE-curves) given an input image. The framework then maps all pixels of the input’s RGB
|
||||||
|
channels by applying the curves iteratively to obtain the final enhanced image.
|
||||||
|
|
||||||
|
### Understanding light-enhancement curves
|
||||||
|
|
||||||
|
A ligh-enhancement curve is a kind of curve that can map a low-light image
|
||||||
|
to its enhanced version automatically,
|
||||||
|
where the self-adaptive curve parameters are solely dependent on the input image.
|
||||||
|
When designing such a curve, three objectives should be taken into account:
|
||||||
|
|
||||||
|
- Each pixel value of the enhanced image should be in the normalized range `[0,1]`, in order to
|
||||||
|
avoid information loss induced by overflow truncation.
|
||||||
|
- It should be monotonous, to preserve the contrast between neighboring pixels.
|
||||||
|
- The shape of this curve should be as simple as possible,
|
||||||
|
and the curve should be differentiable to allow backpropagation.
|
||||||
|
|
||||||
|
The light-enhancement curve is separately applied to three RGB channels instead of solely on the
|
||||||
|
illumination channel. The three-channel adjustment can better preserve the inherent color and reduce
|
||||||
|
the risk of over-saturation.
|
||||||
|
|
||||||
|
![](https://li-chongyi.github.io/Zero-DCE_files/framework.png)
|
||||||
|
|
||||||
|
### DCE-Net
|
||||||
|
|
||||||
|
The DCE-Net is a lightweight deep neural network that learns the mapping between an input
|
||||||
|
image and its best-fitting curve parameter maps. The input to the DCE-Net is a low-light
|
||||||
|
image while the outputs are a set of pixel-wise curve parameter maps for corresponding
|
||||||
|
higher-order curves. It is a plain CNN of seven convolutional layers with symmetrical
|
||||||
|
concatenation. Each layer consists of 32 convolutional kernels of size 3×3 and stride 1
|
||||||
|
followed by the ReLU activation function. The last convolutional layer is followed by the
|
||||||
|
Tanh activation function, which produces 24 parameter maps for 8 iterations, where each
|
||||||
|
iteration requires three curve parameter maps for the three channels.
|
||||||
|
|
||||||
|
![](https://i.imgur.com/HtIg34W.png)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def build_dce_net():
|
||||||
|
input_img = keras.Input(shape=[None, None, 3])
|
||||||
|
conv1 = layers.Conv2D(
|
||||||
|
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
||||||
|
)(input_img)
|
||||||
|
conv2 = layers.Conv2D(
|
||||||
|
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
||||||
|
)(conv1)
|
||||||
|
conv3 = layers.Conv2D(
|
||||||
|
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
||||||
|
)(conv2)
|
||||||
|
conv4 = layers.Conv2D(
|
||||||
|
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
||||||
|
)(conv3)
|
||||||
|
int_con1 = layers.Concatenate(axis=-1)([conv4, conv3])
|
||||||
|
conv5 = layers.Conv2D(
|
||||||
|
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
||||||
|
)(int_con1)
|
||||||
|
int_con2 = layers.Concatenate(axis=-1)([conv5, conv2])
|
||||||
|
conv6 = layers.Conv2D(
|
||||||
|
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
||||||
|
)(int_con2)
|
||||||
|
int_con3 = layers.Concatenate(axis=-1)([conv6, conv1])
|
||||||
|
x_r = layers.Conv2D(24, (3, 3), strides=(1, 1), activation="tanh", padding="same")(
|
||||||
|
int_con3
|
||||||
|
)
|
||||||
|
return keras.Model(inputs=input_img, outputs=x_r)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Loss functions
|
||||||
|
|
||||||
|
To enable zero-reference learning in DCE-Net, we use a set of differentiable
|
||||||
|
zero-reference losses that allow us to evaluate the quality of enhanced images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
### Color constancy loss
|
||||||
|
|
||||||
|
The *color constancy loss* is used to correct the potential color deviations in the
|
||||||
|
enhanced image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def color_constancy_loss(x):
|
||||||
|
mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
|
||||||
|
mr, mg, mb = mean_rgb[:, :, :, 0], mean_rgb[:, :, :, 1], mean_rgb[:, :, :, 2]
|
||||||
|
d_rg = tf.square(mr - mg)
|
||||||
|
d_rb = tf.square(mr - mb)
|
||||||
|
d_gb = tf.square(mb - mg)
|
||||||
|
return tf.sqrt(tf.square(d_rg) + tf.square(d_rb) + tf.square(d_gb))
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
### Exposure loss
|
||||||
|
|
||||||
|
To restrain under-/over-exposed regions, we use the *exposure control loss*.
|
||||||
|
It measures the distance between the average intensity value of a local region
|
||||||
|
and a preset well-exposedness level (set to `0.6`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def exposure_loss(x, mean_val=0.6):
|
||||||
|
x = tf.reduce_mean(x, axis=3, keepdims=True)
|
||||||
|
mean = tf.nn.avg_pool2d(x, ksize=16, strides=16, padding="VALID")
|
||||||
|
return tf.reduce_mean(tf.square(mean - mean_val))
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
### Illumination smoothness loss
|
||||||
|
|
||||||
|
To preserve the monotonicity relations between neighboring pixels, the
|
||||||
|
*illumination smoothness loss* is added to each curve parameter map.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def illumination_smoothness_loss(x):
|
||||||
|
batch_size = tf.shape(x)[0]
|
||||||
|
h_x = tf.shape(x)[1]
|
||||||
|
w_x = tf.shape(x)[2]
|
||||||
|
count_h = (tf.shape(x)[2] - 1) * tf.shape(x)[3]
|
||||||
|
count_w = tf.shape(x)[2] * (tf.shape(x)[3] - 1)
|
||||||
|
h_tv = tf.reduce_sum(tf.square((x[:, 1:, :, :] - x[:, : h_x - 1, :, :])))
|
||||||
|
w_tv = tf.reduce_sum(tf.square((x[:, :, 1:, :] - x[:, :, : w_x - 1, :])))
|
||||||
|
batch_size = tf.cast(batch_size, dtype=tf.float32)
|
||||||
|
count_h = tf.cast(count_h, dtype=tf.float32)
|
||||||
|
count_w = tf.cast(count_w, dtype=tf.float32)
|
||||||
|
return 2 * (h_tv / count_h + w_tv / count_w) / batch_size
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
### Spatial consistency loss
|
||||||
|
|
||||||
|
The *spatial consistency loss* encourages spatial coherence of the enhanced image by
|
||||||
|
preserving the contrast between neighboring regions across the input image and its enhanced version.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialConsistencyLoss(keras.losses.Loss):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(reduction="none")
|
||||||
|
|
||||||
|
self.left_kernel = tf.constant(
|
||||||
|
[[[[0, 0, 0]], [[-1, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
|
||||||
|
)
|
||||||
|
self.right_kernel = tf.constant(
|
||||||
|
[[[[0, 0, 0]], [[0, 1, -1]], [[0, 0, 0]]]], dtype=tf.float32
|
||||||
|
)
|
||||||
|
self.up_kernel = tf.constant(
|
||||||
|
[[[[0, -1, 0]], [[0, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
|
||||||
|
)
|
||||||
|
self.down_kernel = tf.constant(
|
||||||
|
[[[[0, 0, 0]], [[0, 1, 0]], [[0, -1, 0]]]], dtype=tf.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
def call(self, y_true, y_pred):
|
||||||
|
original_mean = tf.reduce_mean(y_true, 3, keepdims=True)
|
||||||
|
enhanced_mean = tf.reduce_mean(y_pred, 3, keepdims=True)
|
||||||
|
original_pool = tf.nn.avg_pool2d(
|
||||||
|
original_mean, ksize=4, strides=4, padding="VALID"
|
||||||
|
)
|
||||||
|
enhanced_pool = tf.nn.avg_pool2d(
|
||||||
|
enhanced_mean, ksize=4, strides=4, padding="VALID"
|
||||||
|
)
|
||||||
|
|
||||||
|
d_original_left = tf.nn.conv2d(
|
||||||
|
original_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||||
|
)
|
||||||
|
d_original_right = tf.nn.conv2d(
|
||||||
|
original_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||||
|
)
|
||||||
|
d_original_up = tf.nn.conv2d(
|
||||||
|
original_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||||
|
)
|
||||||
|
d_original_down = tf.nn.conv2d(
|
||||||
|
original_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||||
|
)
|
||||||
|
|
||||||
|
d_enhanced_left = tf.nn.conv2d(
|
||||||
|
enhanced_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||||
|
)
|
||||||
|
d_enhanced_right = tf.nn.conv2d(
|
||||||
|
enhanced_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||||
|
)
|
||||||
|
d_enhanced_up = tf.nn.conv2d(
|
||||||
|
enhanced_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||||
|
)
|
||||||
|
d_enhanced_down = tf.nn.conv2d(
|
||||||
|
enhanced_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||||
|
)
|
||||||
|
|
||||||
|
d_left = tf.square(d_original_left - d_enhanced_left)
|
||||||
|
d_right = tf.square(d_original_right - d_enhanced_right)
|
||||||
|
d_up = tf.square(d_original_up - d_enhanced_up)
|
||||||
|
d_down = tf.square(d_original_down - d_enhanced_down)
|
||||||
|
return d_left + d_right + d_up + d_down
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
### Deep curve estimation model
|
||||||
|
|
||||||
|
We implement the Zero-DCE framework as a Keras subclassed model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroDCE(keras.Model):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.dce_model = build_dce_net()
|
||||||
|
|
||||||
|
def compile(self, learning_rate, **kwargs):
|
||||||
|
super().compile(**kwargs)
|
||||||
|
self.optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
|
||||||
|
self.spatial_constancy_loss = SpatialConsistencyLoss(reduction="none")
|
||||||
|
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
|
||||||
|
self.illumination_smoothness_loss_tracker = keras.metrics.Mean(name="illumination_smoothness_loss")
|
||||||
|
self.spatial_constancy_loss_tracker = keras.metrics.Mean(name="spatial_constancy_loss")
|
||||||
|
self.color_constancy_loss_tracker = keras.metrics.Mean(name="color_constancy_loss")
|
||||||
|
self.exposure_loss_tracker = keras.metrics.Mean(name="exposure_loss")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metrics(self):
|
||||||
|
return [
|
||||||
|
self.total_loss_tracker,
|
||||||
|
self.illumination_smoothness_loss_tracker,
|
||||||
|
self.spatial_constancy_loss_tracker,
|
||||||
|
self.color_constancy_loss_tracker,
|
||||||
|
self.exposure_loss_tracker,
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_enhanced_image(self, data, output):
|
||||||
|
r1 = output[:, :, :, :3]
|
||||||
|
r2 = output[:, :, :, 3:6]
|
||||||
|
r3 = output[:, :, :, 6:9]
|
||||||
|
r4 = output[:, :, :, 9:12]
|
||||||
|
r5 = output[:, :, :, 12:15]
|
||||||
|
r6 = output[:, :, :, 15:18]
|
||||||
|
r7 = output[:, :, :, 18:21]
|
||||||
|
r8 = output[:, :, :, 21:24]
|
||||||
|
x = data + r1 * (tf.square(data) - data)
|
||||||
|
x = x + r2 * (tf.square(x) - x)
|
||||||
|
x = x + r3 * (tf.square(x) - x)
|
||||||
|
enhanced_image = x + r4 * (tf.square(x) - x)
|
||||||
|
x = enhanced_image + r5 * (tf.square(enhanced_image) - enhanced_image)
|
||||||
|
x = x + r6 * (tf.square(x) - x)
|
||||||
|
x = x + r7 * (tf.square(x) - x)
|
||||||
|
enhanced_image = x + r8 * (tf.square(x) - x)
|
||||||
|
return enhanced_image
|
||||||
|
|
||||||
|
def call(self, data):
|
||||||
|
dce_net_output = self.dce_model(data)
|
||||||
|
return self.get_enhanced_image(data, dce_net_output)
|
||||||
|
|
||||||
|
def compute_losses(self, data, output):
|
||||||
|
enhanced_image = self.get_enhanced_image(data, output)
|
||||||
|
loss_illumination = 200 * illumination_smoothness_loss(output)
|
||||||
|
loss_spatial_constancy = tf.reduce_mean(
|
||||||
|
self.spatial_constancy_loss(enhanced_image, data)
|
||||||
|
)
|
||||||
|
loss_color_constancy = 5 * tf.reduce_mean(color_constancy_loss(enhanced_image))
|
||||||
|
loss_exposure = 10 * tf.reduce_mean(exposure_loss(enhanced_image))
|
||||||
|
total_loss = (
|
||||||
|
loss_illumination
|
||||||
|
+ loss_spatial_constancy
|
||||||
|
+ loss_color_constancy
|
||||||
|
+ loss_exposure
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_loss": total_loss,
|
||||||
|
"illumination_smoothness_loss": loss_illumination,
|
||||||
|
"spatial_constancy_loss": loss_spatial_constancy,
|
||||||
|
"color_constancy_loss": loss_color_constancy,
|
||||||
|
"exposure_loss": loss_exposure,
|
||||||
|
}
|
||||||
|
|
||||||
|
def train_step(self, data):
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
output = self.dce_model(data)
|
||||||
|
losses = self.compute_losses(data, output)
|
||||||
|
|
||||||
|
gradients = tape.gradient(
|
||||||
|
losses["total_loss"], self.dce_model.trainable_weights
|
||||||
|
)
|
||||||
|
self.optimizer.apply_gradients(zip(gradients, self.dce_model.trainable_weights))
|
||||||
|
|
||||||
|
self.total_loss_tracker.update_state(losses["total_loss"])
|
||||||
|
self.illumination_smoothness_loss_tracker.update_state(losses["illumination_smoothness_loss"])
|
||||||
|
self.spatial_constancy_loss_tracker.update_state(losses["spatial_constancy_loss"])
|
||||||
|
self.color_constancy_loss_tracker.update_state(losses["color_constancy_loss"])
|
||||||
|
self.exposure_loss_tracker.update_state(losses["exposure_loss"])
|
||||||
|
|
||||||
|
return {metric.name: metric.result() for metric in self.metrics}
|
||||||
|
|
||||||
|
def test_step(self, data):
|
||||||
|
output = self.dce_model(data)
|
||||||
|
losses = self.compute_losses(data, output)
|
||||||
|
|
||||||
|
self.total_loss_tracker.update_state(losses["total_loss"])
|
||||||
|
self.illumination_smoothness_loss_tracker.update_state(losses["illumination_smoothness_loss"])
|
||||||
|
self.spatial_constancy_loss_tracker.update_state(losses["spatial_constancy_loss"])
|
||||||
|
self.color_constancy_loss_tracker.update_state(losses["color_constancy_loss"])
|
||||||
|
self.exposure_loss_tracker.update_state(losses["exposure_loss"])
|
||||||
|
|
||||||
|
return {metric.name: metric.result() for metric in self.metrics}
|
||||||
|
|
||||||
|
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
|
||||||
|
"""While saving the weights, we simply save the weights of the DCE-Net"""
|
||||||
|
self.dce_model.save_weights(
|
||||||
|
filepath, overwrite=overwrite, save_format=save_format, options=options
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
|
||||||
|
"""While loading the weights, we simply load the weights of the DCE-Net"""
|
||||||
|
self.dce_model.load_weights(
|
||||||
|
filepath=filepath,
|
||||||
|
by_name=by_name,
|
||||||
|
skip_mismatch=skip_mismatch,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Training
|
||||||
|
"""
|
||||||
|
|
||||||
|
zero_dce_model = ZeroDCE()
|
||||||
|
zero_dce_model.compile(learning_rate=1e-4)
|
||||||
|
history = zero_dce_model.fit(train_dataset, validation_data=val_dataset, epochs=100)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_result(item):
|
||||||
|
plt.plot(history.history[item], label=item)
|
||||||
|
plt.plot(history.history["val_" + item], label="val_" + item)
|
||||||
|
plt.xlabel("Epochs")
|
||||||
|
plt.ylabel(item)
|
||||||
|
plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
|
||||||
|
plt.legend()
|
||||||
|
plt.grid()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
plot_result("total_loss")
|
||||||
|
plot_result("illumination_smoothness_loss")
|
||||||
|
plot_result("spatial_constancy_loss")
|
||||||
|
plot_result("color_constancy_loss")
|
||||||
|
plot_result("exposure_loss")
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Inference
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def plot_results(images, titles, figure_size=(12, 12)):
|
||||||
|
fig = plt.figure(figsize=figure_size)
|
||||||
|
for i in range(len(images)):
|
||||||
|
fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
|
||||||
|
_ = plt.imshow(images[i])
|
||||||
|
plt.axis("off")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def infer(original_image):
|
||||||
|
image = keras.utils.img_to_array(original_image)
|
||||||
|
image = image.astype("float32") / 255.0
|
||||||
|
image = np.expand_dims(image, axis=0)
|
||||||
|
output_image = zero_dce_model(image)
|
||||||
|
output_image = tf.cast((output_image[0, :, :, :] * 255), dtype=np.uint8)
|
||||||
|
output_image = Image.fromarray(output_image.numpy())
|
||||||
|
return output_image
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
### Inference on test images
|
||||||
|
|
||||||
|
We compare the test images from LOLDataset enhanced by MIRNet with images enhanced via
|
||||||
|
the `PIL.ImageOps.autocontrast()` function.
|
||||||
|
|
||||||
|
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/low-light-image-enhancement)
|
||||||
|
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/low-light-image-enhancement).
|
||||||
|
"""
|
||||||
|
|
||||||
|
for val_image_file in test_low_light_images:
|
||||||
|
original_image = Image.open(val_image_file)
|
||||||
|
enhanced_image = infer(original_image)
|
||||||
|
plot_results(
|
||||||
|
[original_image, ImageOps.autocontrast(original_image), enhanced_image],
|
||||||
|
["Original", "PIL Autocontrast", "Enhanced"],
|
||||||
|
(20, 12),
|
||||||
|
)
|
@ -35,6 +35,30 @@ class SegmentSum(Operation):
|
|||||||
|
|
||||||
@keras_core_export("keras_core.ops.segment_sum")
|
@keras_core_export("keras_core.ops.segment_sum")
|
||||||
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
|
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
|
||||||
|
"""Computes the sum of segments in a tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Input tensor.
|
||||||
|
segment_ids: A 1-D tensor containing segment indices for each
|
||||||
|
element in `data`.
|
||||||
|
num_segments: An integer representing the total number of
|
||||||
|
segments. If not specified, it is inferred from the maximum
|
||||||
|
value in `segment_ids`.
|
||||||
|
sorted: A boolean indicating whether `segment_ids` is sorted.
|
||||||
|
Default is `False`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor containing the sum of segments, where each element
|
||||||
|
represents the sum of the corresponding segment in `data`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> data = keras_core.ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
|
||||||
|
>>> segment_ids = keras_core.ops.convert_to_tensor([0, 1, 0, 1, 0, 1])
|
||||||
|
>>> segment_sum(data, segment_ids)
|
||||||
|
array([9 12], shape=(2,), dtype=int32)
|
||||||
|
```
|
||||||
|
"""
|
||||||
if any_symbolic_tensors((data,)):
|
if any_symbolic_tensors((data,)):
|
||||||
return SegmentSum(num_segments, sorted).symbolic_call(data, segment_ids)
|
return SegmentSum(num_segments, sorted).symbolic_call(data, segment_ids)
|
||||||
return backend.math.segment_sum(
|
return backend.math.segment_sum(
|
||||||
@ -63,6 +87,29 @@ class TopK(Operation):
|
|||||||
|
|
||||||
@keras_core_export("keras_core.ops.top_k")
|
@keras_core_export("keras_core.ops.top_k")
|
||||||
def top_k(x, k, sorted=True):
|
def top_k(x, k, sorted=True):
|
||||||
|
"""Finds the top-k values and their indices in a tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor.
|
||||||
|
k: An integer representing the number of top elements to retrieve.
|
||||||
|
sorted: A boolean indicating whether to sort the output in
|
||||||
|
descending order. Default is `True`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing two tensors. The first tensor contains the
|
||||||
|
top-k values, and the second tensor contains the indices of the
|
||||||
|
top-k values in the input tensor.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> x = keras_core.ops.convert_to_tensor([5, 2, 7, 1, 9, 3])
|
||||||
|
>>> values, indices = top_k(x, k=3)
|
||||||
|
>>> print(values)
|
||||||
|
array([9 7 5], shape=(3,), dtype=int32)
|
||||||
|
>>> print(indices)
|
||||||
|
array([4 2 0], shape=(3,), dtype=int32)
|
||||||
|
```
|
||||||
|
"""
|
||||||
if any_symbolic_tensors((x,)):
|
if any_symbolic_tensors((x,)):
|
||||||
return TopK(k, sorted).symbolic_call(x)
|
return TopK(k, sorted).symbolic_call(x)
|
||||||
return backend.math.top_k(x, k, sorted)
|
return backend.math.top_k(x, k, sorted)
|
||||||
@ -82,6 +129,28 @@ class InTopK(Operation):
|
|||||||
|
|
||||||
@keras_core_export("keras_core.ops.in_top_k")
|
@keras_core_export("keras_core.ops.in_top_k")
|
||||||
def in_top_k(targets, predictions, k):
|
def in_top_k(targets, predictions, k):
|
||||||
|
"""Checks if the targets are in the top-k predictions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
targets: A tensor of true labels.
|
||||||
|
predictions: A tensor of predicted labels.
|
||||||
|
k: An integer representing the number of predictions to consider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A boolean tensor of the same shape as `targets`, where each element
|
||||||
|
indicates whether the corresponding target is in the top-k predictions.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> targets = keras_core.ops.convert_to_tensor([2, 5, 3])
|
||||||
|
>>> predictions = keras_core.ops.convert_to_tensor(
|
||||||
|
[[0.1, 0.4, 0.6, 0.9, 0.5],
|
||||||
|
[0.1, 0.7, 0.9, 0.8, 0.3],
|
||||||
|
[0.1, 0.6, 0.9, 0.9, 0.5]])
|
||||||
|
>>> in_top_k(targets, predictions, k=3)
|
||||||
|
array([ True False True], shape=(3,), dtype=bool)
|
||||||
|
```
|
||||||
|
"""
|
||||||
if any_symbolic_tensors((targets, predictions)):
|
if any_symbolic_tensors((targets, predictions)):
|
||||||
return InTopK(k).symbolic_call(targets, predictions)
|
return InTopK(k).symbolic_call(targets, predictions)
|
||||||
return backend.math.in_top_k(targets, predictions, k)
|
return backend.math.in_top_k(targets, predictions, k)
|
||||||
@ -103,6 +172,27 @@ class Logsumexp(Operation):
|
|||||||
|
|
||||||
@keras_core_export("keras_core.ops.logsumexp")
|
@keras_core_export("keras_core.ops.logsumexp")
|
||||||
def logsumexp(x, axis=None, keepdims=False):
|
def logsumexp(x, axis=None, keepdims=False):
|
||||||
|
"""Computes the logarithm of sum of exponentials of elements in a tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor.
|
||||||
|
axis: An integer or a tuple of integers specifying the axis/axes
|
||||||
|
along which to compute the sum. If `None`, the sum is computed
|
||||||
|
over all elements. Default is `None`.
|
||||||
|
keepdims: A boolean indicating whether to keep the dimensions of
|
||||||
|
the input tensor when computing the sum. Default is `False`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor containing the logarithm of the sum of exponentials of
|
||||||
|
elements in `x`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> x = keras_core.ops.convert_to_tensor([1., 2., 3.])
|
||||||
|
>>> logsumexp(x)
|
||||||
|
array(3.407606, shape=(), dtype=float32)
|
||||||
|
```
|
||||||
|
"""
|
||||||
if any_symbolic_tensors((x,)):
|
if any_symbolic_tensors((x,)):
|
||||||
return Logsumexp(axis, keepdims).symbolic_call(x)
|
return Logsumexp(axis, keepdims).symbolic_call(x)
|
||||||
return backend.math.logsumexp(x, axis=axis, keepdims=keepdims)
|
return backend.math.logsumexp(x, axis=axis, keepdims=keepdims)
|
||||||
@ -152,6 +242,30 @@ class Qr(Operation):
|
|||||||
|
|
||||||
@keras_core_export("keras_core.ops.qr")
|
@keras_core_export("keras_core.ops.qr")
|
||||||
def qr(x, mode="reduced"):
|
def qr(x, mode="reduced"):
|
||||||
|
"""Computes the QR decomposition of a tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor.
|
||||||
|
mode: A string specifying the mode of the QR decomposition.
|
||||||
|
- 'reduced': Returns the reduced QR decomposition. (default)
|
||||||
|
- 'complete': Returns the complete QR decomposition.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing two tensors. The first tensor represents the
|
||||||
|
orthogonal matrix Q, and the second tensor represents the upper
|
||||||
|
triangular matrix R.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> x = keras_core.ops.convert_to_tensor([[1., 2.], [3., 4.], [5., 6.]])
|
||||||
|
>>> q, r = qr(x)
|
||||||
|
>>> print(q)
|
||||||
|
array([[-0.16903079 0.897085]
|
||||||
|
[-0.5070925 0.2760267 ]
|
||||||
|
[-0.8451542 -0.34503305]], shape=(3, 2), dtype=float32)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
if any_symbolic_tensors((x,)):
|
if any_symbolic_tensors((x,)):
|
||||||
return Qr(mode=mode).symbolic_call(x)
|
return Qr(mode=mode).symbolic_call(x)
|
||||||
return backend.math.qr(x, mode=mode)
|
return backend.math.qr(x, mode=mode)
|
||||||
|
@ -895,7 +895,7 @@ def conv_transpose(
|
|||||||
the output tensor. Can be a single integer to specify the same
|
the output tensor. Can be a single integer to specify the same
|
||||||
value for all spatial dimensions. The amount of output padding
|
value for all spatial dimensions. The amount of output padding
|
||||||
along a given dimension must be lower than the stride along that
|
along a given dimension must be lower than the stride along that
|
||||||
same dimension. If set to None (default), the output shape is
|
same dimension. If set to `None` (default), the output shape is
|
||||||
inferred.
|
inferred.
|
||||||
data_format: A string, either `"channels_last"` or `"channels_first"`.
|
data_format: A string, either `"channels_last"` or `"channels_first"`.
|
||||||
`data_format` determines the ordering of the dimensions in the
|
`data_format` determines the ordering of the dimensions in the
|
||||||
@ -955,6 +955,38 @@ class OneHot(Operation):
|
|||||||
|
|
||||||
@keras_core_export(["keras_core.ops.one_hot", "keras_core.ops.nn.one_hot"])
|
@keras_core_export(["keras_core.ops.one_hot", "keras_core.ops.nn.one_hot"])
|
||||||
def one_hot(x, num_classes, axis=-1, dtype=None):
|
def one_hot(x, num_classes, axis=-1, dtype=None):
|
||||||
|
"""Converts integer tensor `x` into a one-hot tensor.
|
||||||
|
|
||||||
|
The one-hot encoding is a representation where each integer value is
|
||||||
|
converted into a binary vector with a length equal to `num_classes`,
|
||||||
|
and the index corresponding to the integer value is marked as 1, while
|
||||||
|
all other indices are marked as 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x : Integer tensor to be encoded. The shape can be
|
||||||
|
arbitrary, but the dtype should be integer.
|
||||||
|
num_classes: Number of classes for the one-hot encoding.
|
||||||
|
axis: Axis along which the encoding is performed. Default is
|
||||||
|
-1, which represents the last axis.
|
||||||
|
dtype: (Optional) Data type of the output tensor. If not
|
||||||
|
provided, it defaults to the default data type of the backend.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Integer tensor: One-hot encoded tensor with the same shape as `x`
|
||||||
|
except for the specified `axis` dimension, which will have
|
||||||
|
a length of `num_classes`. The dtype of the output tensor
|
||||||
|
is determined by `dtype` or the default data type of the backend.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> x = keras_core.ops.convert_to_tensor([1, 3, 2, 0])
|
||||||
|
>>> one_hot(x, num_classes=4)
|
||||||
|
array([[0. 1. 0. 0.]
|
||||||
|
[0. 0. 0. 1.]
|
||||||
|
[0. 0. 1. 0.]
|
||||||
|
[1. 0. 0. 0.]], shape=(4, 4), dtype=float32)
|
||||||
|
```
|
||||||
|
"""
|
||||||
if any_symbolic_tensors((x,)):
|
if any_symbolic_tensors((x,)):
|
||||||
return OneHot(num_classes, axis=axis, dtype=dtype).symbolic_call(x)
|
return OneHot(num_classes, axis=axis, dtype=dtype).symbolic_call(x)
|
||||||
return backend.nn.one_hot(
|
return backend.nn.one_hot(
|
||||||
@ -989,6 +1021,40 @@ class BinaryCrossentropy(Operation):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
def binary_crossentropy(target, output, from_logits=False):
|
def binary_crossentropy(target, output, from_logits=False):
|
||||||
|
"""Computes binary cross-entropy loss between target and output tensor.
|
||||||
|
|
||||||
|
The binary cross-entropy loss is commonly used in binary
|
||||||
|
classification tasks where each input sample belongs to one
|
||||||
|
of the two classes. It measures the dissimilarity between the
|
||||||
|
target and output probabilities or logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: The target tensor representing the true binary labels.
|
||||||
|
Its shape should match the shape of the `output` tensor.
|
||||||
|
output: The output tensor representing the predicted probabilities
|
||||||
|
or logits. Its shape should match the shape of the
|
||||||
|
`target` tensor.
|
||||||
|
from_logits: (optional) Whether `output` is a tensor of logits or
|
||||||
|
probabilities.
|
||||||
|
Set it to `True` if `output` represents logits; otherwise,
|
||||||
|
set it to `False` if `output` represents probabilities.
|
||||||
|
Default is `False`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Integer tensor: The computed binary cross-entropy loss between
|
||||||
|
`target` and `output`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> target = keras_core.ops.convert_to_tensor([0, 1, 1, 0],
|
||||||
|
dtype=float32)
|
||||||
|
>>> output = keras_core.ops.convert_to_tensor([0.1, 0.9, 0.8, 0.2],
|
||||||
|
dtype=float32)
|
||||||
|
>>> binary_crossentropy(target, output)
|
||||||
|
array([0.10536054 0.10536054 0.22314355 0.22314355],
|
||||||
|
shape=(4,), dtype=float32)
|
||||||
|
```
|
||||||
|
"""
|
||||||
if any_symbolic_tensors((target, output)):
|
if any_symbolic_tensors((target, output)):
|
||||||
return BinaryCrossentropy(from_logits=from_logits).symbolic_call(
|
return BinaryCrossentropy(from_logits=from_logits).symbolic_call(
|
||||||
target, output
|
target, output
|
||||||
@ -1032,6 +1098,48 @@ class CategoricalCrossentropy(Operation):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||||
|
"""Computes categorical cross-entropy loss between target and output tensor.
|
||||||
|
|
||||||
|
The categorical cross-entropy loss is commonly used in multi-class
|
||||||
|
classification tasks where each input sample can belong to one of
|
||||||
|
multiple classes. It measures the dissimilarity
|
||||||
|
between the target and output probabilities or logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: The target tensor representing the true categorical labels.
|
||||||
|
Its shape should match the shape of the `output` tensor
|
||||||
|
except for the last dimension.
|
||||||
|
output: The output tensor representing the predicted probabilities
|
||||||
|
or logits. Its shape should match the shape of the `target`
|
||||||
|
tensor except for the last dimension.
|
||||||
|
from_logits: (optional) Whether `output` is a tensor of logits or
|
||||||
|
probabilities.
|
||||||
|
Set it to `True` if `output` represents logits; otherwise,
|
||||||
|
set it to `False` if `output` represents probabilities.
|
||||||
|
Default is `False`.
|
||||||
|
axis: (optional) The axis along which the categorical cross-entropy
|
||||||
|
is computed.
|
||||||
|
Default is -1, which corresponds to the last dimension of
|
||||||
|
the tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Integer tensor: The computed categorical cross-entropy loss between
|
||||||
|
`target` and `output`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> target = keras_core.ops.convert_to_tensor([[1, 0, 0],
|
||||||
|
[0, 1, 0],
|
||||||
|
[0, 0, 1]],
|
||||||
|
dtype=float32)
|
||||||
|
>>> output = keras_core.ops.convert_to_tensor([[0.9, 0.05, 0.05],
|
||||||
|
[0.1, 0.8, 0.1],
|
||||||
|
[0.2, 0.3, 0.5]],
|
||||||
|
dtype=float32)
|
||||||
|
>>> categorical_crossentropy(target, output)
|
||||||
|
array([0.10536054 0.22314355 0.6931472 ], shape=(3,), dtype=float32)
|
||||||
|
```
|
||||||
|
"""
|
||||||
if any_symbolic_tensors((target, output)):
|
if any_symbolic_tensors((target, output)):
|
||||||
return CategoricalCrossentropy(
|
return CategoricalCrossentropy(
|
||||||
from_logits=from_logits, axis=axis
|
from_logits=from_logits, axis=axis
|
||||||
@ -1078,6 +1186,46 @@ class SparseCategoricalCrossentropy(Operation):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||||
|
"""Computes sparse categorical cross-entropy loss.
|
||||||
|
|
||||||
|
The sparse categorical cross-entropy loss is similar to categorical
|
||||||
|
cross-entropy, but it is used when the target tensor contains integer
|
||||||
|
class labels instead of one-hot encoded vectors. It measures the
|
||||||
|
dissimilarity between the target and output probabilities or logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: The target tensor representing the true class labels as integers.
|
||||||
|
Its shape should match the shape of the `output` tensor except
|
||||||
|
for the last dimension.
|
||||||
|
output: The output tensor representing the predicted probabilities
|
||||||
|
or logits.
|
||||||
|
Its shape should match the shape of the `target` tensor except
|
||||||
|
for the last dimension.
|
||||||
|
from_logits: (optional) Whether `output` is a tensor of logits
|
||||||
|
or probabilities.
|
||||||
|
Set it to `True` if `output` represents logits; otherwise,
|
||||||
|
set it to `False` if `output` represents probabilities.
|
||||||
|
Default is `False`.
|
||||||
|
axis: (optional) The axis along which the sparse categorical
|
||||||
|
cross-entropy is computed.
|
||||||
|
Default is -1, which corresponds to the last dimension
|
||||||
|
of the tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Integer tensor: The computed sparse categorical cross-entropy
|
||||||
|
loss between `target` and `output`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> target = keras_core.ops.convert_to_tensor([0, 1, 2], dtype=int32)
|
||||||
|
>>> output = keras_core.ops.convert_to_tensor([[0.9, 0.05, 0.05],
|
||||||
|
[0.1, 0.8, 0.1],
|
||||||
|
[0.2, 0.3, 0.5]],
|
||||||
|
dtype=float32)
|
||||||
|
>>> sparse_categorical_crossentropy(target, output)
|
||||||
|
array([0.10536056 0.22314355 0.6931472 ], shape=(3,), dtype=float32)
|
||||||
|
```
|
||||||
|
"""
|
||||||
if any_symbolic_tensors((target, output)):
|
if any_symbolic_tensors((target, output)):
|
||||||
return SparseCategoricalCrossentropy(
|
return SparseCategoricalCrossentropy(
|
||||||
from_logits=from_logits, axis=axis
|
from_logits=from_logits, axis=axis
|
||||||
|
Loading…
Reference in New Issue
Block a user