Merge branch 'main' of github.com:keras-team/keras-core
This commit is contained in:
parent
ccb9db5b92
commit
16fc9cd173
492
examples/keras_io/tensorflow/vision/zero_dce.py
Normal file
492
examples/keras_io/tensorflow/vision/zero_dce.py
Normal file
@ -0,0 +1,492 @@
|
||||
"""
|
||||
Title: Zero-DCE for low-light image enhancement
|
||||
Author: [Soumik Rakshit](http://github.com/soumik12345)
|
||||
Converted to Keras Core by: [Soumik Rakshit](http://github.com/soumik12345)
|
||||
Date created: 2021/09/18
|
||||
Last modified: 2023/07/15
|
||||
Description: Implementing Zero-Reference Deep Curve Estimation for low-light image enhancement.
|
||||
Accelerator: GPU
|
||||
"""
|
||||
"""
|
||||
## Introduction
|
||||
|
||||
**Zero-Reference Deep Curve Estimation** or **Zero-DCE** formulates low-light image
|
||||
enhancement as the task of estimating an image-specific
|
||||
[*tonal curve*](https://en.wikipedia.org/wiki/Curve_(tonality)) with a deep neural network.
|
||||
In this example, we train a lightweight deep network, **DCE-Net**, to estimate
|
||||
pixel-wise and high-order tonal curves for dynamic range adjustment of a given image.
|
||||
|
||||
Zero-DCE takes a low-light image as input and produces high-order tonal curves as its output.
|
||||
These curves are then used for pixel-wise adjustment on the dynamic range of the input to
|
||||
obtain an enhanced image. The curve estimation process is done in such a way that it maintains
|
||||
the range of the enhanced image and preserves the contrast of neighboring pixels. This
|
||||
curve estimation is inspired by curves adjustment used in photo editing software such as
|
||||
Adobe Photoshop where users can adjust points throughout an image’s tonal range.
|
||||
|
||||
Zero-DCE is appealing because of its relaxed assumptions with regard to reference images:
|
||||
it does not require any input/output image pairs during training.
|
||||
This is achieved through a set of carefully formulated non-reference loss functions,
|
||||
which implicitly measure the enhancement quality and guide the training of the network.
|
||||
|
||||
### References
|
||||
|
||||
- [Zero-Reference Deep Curve Estimation for Low-Light Image Enhancement](https://arxiv.org/pdf/2001.06826.pdf)
|
||||
- [Curves adjustment in Adobe Photoshop](https://helpx.adobe.com/photoshop/using/curves-adjustment.html)
|
||||
"""
|
||||
|
||||
"""
|
||||
## Downloading LOLDataset
|
||||
|
||||
The **LoL Dataset** has been created for low-light image enhancement. It provides 485
|
||||
images for training and 15 for testing. Each image pair in the dataset consists of a
|
||||
low-light input image and its corresponding well-exposed reference image.
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from PIL import Image, ImageOps
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import keras_core as keras
|
||||
from keras_core import layers
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
"""shell
|
||||
wget https://huggingface.co/datasets/geekyrakshit/LoL-Dataset/resolve/main/lol_dataset.zip
|
||||
unzip -q lol_dataset.zip && rm lol_dataset.zip
|
||||
"""
|
||||
|
||||
"""
|
||||
## Creating a TensorFlow Dataset
|
||||
|
||||
We use 300 low-light images from the LoL Dataset training set for training, and we use
|
||||
the remaining 185 low-light images for validation. We resize the images to size `256 x
|
||||
256` to be used for both training and validation. Note that in order to train the DCE-Net,
|
||||
we will not require the corresponding enhanced images.
|
||||
"""
|
||||
|
||||
IMAGE_SIZE = 256
|
||||
BATCH_SIZE = 16
|
||||
MAX_TRAIN_IMAGES = 400
|
||||
|
||||
|
||||
def load_data(image_path):
|
||||
image = tf.io.read_file(image_path)
|
||||
image = tf.image.decode_png(image, channels=3)
|
||||
image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
|
||||
image = image / 255.0
|
||||
return image
|
||||
|
||||
|
||||
def data_generator(low_light_images):
|
||||
dataset = tf.data.Dataset.from_tensor_slices((low_light_images))
|
||||
dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
|
||||
return dataset
|
||||
|
||||
|
||||
train_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES]
|
||||
val_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:]
|
||||
test_low_light_images = sorted(glob("./lol_dataset/eval15/low/*"))
|
||||
|
||||
|
||||
train_dataset = data_generator(train_low_light_images)
|
||||
val_dataset = data_generator(val_low_light_images)
|
||||
|
||||
print("Train Dataset:", train_dataset)
|
||||
print("Validation Dataset:", val_dataset)
|
||||
|
||||
"""
|
||||
## The Zero-DCE Framework
|
||||
|
||||
The goal of DCE-Net is to estimate a set of best-fitting light-enhancement curves
|
||||
(LE-curves) given an input image. The framework then maps all pixels of the input’s RGB
|
||||
channels by applying the curves iteratively to obtain the final enhanced image.
|
||||
|
||||
### Understanding light-enhancement curves
|
||||
|
||||
A ligh-enhancement curve is a kind of curve that can map a low-light image
|
||||
to its enhanced version automatically,
|
||||
where the self-adaptive curve parameters are solely dependent on the input image.
|
||||
When designing such a curve, three objectives should be taken into account:
|
||||
|
||||
- Each pixel value of the enhanced image should be in the normalized range `[0,1]`, in order to
|
||||
avoid information loss induced by overflow truncation.
|
||||
- It should be monotonous, to preserve the contrast between neighboring pixels.
|
||||
- The shape of this curve should be as simple as possible,
|
||||
and the curve should be differentiable to allow backpropagation.
|
||||
|
||||
The light-enhancement curve is separately applied to three RGB channels instead of solely on the
|
||||
illumination channel. The three-channel adjustment can better preserve the inherent color and reduce
|
||||
the risk of over-saturation.
|
||||
|
||||
![](https://li-chongyi.github.io/Zero-DCE_files/framework.png)
|
||||
|
||||
### DCE-Net
|
||||
|
||||
The DCE-Net is a lightweight deep neural network that learns the mapping between an input
|
||||
image and its best-fitting curve parameter maps. The input to the DCE-Net is a low-light
|
||||
image while the outputs are a set of pixel-wise curve parameter maps for corresponding
|
||||
higher-order curves. It is a plain CNN of seven convolutional layers with symmetrical
|
||||
concatenation. Each layer consists of 32 convolutional kernels of size 3×3 and stride 1
|
||||
followed by the ReLU activation function. The last convolutional layer is followed by the
|
||||
Tanh activation function, which produces 24 parameter maps for 8 iterations, where each
|
||||
iteration requires three curve parameter maps for the three channels.
|
||||
|
||||
![](https://i.imgur.com/HtIg34W.png)
|
||||
"""
|
||||
|
||||
|
||||
def build_dce_net():
|
||||
input_img = keras.Input(shape=[None, None, 3])
|
||||
conv1 = layers.Conv2D(
|
||||
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
||||
)(input_img)
|
||||
conv2 = layers.Conv2D(
|
||||
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
||||
)(conv1)
|
||||
conv3 = layers.Conv2D(
|
||||
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
||||
)(conv2)
|
||||
conv4 = layers.Conv2D(
|
||||
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
||||
)(conv3)
|
||||
int_con1 = layers.Concatenate(axis=-1)([conv4, conv3])
|
||||
conv5 = layers.Conv2D(
|
||||
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
||||
)(int_con1)
|
||||
int_con2 = layers.Concatenate(axis=-1)([conv5, conv2])
|
||||
conv6 = layers.Conv2D(
|
||||
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
||||
)(int_con2)
|
||||
int_con3 = layers.Concatenate(axis=-1)([conv6, conv1])
|
||||
x_r = layers.Conv2D(24, (3, 3), strides=(1, 1), activation="tanh", padding="same")(
|
||||
int_con3
|
||||
)
|
||||
return keras.Model(inputs=input_img, outputs=x_r)
|
||||
|
||||
|
||||
"""
|
||||
## Loss functions
|
||||
|
||||
To enable zero-reference learning in DCE-Net, we use a set of differentiable
|
||||
zero-reference losses that allow us to evaluate the quality of enhanced images.
|
||||
"""
|
||||
|
||||
"""
|
||||
### Color constancy loss
|
||||
|
||||
The *color constancy loss* is used to correct the potential color deviations in the
|
||||
enhanced image.
|
||||
"""
|
||||
|
||||
|
||||
def color_constancy_loss(x):
|
||||
mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
|
||||
mr, mg, mb = mean_rgb[:, :, :, 0], mean_rgb[:, :, :, 1], mean_rgb[:, :, :, 2]
|
||||
d_rg = tf.square(mr - mg)
|
||||
d_rb = tf.square(mr - mb)
|
||||
d_gb = tf.square(mb - mg)
|
||||
return tf.sqrt(tf.square(d_rg) + tf.square(d_rb) + tf.square(d_gb))
|
||||
|
||||
|
||||
"""
|
||||
### Exposure loss
|
||||
|
||||
To restrain under-/over-exposed regions, we use the *exposure control loss*.
|
||||
It measures the distance between the average intensity value of a local region
|
||||
and a preset well-exposedness level (set to `0.6`).
|
||||
"""
|
||||
|
||||
|
||||
def exposure_loss(x, mean_val=0.6):
|
||||
x = tf.reduce_mean(x, axis=3, keepdims=True)
|
||||
mean = tf.nn.avg_pool2d(x, ksize=16, strides=16, padding="VALID")
|
||||
return tf.reduce_mean(tf.square(mean - mean_val))
|
||||
|
||||
|
||||
"""
|
||||
### Illumination smoothness loss
|
||||
|
||||
To preserve the monotonicity relations between neighboring pixels, the
|
||||
*illumination smoothness loss* is added to each curve parameter map.
|
||||
"""
|
||||
|
||||
|
||||
def illumination_smoothness_loss(x):
|
||||
batch_size = tf.shape(x)[0]
|
||||
h_x = tf.shape(x)[1]
|
||||
w_x = tf.shape(x)[2]
|
||||
count_h = (tf.shape(x)[2] - 1) * tf.shape(x)[3]
|
||||
count_w = tf.shape(x)[2] * (tf.shape(x)[3] - 1)
|
||||
h_tv = tf.reduce_sum(tf.square((x[:, 1:, :, :] - x[:, : h_x - 1, :, :])))
|
||||
w_tv = tf.reduce_sum(tf.square((x[:, :, 1:, :] - x[:, :, : w_x - 1, :])))
|
||||
batch_size = tf.cast(batch_size, dtype=tf.float32)
|
||||
count_h = tf.cast(count_h, dtype=tf.float32)
|
||||
count_w = tf.cast(count_w, dtype=tf.float32)
|
||||
return 2 * (h_tv / count_h + w_tv / count_w) / batch_size
|
||||
|
||||
|
||||
"""
|
||||
### Spatial consistency loss
|
||||
|
||||
The *spatial consistency loss* encourages spatial coherence of the enhanced image by
|
||||
preserving the contrast between neighboring regions across the input image and its enhanced version.
|
||||
"""
|
||||
|
||||
|
||||
class SpatialConsistencyLoss(keras.losses.Loss):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(reduction="none")
|
||||
|
||||
self.left_kernel = tf.constant(
|
||||
[[[[0, 0, 0]], [[-1, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
|
||||
)
|
||||
self.right_kernel = tf.constant(
|
||||
[[[[0, 0, 0]], [[0, 1, -1]], [[0, 0, 0]]]], dtype=tf.float32
|
||||
)
|
||||
self.up_kernel = tf.constant(
|
||||
[[[[0, -1, 0]], [[0, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
|
||||
)
|
||||
self.down_kernel = tf.constant(
|
||||
[[[[0, 0, 0]], [[0, 1, 0]], [[0, -1, 0]]]], dtype=tf.float32
|
||||
)
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
original_mean = tf.reduce_mean(y_true, 3, keepdims=True)
|
||||
enhanced_mean = tf.reduce_mean(y_pred, 3, keepdims=True)
|
||||
original_pool = tf.nn.avg_pool2d(
|
||||
original_mean, ksize=4, strides=4, padding="VALID"
|
||||
)
|
||||
enhanced_pool = tf.nn.avg_pool2d(
|
||||
enhanced_mean, ksize=4, strides=4, padding="VALID"
|
||||
)
|
||||
|
||||
d_original_left = tf.nn.conv2d(
|
||||
original_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||
)
|
||||
d_original_right = tf.nn.conv2d(
|
||||
original_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||
)
|
||||
d_original_up = tf.nn.conv2d(
|
||||
original_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||
)
|
||||
d_original_down = tf.nn.conv2d(
|
||||
original_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||
)
|
||||
|
||||
d_enhanced_left = tf.nn.conv2d(
|
||||
enhanced_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||
)
|
||||
d_enhanced_right = tf.nn.conv2d(
|
||||
enhanced_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||
)
|
||||
d_enhanced_up = tf.nn.conv2d(
|
||||
enhanced_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||
)
|
||||
d_enhanced_down = tf.nn.conv2d(
|
||||
enhanced_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
||||
)
|
||||
|
||||
d_left = tf.square(d_original_left - d_enhanced_left)
|
||||
d_right = tf.square(d_original_right - d_enhanced_right)
|
||||
d_up = tf.square(d_original_up - d_enhanced_up)
|
||||
d_down = tf.square(d_original_down - d_enhanced_down)
|
||||
return d_left + d_right + d_up + d_down
|
||||
|
||||
|
||||
"""
|
||||
### Deep curve estimation model
|
||||
|
||||
We implement the Zero-DCE framework as a Keras subclassed model.
|
||||
"""
|
||||
|
||||
|
||||
class ZeroDCE(keras.Model):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.dce_model = build_dce_net()
|
||||
|
||||
def compile(self, learning_rate, **kwargs):
|
||||
super().compile(**kwargs)
|
||||
self.optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
|
||||
self.spatial_constancy_loss = SpatialConsistencyLoss(reduction="none")
|
||||
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
|
||||
self.illumination_smoothness_loss_tracker = keras.metrics.Mean(name="illumination_smoothness_loss")
|
||||
self.spatial_constancy_loss_tracker = keras.metrics.Mean(name="spatial_constancy_loss")
|
||||
self.color_constancy_loss_tracker = keras.metrics.Mean(name="color_constancy_loss")
|
||||
self.exposure_loss_tracker = keras.metrics.Mean(name="exposure_loss")
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
return [
|
||||
self.total_loss_tracker,
|
||||
self.illumination_smoothness_loss_tracker,
|
||||
self.spatial_constancy_loss_tracker,
|
||||
self.color_constancy_loss_tracker,
|
||||
self.exposure_loss_tracker,
|
||||
]
|
||||
|
||||
def get_enhanced_image(self, data, output):
|
||||
r1 = output[:, :, :, :3]
|
||||
r2 = output[:, :, :, 3:6]
|
||||
r3 = output[:, :, :, 6:9]
|
||||
r4 = output[:, :, :, 9:12]
|
||||
r5 = output[:, :, :, 12:15]
|
||||
r6 = output[:, :, :, 15:18]
|
||||
r7 = output[:, :, :, 18:21]
|
||||
r8 = output[:, :, :, 21:24]
|
||||
x = data + r1 * (tf.square(data) - data)
|
||||
x = x + r2 * (tf.square(x) - x)
|
||||
x = x + r3 * (tf.square(x) - x)
|
||||
enhanced_image = x + r4 * (tf.square(x) - x)
|
||||
x = enhanced_image + r5 * (tf.square(enhanced_image) - enhanced_image)
|
||||
x = x + r6 * (tf.square(x) - x)
|
||||
x = x + r7 * (tf.square(x) - x)
|
||||
enhanced_image = x + r8 * (tf.square(x) - x)
|
||||
return enhanced_image
|
||||
|
||||
def call(self, data):
|
||||
dce_net_output = self.dce_model(data)
|
||||
return self.get_enhanced_image(data, dce_net_output)
|
||||
|
||||
def compute_losses(self, data, output):
|
||||
enhanced_image = self.get_enhanced_image(data, output)
|
||||
loss_illumination = 200 * illumination_smoothness_loss(output)
|
||||
loss_spatial_constancy = tf.reduce_mean(
|
||||
self.spatial_constancy_loss(enhanced_image, data)
|
||||
)
|
||||
loss_color_constancy = 5 * tf.reduce_mean(color_constancy_loss(enhanced_image))
|
||||
loss_exposure = 10 * tf.reduce_mean(exposure_loss(enhanced_image))
|
||||
total_loss = (
|
||||
loss_illumination
|
||||
+ loss_spatial_constancy
|
||||
+ loss_color_constancy
|
||||
+ loss_exposure
|
||||
)
|
||||
|
||||
return {
|
||||
"total_loss": total_loss,
|
||||
"illumination_smoothness_loss": loss_illumination,
|
||||
"spatial_constancy_loss": loss_spatial_constancy,
|
||||
"color_constancy_loss": loss_color_constancy,
|
||||
"exposure_loss": loss_exposure,
|
||||
}
|
||||
|
||||
def train_step(self, data):
|
||||
with tf.GradientTape() as tape:
|
||||
output = self.dce_model(data)
|
||||
losses = self.compute_losses(data, output)
|
||||
|
||||
gradients = tape.gradient(
|
||||
losses["total_loss"], self.dce_model.trainable_weights
|
||||
)
|
||||
self.optimizer.apply_gradients(zip(gradients, self.dce_model.trainable_weights))
|
||||
|
||||
self.total_loss_tracker.update_state(losses["total_loss"])
|
||||
self.illumination_smoothness_loss_tracker.update_state(losses["illumination_smoothness_loss"])
|
||||
self.spatial_constancy_loss_tracker.update_state(losses["spatial_constancy_loss"])
|
||||
self.color_constancy_loss_tracker.update_state(losses["color_constancy_loss"])
|
||||
self.exposure_loss_tracker.update_state(losses["exposure_loss"])
|
||||
|
||||
return {metric.name: metric.result() for metric in self.metrics}
|
||||
|
||||
def test_step(self, data):
|
||||
output = self.dce_model(data)
|
||||
losses = self.compute_losses(data, output)
|
||||
|
||||
self.total_loss_tracker.update_state(losses["total_loss"])
|
||||
self.illumination_smoothness_loss_tracker.update_state(losses["illumination_smoothness_loss"])
|
||||
self.spatial_constancy_loss_tracker.update_state(losses["spatial_constancy_loss"])
|
||||
self.color_constancy_loss_tracker.update_state(losses["color_constancy_loss"])
|
||||
self.exposure_loss_tracker.update_state(losses["exposure_loss"])
|
||||
|
||||
return {metric.name: metric.result() for metric in self.metrics}
|
||||
|
||||
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
|
||||
"""While saving the weights, we simply save the weights of the DCE-Net"""
|
||||
self.dce_model.save_weights(
|
||||
filepath, overwrite=overwrite, save_format=save_format, options=options
|
||||
)
|
||||
|
||||
def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
|
||||
"""While loading the weights, we simply load the weights of the DCE-Net"""
|
||||
self.dce_model.load_weights(
|
||||
filepath=filepath,
|
||||
by_name=by_name,
|
||||
skip_mismatch=skip_mismatch,
|
||||
options=options,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
## Training
|
||||
"""
|
||||
|
||||
zero_dce_model = ZeroDCE()
|
||||
zero_dce_model.compile(learning_rate=1e-4)
|
||||
history = zero_dce_model.fit(train_dataset, validation_data=val_dataset, epochs=100)
|
||||
|
||||
|
||||
def plot_result(item):
|
||||
plt.plot(history.history[item], label=item)
|
||||
plt.plot(history.history["val_" + item], label="val_" + item)
|
||||
plt.xlabel("Epochs")
|
||||
plt.ylabel(item)
|
||||
plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
plt.show()
|
||||
|
||||
|
||||
plot_result("total_loss")
|
||||
plot_result("illumination_smoothness_loss")
|
||||
plot_result("spatial_constancy_loss")
|
||||
plot_result("color_constancy_loss")
|
||||
plot_result("exposure_loss")
|
||||
|
||||
"""
|
||||
## Inference
|
||||
"""
|
||||
|
||||
|
||||
def plot_results(images, titles, figure_size=(12, 12)):
|
||||
fig = plt.figure(figsize=figure_size)
|
||||
for i in range(len(images)):
|
||||
fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
|
||||
_ = plt.imshow(images[i])
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
|
||||
def infer(original_image):
|
||||
image = keras.utils.img_to_array(original_image)
|
||||
image = image.astype("float32") / 255.0
|
||||
image = np.expand_dims(image, axis=0)
|
||||
output_image = zero_dce_model(image)
|
||||
output_image = tf.cast((output_image[0, :, :, :] * 255), dtype=np.uint8)
|
||||
output_image = Image.fromarray(output_image.numpy())
|
||||
return output_image
|
||||
|
||||
|
||||
"""
|
||||
### Inference on test images
|
||||
|
||||
We compare the test images from LOLDataset enhanced by MIRNet with images enhanced via
|
||||
the `PIL.ImageOps.autocontrast()` function.
|
||||
|
||||
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/low-light-image-enhancement)
|
||||
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/low-light-image-enhancement).
|
||||
"""
|
||||
|
||||
for val_image_file in test_low_light_images:
|
||||
original_image = Image.open(val_image_file)
|
||||
enhanced_image = infer(original_image)
|
||||
plot_results(
|
||||
[original_image, ImageOps.autocontrast(original_image), enhanced_image],
|
||||
["Original", "PIL Autocontrast", "Enhanced"],
|
||||
(20, 12),
|
||||
)
|
@ -35,6 +35,30 @@ class SegmentSum(Operation):
|
||||
|
||||
@keras_core_export("keras_core.ops.segment_sum")
|
||||
def segment_sum(data, segment_ids, num_segments=None, sorted=False):
|
||||
"""Computes the sum of segments in a tensor.
|
||||
|
||||
Args:
|
||||
data: Input tensor.
|
||||
segment_ids: A 1-D tensor containing segment indices for each
|
||||
element in `data`.
|
||||
num_segments: An integer representing the total number of
|
||||
segments. If not specified, it is inferred from the maximum
|
||||
value in `segment_ids`.
|
||||
sorted: A boolean indicating whether `segment_ids` is sorted.
|
||||
Default is `False`.
|
||||
|
||||
Returns:
|
||||
A tensor containing the sum of segments, where each element
|
||||
represents the sum of the corresponding segment in `data`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> data = keras_core.ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
|
||||
>>> segment_ids = keras_core.ops.convert_to_tensor([0, 1, 0, 1, 0, 1])
|
||||
>>> segment_sum(data, segment_ids)
|
||||
array([9 12], shape=(2,), dtype=int32)
|
||||
```
|
||||
"""
|
||||
if any_symbolic_tensors((data,)):
|
||||
return SegmentSum(num_segments, sorted).symbolic_call(data, segment_ids)
|
||||
return backend.math.segment_sum(
|
||||
@ -63,6 +87,29 @@ class TopK(Operation):
|
||||
|
||||
@keras_core_export("keras_core.ops.top_k")
|
||||
def top_k(x, k, sorted=True):
|
||||
"""Finds the top-k values and their indices in a tensor.
|
||||
|
||||
Args:
|
||||
x: Input tensor.
|
||||
k: An integer representing the number of top elements to retrieve.
|
||||
sorted: A boolean indicating whether to sort the output in
|
||||
descending order. Default is `True`.
|
||||
|
||||
Returns:
|
||||
A tuple containing two tensors. The first tensor contains the
|
||||
top-k values, and the second tensor contains the indices of the
|
||||
top-k values in the input tensor.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> x = keras_core.ops.convert_to_tensor([5, 2, 7, 1, 9, 3])
|
||||
>>> values, indices = top_k(x, k=3)
|
||||
>>> print(values)
|
||||
array([9 7 5], shape=(3,), dtype=int32)
|
||||
>>> print(indices)
|
||||
array([4 2 0], shape=(3,), dtype=int32)
|
||||
```
|
||||
"""
|
||||
if any_symbolic_tensors((x,)):
|
||||
return TopK(k, sorted).symbolic_call(x)
|
||||
return backend.math.top_k(x, k, sorted)
|
||||
@ -82,6 +129,28 @@ class InTopK(Operation):
|
||||
|
||||
@keras_core_export("keras_core.ops.in_top_k")
|
||||
def in_top_k(targets, predictions, k):
|
||||
"""Checks if the targets are in the top-k predictions.
|
||||
|
||||
Args:
|
||||
targets: A tensor of true labels.
|
||||
predictions: A tensor of predicted labels.
|
||||
k: An integer representing the number of predictions to consider.
|
||||
|
||||
Returns:
|
||||
A boolean tensor of the same shape as `targets`, where each element
|
||||
indicates whether the corresponding target is in the top-k predictions.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> targets = keras_core.ops.convert_to_tensor([2, 5, 3])
|
||||
>>> predictions = keras_core.ops.convert_to_tensor(
|
||||
[[0.1, 0.4, 0.6, 0.9, 0.5],
|
||||
[0.1, 0.7, 0.9, 0.8, 0.3],
|
||||
[0.1, 0.6, 0.9, 0.9, 0.5]])
|
||||
>>> in_top_k(targets, predictions, k=3)
|
||||
array([ True False True], shape=(3,), dtype=bool)
|
||||
```
|
||||
"""
|
||||
if any_symbolic_tensors((targets, predictions)):
|
||||
return InTopK(k).symbolic_call(targets, predictions)
|
||||
return backend.math.in_top_k(targets, predictions, k)
|
||||
@ -103,6 +172,27 @@ class Logsumexp(Operation):
|
||||
|
||||
@keras_core_export("keras_core.ops.logsumexp")
|
||||
def logsumexp(x, axis=None, keepdims=False):
|
||||
"""Computes the logarithm of sum of exponentials of elements in a tensor.
|
||||
|
||||
Args:
|
||||
x: Input tensor.
|
||||
axis: An integer or a tuple of integers specifying the axis/axes
|
||||
along which to compute the sum. If `None`, the sum is computed
|
||||
over all elements. Default is `None`.
|
||||
keepdims: A boolean indicating whether to keep the dimensions of
|
||||
the input tensor when computing the sum. Default is `False`.
|
||||
|
||||
Returns:
|
||||
A tensor containing the logarithm of the sum of exponentials of
|
||||
elements in `x`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> x = keras_core.ops.convert_to_tensor([1., 2., 3.])
|
||||
>>> logsumexp(x)
|
||||
array(3.407606, shape=(), dtype=float32)
|
||||
```
|
||||
"""
|
||||
if any_symbolic_tensors((x,)):
|
||||
return Logsumexp(axis, keepdims).symbolic_call(x)
|
||||
return backend.math.logsumexp(x, axis=axis, keepdims=keepdims)
|
||||
@ -152,6 +242,30 @@ class Qr(Operation):
|
||||
|
||||
@keras_core_export("keras_core.ops.qr")
|
||||
def qr(x, mode="reduced"):
|
||||
"""Computes the QR decomposition of a tensor.
|
||||
|
||||
Args:
|
||||
x: Input tensor.
|
||||
mode: A string specifying the mode of the QR decomposition.
|
||||
- 'reduced': Returns the reduced QR decomposition. (default)
|
||||
- 'complete': Returns the complete QR decomposition.
|
||||
|
||||
Returns:
|
||||
A tuple containing two tensors. The first tensor represents the
|
||||
orthogonal matrix Q, and the second tensor represents the upper
|
||||
triangular matrix R.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> x = keras_core.ops.convert_to_tensor([[1., 2.], [3., 4.], [5., 6.]])
|
||||
>>> q, r = qr(x)
|
||||
>>> print(q)
|
||||
array([[-0.16903079 0.897085]
|
||||
[-0.5070925 0.2760267 ]
|
||||
[-0.8451542 -0.34503305]], shape=(3, 2), dtype=float32)
|
||||
```
|
||||
"""
|
||||
|
||||
if any_symbolic_tensors((x,)):
|
||||
return Qr(mode=mode).symbolic_call(x)
|
||||
return backend.math.qr(x, mode=mode)
|
||||
|
@ -895,7 +895,7 @@ def conv_transpose(
|
||||
the output tensor. Can be a single integer to specify the same
|
||||
value for all spatial dimensions. The amount of output padding
|
||||
along a given dimension must be lower than the stride along that
|
||||
same dimension. If set to None (default), the output shape is
|
||||
same dimension. If set to `None` (default), the output shape is
|
||||
inferred.
|
||||
data_format: A string, either `"channels_last"` or `"channels_first"`.
|
||||
`data_format` determines the ordering of the dimensions in the
|
||||
@ -955,6 +955,38 @@ class OneHot(Operation):
|
||||
|
||||
@keras_core_export(["keras_core.ops.one_hot", "keras_core.ops.nn.one_hot"])
|
||||
def one_hot(x, num_classes, axis=-1, dtype=None):
|
||||
"""Converts integer tensor `x` into a one-hot tensor.
|
||||
|
||||
The one-hot encoding is a representation where each integer value is
|
||||
converted into a binary vector with a length equal to `num_classes`,
|
||||
and the index corresponding to the integer value is marked as 1, while
|
||||
all other indices are marked as 0.
|
||||
|
||||
Args:
|
||||
x : Integer tensor to be encoded. The shape can be
|
||||
arbitrary, but the dtype should be integer.
|
||||
num_classes: Number of classes for the one-hot encoding.
|
||||
axis: Axis along which the encoding is performed. Default is
|
||||
-1, which represents the last axis.
|
||||
dtype: (Optional) Data type of the output tensor. If not
|
||||
provided, it defaults to the default data type of the backend.
|
||||
|
||||
Returns:
|
||||
Integer tensor: One-hot encoded tensor with the same shape as `x`
|
||||
except for the specified `axis` dimension, which will have
|
||||
a length of `num_classes`. The dtype of the output tensor
|
||||
is determined by `dtype` or the default data type of the backend.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> x = keras_core.ops.convert_to_tensor([1, 3, 2, 0])
|
||||
>>> one_hot(x, num_classes=4)
|
||||
array([[0. 1. 0. 0.]
|
||||
[0. 0. 0. 1.]
|
||||
[0. 0. 1. 0.]
|
||||
[1. 0. 0. 0.]], shape=(4, 4), dtype=float32)
|
||||
```
|
||||
"""
|
||||
if any_symbolic_tensors((x,)):
|
||||
return OneHot(num_classes, axis=axis, dtype=dtype).symbolic_call(x)
|
||||
return backend.nn.one_hot(
|
||||
@ -989,6 +1021,40 @@ class BinaryCrossentropy(Operation):
|
||||
]
|
||||
)
|
||||
def binary_crossentropy(target, output, from_logits=False):
|
||||
"""Computes binary cross-entropy loss between target and output tensor.
|
||||
|
||||
The binary cross-entropy loss is commonly used in binary
|
||||
classification tasks where each input sample belongs to one
|
||||
of the two classes. It measures the dissimilarity between the
|
||||
target and output probabilities or logits.
|
||||
|
||||
Args:
|
||||
target: The target tensor representing the true binary labels.
|
||||
Its shape should match the shape of the `output` tensor.
|
||||
output: The output tensor representing the predicted probabilities
|
||||
or logits. Its shape should match the shape of the
|
||||
`target` tensor.
|
||||
from_logits: (optional) Whether `output` is a tensor of logits or
|
||||
probabilities.
|
||||
Set it to `True` if `output` represents logits; otherwise,
|
||||
set it to `False` if `output` represents probabilities.
|
||||
Default is `False`.
|
||||
|
||||
Returns:
|
||||
Integer tensor: The computed binary cross-entropy loss between
|
||||
`target` and `output`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> target = keras_core.ops.convert_to_tensor([0, 1, 1, 0],
|
||||
dtype=float32)
|
||||
>>> output = keras_core.ops.convert_to_tensor([0.1, 0.9, 0.8, 0.2],
|
||||
dtype=float32)
|
||||
>>> binary_crossentropy(target, output)
|
||||
array([0.10536054 0.10536054 0.22314355 0.22314355],
|
||||
shape=(4,), dtype=float32)
|
||||
```
|
||||
"""
|
||||
if any_symbolic_tensors((target, output)):
|
||||
return BinaryCrossentropy(from_logits=from_logits).symbolic_call(
|
||||
target, output
|
||||
@ -1032,6 +1098,48 @@ class CategoricalCrossentropy(Operation):
|
||||
]
|
||||
)
|
||||
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||
"""Computes categorical cross-entropy loss between target and output tensor.
|
||||
|
||||
The categorical cross-entropy loss is commonly used in multi-class
|
||||
classification tasks where each input sample can belong to one of
|
||||
multiple classes. It measures the dissimilarity
|
||||
between the target and output probabilities or logits.
|
||||
|
||||
Args:
|
||||
target: The target tensor representing the true categorical labels.
|
||||
Its shape should match the shape of the `output` tensor
|
||||
except for the last dimension.
|
||||
output: The output tensor representing the predicted probabilities
|
||||
or logits. Its shape should match the shape of the `target`
|
||||
tensor except for the last dimension.
|
||||
from_logits: (optional) Whether `output` is a tensor of logits or
|
||||
probabilities.
|
||||
Set it to `True` if `output` represents logits; otherwise,
|
||||
set it to `False` if `output` represents probabilities.
|
||||
Default is `False`.
|
||||
axis: (optional) The axis along which the categorical cross-entropy
|
||||
is computed.
|
||||
Default is -1, which corresponds to the last dimension of
|
||||
the tensors.
|
||||
|
||||
Returns:
|
||||
Integer tensor: The computed categorical cross-entropy loss between
|
||||
`target` and `output`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> target = keras_core.ops.convert_to_tensor([[1, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1]],
|
||||
dtype=float32)
|
||||
>>> output = keras_core.ops.convert_to_tensor([[0.9, 0.05, 0.05],
|
||||
[0.1, 0.8, 0.1],
|
||||
[0.2, 0.3, 0.5]],
|
||||
dtype=float32)
|
||||
>>> categorical_crossentropy(target, output)
|
||||
array([0.10536054 0.22314355 0.6931472 ], shape=(3,), dtype=float32)
|
||||
```
|
||||
"""
|
||||
if any_symbolic_tensors((target, output)):
|
||||
return CategoricalCrossentropy(
|
||||
from_logits=from_logits, axis=axis
|
||||
@ -1078,6 +1186,46 @@ class SparseCategoricalCrossentropy(Operation):
|
||||
]
|
||||
)
|
||||
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||
"""Computes sparse categorical cross-entropy loss.
|
||||
|
||||
The sparse categorical cross-entropy loss is similar to categorical
|
||||
cross-entropy, but it is used when the target tensor contains integer
|
||||
class labels instead of one-hot encoded vectors. It measures the
|
||||
dissimilarity between the target and output probabilities or logits.
|
||||
|
||||
Args:
|
||||
target: The target tensor representing the true class labels as integers.
|
||||
Its shape should match the shape of the `output` tensor except
|
||||
for the last dimension.
|
||||
output: The output tensor representing the predicted probabilities
|
||||
or logits.
|
||||
Its shape should match the shape of the `target` tensor except
|
||||
for the last dimension.
|
||||
from_logits: (optional) Whether `output` is a tensor of logits
|
||||
or probabilities.
|
||||
Set it to `True` if `output` represents logits; otherwise,
|
||||
set it to `False` if `output` represents probabilities.
|
||||
Default is `False`.
|
||||
axis: (optional) The axis along which the sparse categorical
|
||||
cross-entropy is computed.
|
||||
Default is -1, which corresponds to the last dimension
|
||||
of the tensors.
|
||||
|
||||
Returns:
|
||||
Integer tensor: The computed sparse categorical cross-entropy
|
||||
loss between `target` and `output`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> target = keras_core.ops.convert_to_tensor([0, 1, 2], dtype=int32)
|
||||
>>> output = keras_core.ops.convert_to_tensor([[0.9, 0.05, 0.05],
|
||||
[0.1, 0.8, 0.1],
|
||||
[0.2, 0.3, 0.5]],
|
||||
dtype=float32)
|
||||
>>> sparse_categorical_crossentropy(target, output)
|
||||
array([0.10536056 0.22314355 0.6931472 ], shape=(3,), dtype=float32)
|
||||
```
|
||||
"""
|
||||
if any_symbolic_tensors((target, output)):
|
||||
return SparseCategoricalCrossentropy(
|
||||
from_logits=from_logits, axis=axis
|
||||
|
Loading…
Reference in New Issue
Block a user