Convert ddim (#743)

* import ddim.py from keras-io

* convert keras to keras_core

1. convert keras/tf.keras to keras_core
2. replace most of tf ops except tf.image, tf.data, and tf.GrapdientTap
3. fix ops.concat out of range issue
This commit is contained in:
Koan-Sin Tan 2023-08-19 01:37:56 +07:00 committed by Francois Chollet
parent 5a22094d04
commit ccd4c0e135
2 changed files with 891 additions and 1 deletions

@ -0,0 +1,889 @@
"""
Title: Denoising Diffusion Implicit Models
Author: [András Béres](https://www.linkedin.com/in/andras-beres-789190210)
Date created: 2022/06/24
Last modified: 2022/06/24
Description: Generating images of flowers with denoising diffusion implicit models.
Accelerator: GPU
"""
"""
## Introduction
### What are diffusion models?
Recently, [denoising diffusion models](https://arxiv.org/abs/2006.11239), including
[score-based generative models](https://arxiv.org/abs/1907.05600), gained popularity as a
powerful class of generative models, that can [rival](https://arxiv.org/abs/2105.05233)
even [generative adversarial networks (GANs)](https://arxiv.org/abs/1406.2661) in image
synthesis quality. They tend to generate more diverse samples, while being stable to
train and easy to scale. Recent large diffusion models, such as
[DALL-E 2](https://openai.com/dall-e-2/) and [Imagen](https://imagen.research.google/),
have shown incredible text-to-image generation capability. One of their drawbacks is
however, that they are slower to sample from, because they require multiple forward passes
for generating an image.
Diffusion refers to the process of turning a structured signal (an image) into noise
step-by-step. By simulating diffusion, we can generate noisy images from our training
images, and can train a neural network to try to denoise them. Using the trained network
we can simulate the opposite of diffusion, reverse diffusion, which is the process of an
image emerging from noise.
![diffusion process gif](https://i.imgur.com/dipPOfa.gif)
One-sentence summary: **diffusion models are trained to denoise noisy images, and can
generate images by iteratively denoising pure noise.**
### Goal of this example
This code example intends to be a minimal but feature-complete (with a generation quality
metric) implementation of diffusion models, with modest compute requirements and
reasonable performance. My implementation choices and hyperparameter tuning were done
with these goals in mind.
Since currently the literature of diffusion models is
[mathematically quite complex](https://arxiv.org/abs/2206.00364)
with multiple theoretical frameworks
([score matching](https://arxiv.org/abs/1907.05600),
[differential equations](https://arxiv.org/abs/2011.13456),
[Markov chains](https://arxiv.org/abs/2006.11239)) and sometimes even
[conflicting notations (see Appendix C.2)](https://arxiv.org/abs/2010.02502),
it can be daunting trying to understand
them. My view of these models in this example will be that they learn to separate a
noisy image into its image and Gaussian noise components.
In this example I made effort to break down all long mathematical expressions into
digestible pieces and gave all variables explanatory names. I also included numerous
links to relevant literature to help interested readers dive deeper into the topic, in
the hope that this code example will become a good starting point for practitioners
learning about diffusion models.
In the following sections, we will implement a continuous time version of
[Denoising Diffusion Implicit Models (DDIMs)](https://arxiv.org/abs/2010.02502)
with deterministic sampling.
"""
"""
## Setup
"""
import math
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import keras_core as keras
from keras_core import layers
from keras_core import ops
"""
## Hyperparameters
"""
# data
dataset_name = "oxford_flowers102"
dataset_repetitions = 5
num_epochs = 1 # train for at least 50 epochs for good results
image_size = 64
# KID = Kernel Inception Distance, see related section
kid_image_size = 75
kid_diffusion_steps = 5
plot_diffusion_steps = 20
# sampling
min_signal_rate = 0.02
max_signal_rate = 0.95
# architecture
embedding_dims = 32
embedding_max_frequency = 1000.0
widths = [32, 64, 96, 128]
block_depth = 2
# optimization
batch_size = 64
ema = 0.999
learning_rate = 1e-3
weight_decay = 1e-4
"""
## Data pipeline
We will use the
[Oxford Flowers 102](https://www.tensorflow.org/datasets/catalog/oxford_flowers102)
dataset for
generating images of flowers, which is a diverse natural dataset containing around 8,000
images. Unfortunately the official splits are imbalanced, as most of the images are
contained in the test split. We create new splits (80% train, 20% validation) using the
[Tensorflow Datasets slicing API](https://www.tensorflow.org/datasets/splits). We apply
center crops as preprocessing, and repeat the dataset multiple times (reason given in the
next section).
"""
def preprocess_image(data):
# center crop image
height = ops.shape(data["image"])[0]
width = ops.shape(data["image"])[1]
crop_size = ops.minimum(height, width)
image = tf.image.crop_to_bounding_box(
data["image"],
(height - crop_size) // 2,
(width - crop_size) // 2,
crop_size,
crop_size,
)
# resize and clip
# for image downsampling it is important to turn on antialiasing
image = tf.image.resize(
image, size=[image_size, image_size], antialias=True
)
return ops.clip(image / 255.0, 0.0, 1.0)
def prepare_dataset(split):
# the validation dataset is shuffled as well, because data order matters
# for the KID estimation
return (
tfds.load(dataset_name, split=split, shuffle_files=True)
.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
.cache()
.repeat(dataset_repetitions)
.shuffle(10 * batch_size)
.batch(batch_size, drop_remainder=True)
.prefetch(buffer_size=tf.data.AUTOTUNE)
)
# load dataset
train_dataset = prepare_dataset("train[:80%]+validation[:80%]+test[:80%]")
val_dataset = prepare_dataset("train[80%:]+validation[80%:]+test[80%:]")
"""
## Kernel inception distance
[Kernel Inception Distance (KID)](https://arxiv.org/abs/1801.01401) is an image quality
metric which was proposed as a replacement for the popular
[Frechet Inception Distance (FID)](https://arxiv.org/abs/1706.08500).
I prefer KID to FID because it is simpler to
implement, can be estimated per-batch, and is computationally lighter. More details
[here](https://keras.io/examples/generative/gan_ada/#kernel-inception-distance).
In this example, the images are evaluated at the minimal possible resolution of the
Inception network (75x75 instead of 299x299), and the metric is only measured on the
validation set for computational efficiency. We also limit the number of sampling steps
at evaluation to 5 for the same reason.
Since the dataset is relatively small, we go over the train and validation splits
multiple times per epoch, because the KID estimation is noisy and compute-intensive, so
we want to evaluate only after many iterations, but for many iterations.
"""
@keras.saving.register_keras_serializable()
class KID(keras.metrics.Metric):
def __init__(self, name, **kwargs):
super().__init__(name=name, **kwargs)
# KID is estimated per batch and is averaged across batches
self.kid_tracker = keras.metrics.Mean(name="kid_tracker")
# a pretrained InceptionV3 is used without its classification layer
# transform the pixel values to the 0-255 range, then use the same
# preprocessing as during pretraining
self.encoder = keras.Sequential(
[
keras.Input(shape=(image_size, image_size, 3)),
layers.Rescaling(255.0),
layers.Resizing(height=kid_image_size, width=kid_image_size),
layers.Lambda(keras.applications.inception_v3.preprocess_input),
keras.applications.InceptionV3(
include_top=False,
input_shape=(kid_image_size, kid_image_size, 3),
weights="imagenet",
),
layers.GlobalAveragePooling2D(),
],
name="inception_encoder",
)
def polynomial_kernel(self, features_1, features_2):
feature_dimensions = ops.cast(ops.shape(features_1)[1], dtype="float32")
return (
features_1 @ ops.transpose(features_2) / feature_dimensions + 1.0
) ** 3.0
def update_state(self, real_images, generated_images, sample_weight=None):
real_features = self.encoder(real_images, training=False)
generated_features = self.encoder(generated_images, training=False)
# compute polynomial kernels using the two sets of features
kernel_real = self.polynomial_kernel(real_features, real_features)
kernel_generated = self.polynomial_kernel(
generated_features, generated_features
)
kernel_cross = self.polynomial_kernel(real_features, generated_features)
# estimate the squared maximum mean discrepancy using the average kernel values
batch_size = real_features.shape[0]
batch_size_f = ops.cast(batch_size, dtype="float32")
mean_kernel_real = ops.sum(
kernel_real * (1.0 - ops.eye(batch_size))
) / (batch_size_f * (batch_size_f - 1.0))
mean_kernel_generated = ops.sum(
kernel_generated * (1.0 - ops.eye(batch_size))
) / (batch_size_f * (batch_size_f - 1.0))
mean_kernel_cross = ops.mean(kernel_cross)
kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross
# update the average KID estimate
self.kid_tracker.update_state(kid)
def result(self):
return self.kid_tracker.result()
def reset_state(self):
self.kid_tracker.reset_state()
"""
## Network architecture
Here we specify the architecture of the neural network that we will use for denoising. We
build a [U-Net](https://arxiv.org/abs/1505.04597) with identical input and output
dimensions. U-Net is a popular semantic segmentation architecture, whose main idea is
that it progressively downsamples and then upsamples its input image, and adds skip
connections between layers having the same resolution. These help with gradient flow and
avoid introducing a representation bottleneck, unlike usual
[autoencoders](https://www.deeplearningbook.org/contents/autoencoders.html). Based on
this, one can view
[diffusion models as denoising autoencoders](https://benanne.github.io/2022/01/31/diffusion.html)
without a bottleneck.
The network takes two inputs, the noisy images and the variances of their noise
components. The latter is required since denoising a signal requires different operations
at different levels of noise. We transform the noise variances using sinusoidal
embeddings, similarly to positional encodings used both in
[transformers](https://arxiv.org/abs/1706.03762) and
[NeRF](https://arxiv.org/abs/2003.08934). This helps the network to be
[highly sensitive](https://arxiv.org/abs/2006.10739) to the noise level, which is
crucial for good performance. We implement sinusoidal embeddings using a
[Lambda layer](https://keras.io/api/layers/core_layers/lambda/).
Some other considerations:
* We build the network using the
[Keras Functional API](https://keras.io/guides/functional_api/), and use
[closures](https://twitter.com/fchollet/status/1441927912836321280) to build blocks of
layers in a consistent style.
* [Diffusion models](https://arxiv.org/abs/2006.11239) embed the index of the timestep of
the diffusion process instead of the noise variance, while
[score-based models (Table 1)](https://arxiv.org/abs/2206.00364)
usually use some function of the noise level. I
prefer the latter so that we can change the sampling schedule at inference time, without
retraining the network.
* [Diffusion models](https://arxiv.org/abs/2006.11239) input the embedding to each
convolution block separately. We only input it at the start of the network for
simplicity, which in my experience barely decreases performance, because the skip and
residual connections help the information propagate through the network properly.
* In the literature it is common to use
[attention layers](https://keras.io/api/layers/attention_layers/multi_head_attention/)
at lower resolutions for better global coherence. I omitted it for simplicity.
* We disable the learnable center and scale parameters of the batch normalization layers,
since the following convolution layers make them redundant.
* We initialize the last convolution's kernel to all zeros as a good practice, making the
network predict only zeros after initialization, which is the mean of its targets. This
will improve behaviour at the start of training and make the mean squared error loss
start at exactly 1.
"""
@keras.saving.register_keras_serializable()
def sinusoidal_embedding(x):
embedding_min_frequency = 1.0
frequencies = ops.exp(
ops.linspace(
ops.log(embedding_min_frequency),
ops.log(embedding_max_frequency),
embedding_dims // 2,
)
)
angular_speeds = ops.cast(2.0 * math.pi * frequencies, "float32")
embeddings = ops.concatenate(
[ops.sin(angular_speeds * x), ops.cos(angular_speeds * x)], axis=3
)
return embeddings
def ResidualBlock(width):
def apply(x):
input_width = x.shape[3]
if input_width == width:
residual = x
else:
residual = layers.Conv2D(width, kernel_size=1)(x)
x = layers.BatchNormalization(center=False, scale=False)(x)
x = layers.Conv2D(
width, kernel_size=3, padding="same", activation="swish"
)(x)
x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
x = layers.Add()([x, residual])
return x
return apply
def DownBlock(width, block_depth):
def apply(x):
x, skips = x
for _ in range(block_depth):
x = ResidualBlock(width)(x)
skips.append(x)
x = layers.AveragePooling2D(pool_size=2)(x)
return x
return apply
def UpBlock(width, block_depth):
def apply(x):
x, skips = x
x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
for _ in range(block_depth):
x = layers.Concatenate()([x, skips.pop()])
x = ResidualBlock(width)(x)
return x
return apply
def get_network(image_size, widths, block_depth):
noisy_images = keras.Input(shape=(image_size, image_size, 3))
noise_variances = keras.Input(shape=(1, 1, 1))
e = layers.Lambda(sinusoidal_embedding, output_shape=(1, 1, 32))(
noise_variances
)
e = layers.UpSampling2D(size=image_size, interpolation="nearest")(e)
x = layers.Conv2D(widths[0], kernel_size=1)(noisy_images)
x = layers.Concatenate()([x, e])
skips = []
for width in widths[:-1]:
x = DownBlock(width, block_depth)([x, skips])
for _ in range(block_depth):
x = ResidualBlock(widths[-1])(x)
for width in reversed(widths[:-1]):
x = UpBlock(width, block_depth)([x, skips])
x = layers.Conv2D(3, kernel_size=1, kernel_initializer="zeros")(x)
return keras.Model([noisy_images, noise_variances], x, name="residual_unet")
"""
This showcases the power of the Functional API. Note how we built a relatively complex
U-Net with skip connections, residual blocks, multiple inputs, and sinusoidal embeddings
in 80 lines of code!
"""
"""
## Diffusion model
### Diffusion schedule
Let us say, that a diffusion process starts at time = 0, and ends at time = 1. This
variable will be called diffusion time, and can be either discrete (common in diffusion
models) or continuous (common in score-based models). I choose the latter, so that the
number of sampling steps can be changed at inference time.
We need to have a function that tells us at each point in the diffusion process the noise
levels and signal levels of the noisy image corresponding to the actual diffusion time.
This will be called the diffusion schedule (see `diffusion_schedule()`).
This schedule outputs two quantities: the `noise_rate` and the `signal_rate`
(corresponding to sqrt(1 - alpha) and sqrt(alpha) in the DDIM paper, respectively). We
generate the noisy image by weighting the random noise and the training image by their
corresponding rates and adding them together.
Since the (standard normal) random noises and the (normalized) images both have zero mean
and unit variance, the noise rate and signal rate can be interpreted as the standard
deviation of their components in the noisy image, while the squares of their rates can be
interpreted as their variance (or their power in the signal processing sense). The rates
will always be set so that their squared sum is 1, meaning that the noisy images will
always have unit variance, just like its unscaled components.
We will use a simplified, continuous version of the
[cosine schedule (Section 3.2)](https://arxiv.org/abs/2102.09672),
that is quite commonly used in the literature.
This schedule is symmetric, slow towards the start and end of the diffusion process, and
it also has a nice geometric interpretation, using the
[trigonometric properties of the unit circle](https://en.wikipedia.org/wiki/Unit_circle#/media/File:Circle-trig6.svg):
![diffusion schedule gif](https://i.imgur.com/JW9W0fA.gif)
### Training process
The training procedure (see `train_step()` and `denoise()`) of denoising diffusion models
is the following: we sample random diffusion times uniformly, and mix the training images
with random gaussian noises at rates corresponding to the diffusion times. Then, we train
the model to separate the noisy image to its two components.
Usually, the neural network is trained to predict the unscaled noise component, from
which the predicted image component can be calculated using the signal and noise rates.
Pixelwise
[mean squared error](https://keras.io/api/losses/regression_losses/#mean_squared_error-function) should
be used theoretically, however I recommend using
[mean absolute error](https://keras.io/api/losses/regression_losses/#mean_absolute_error-function)
instead (similarly to
[this](https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L371)
implementation), which produces better results on this dataset.
### Sampling (reverse diffusion)
When sampling (see `reverse_diffusion()`), at each step we take the previous estimate of
the noisy image and separate it into image and noise using our network. Then we recombine
these components using the signal and noise rate of the following step.
Though a similar view is shown in
[Equation 12 of DDIMs](https://arxiv.org/abs/2010.02502), I believe the above explanation
of the sampling equation is not widely known.
This example only implements the deterministic sampling procedure from DDIM, which
corresponds to *eta = 0* in the paper. One can also use stochastic sampling (in which
case the model becomes a
[Denoising Diffusion Probabilistic Model (DDPM)](https://arxiv.org/abs/2006.11239)),
where a part of the predicted noise is
replaced with the same or larger amount of random noise
([see Equation 16 and below](https://arxiv.org/abs/2010.02502)).
Stochastic sampling can be used without retraining the network (since both models are
trained the same way), and it can improve sample quality, while on the other hand
requiring more sampling steps usually.
"""
@keras.saving.register_keras_serializable()
class DiffusionModel(keras.Model):
def __init__(self, image_size, widths, block_depth):
super().__init__()
self.normalizer = layers.Normalization()
self.network = get_network(image_size, widths, block_depth)
self.ema_network = keras.models.clone_model(self.network)
def compile(self, **kwargs):
super().compile(**kwargs)
self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
self.kid = KID(name="kid")
@property
def metrics(self):
return [self.noise_loss_tracker, self.image_loss_tracker, self.kid]
def denormalize(self, images):
# convert the pixel values back to 0-1 range
images = self.normalizer.mean + images * self.normalizer.variance**0.5
return ops.clip(images, 0.0, 1.0)
def diffusion_schedule(self, diffusion_times):
# diffusion times -> angles
start_angle = ops.cast(ops.arccos(max_signal_rate), "float32")
end_angle = ops.cast(ops.arccos(min_signal_rate), "float32")
diffusion_angles = start_angle + diffusion_times * (
end_angle - start_angle
)
# angles -> signal and noise rates
signal_rates = ops.cos(diffusion_angles)
noise_rates = ops.sin(diffusion_angles)
# note that their squared sum is always: sin^2(x) + cos^2(x) = 1
return noise_rates, signal_rates
def denoise(self, noisy_images, noise_rates, signal_rates, training):
# the exponential moving average weights are used at evaluation
if training:
network = self.network
else:
network = self.ema_network
# predict noise component and calculate the image component using it
pred_noises = network(
[noisy_images, noise_rates**2], training=training
)
pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates
return pred_noises, pred_images
def reverse_diffusion(self, initial_noise, diffusion_steps):
# reverse diffusion = sampling
num_images = initial_noise.shape[0]
step_size = 1.0 / diffusion_steps
# important line:
# at the first sampling step, the "noisy image" is pure noise
# but its signal rate is assumed to be nonzero (min_signal_rate)
next_noisy_images = initial_noise
for step in range(diffusion_steps):
noisy_images = next_noisy_images
# separate the current noisy image to its components
diffusion_times = ops.ones((num_images, 1, 1, 1)) - step * step_size
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
pred_noises, pred_images = self.denoise(
noisy_images, noise_rates, signal_rates, training=False
)
# network used in eval mode
# remix the predicted components using the next signal and noise rates
next_diffusion_times = diffusion_times - step_size
next_noise_rates, next_signal_rates = self.diffusion_schedule(
next_diffusion_times
)
next_noisy_images = (
next_signal_rates * pred_images + next_noise_rates * pred_noises
)
# this new noisy image will be used in the next step
return pred_images
def generate(self, num_images, diffusion_steps):
# noise -> images -> denormalized images
initial_noise = keras.random.normal(
shape=(num_images, image_size, image_size, 3)
)
generated_images = self.reverse_diffusion(
initial_noise, diffusion_steps
)
generated_images = self.denormalize(generated_images)
return generated_images
def train_step(self, images):
# normalize images to have standard deviation of 1, like the noises
images = self.normalizer(images, training=True)
noises = keras.random.normal(
shape=(batch_size, image_size, image_size, 3)
)
# sample uniform random diffusion times
diffusion_times = keras.random.uniform(
shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
)
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
# mix the images with noises accordingly
noisy_images = signal_rates * images + noise_rates * noises
with tf.GradientTape() as tape:
# train the network to separate noisy images to their components
pred_noises, pred_images = self.denoise(
noisy_images, noise_rates, signal_rates, training=True
)
noise_loss = self.loss(noises, pred_noises) # used for training
image_loss = self.loss(images, pred_images) # only used as metric
gradients = tape.gradient(noise_loss, self.network.trainable_weights)
self.optimizer.apply_gradients(
zip(gradients, self.network.trainable_weights)
)
self.noise_loss_tracker.update_state(noise_loss)
self.image_loss_tracker.update_state(image_loss)
# track the exponential moving averages of weights
for weight, ema_weight in zip(
self.network.weights, self.ema_network.weights
):
ema_weight.assign(ema * ema_weight + (1 - ema) * weight)
# KID is not measured during the training phase for computational efficiency
return {m.name: m.result() for m in self.metrics[:-1]}
def test_step(self, images):
# normalize images to have standard deviation of 1, like the noises
images = self.normalizer(images, training=False)
noises = keras.random.normal(
shape=(batch_size, image_size, image_size, 3)
)
# sample uniform random diffusion times
diffusion_times = keras.random.uniform(
shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
)
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
# mix the images with noises accordingly
noisy_images = signal_rates * images + noise_rates * noises
# use the network to separate noisy images to their components
pred_noises, pred_images = self.denoise(
noisy_images, noise_rates, signal_rates, training=False
)
noise_loss = self.loss(noises, pred_noises)
image_loss = self.loss(images, pred_images)
self.image_loss_tracker.update_state(image_loss)
self.noise_loss_tracker.update_state(noise_loss)
# measure KID between real and generated images
# this is computationally demanding, kid_diffusion_steps has to be small
images = self.denormalize(images)
generated_images = self.generate(
num_images=batch_size, diffusion_steps=kid_diffusion_steps
)
self.kid.update_state(images, generated_images)
return {m.name: m.result() for m in self.metrics}
def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6):
# plot random generated images for visual evaluation of generation quality
generated_images = self.generate(
num_images=num_rows * num_cols,
diffusion_steps=plot_diffusion_steps,
)
plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
for row in range(num_rows):
for col in range(num_cols):
index = row * num_cols + col
plt.subplot(num_rows, num_cols, index + 1)
plt.imshow(generated_images[index])
plt.axis("off")
plt.tight_layout()
plt.show()
plt.close()
"""
## Training
"""
# create and compile the model
model = DiffusionModel(image_size, widths, block_depth)
# below tensorflow 2.9:
# pip install tensorflow_addons
# import tensorflow_addons as tfa
# optimizer=tfa.optimizers.AdamW
model.compile(
optimizer=keras.optimizers.AdamW(
learning_rate=learning_rate, weight_decay=weight_decay
),
loss=keras.losses.mean_absolute_error,
)
# pixelwise mean absolute error is used as loss
# save the best model based on the validation KID metric
checkpoint_path = "checkpoints/diffusion_model.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
monitor="val_kid",
mode="min",
save_best_only=True,
)
# calculate mean and variance of training dataset for normalization
model.normalizer.adapt(train_dataset)
# run training and plot generated images periodically
model.fit(
train_dataset,
epochs=num_epochs,
validation_data=val_dataset,
callbacks=[
keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
checkpoint_callback,
],
)
"""
## Inference
"""
# load the best model and generate images
model.load_weights(checkpoint_path)
model.plot_images()
"""
## Results
By running the training for at least 50 epochs (takes 2 hours on a T4 GPU and 30 minutes
on an A100 GPU), one can get high quality image generations using this code example.
The evolution of a batch of images over a 80 epoch training (color artifacts are due to
GIF compression):
![flowers training gif](https://i.imgur.com/FSCKtZq.gif)
Images generated using between 1 and 20 sampling steps from the same initial noise:
![flowers sampling steps gif](https://i.imgur.com/tM5LyH3.gif)
Interpolation (spherical) between initial noise samples:
![flowers interpolation gif](https://i.imgur.com/hk5Hd5o.gif)
Deterministic sampling process (noisy images on top, predicted images on bottom, 40
steps):
![flowers deterministic generation gif](https://i.imgur.com/wCvzynh.gif)
Stochastic sampling process (noisy images on top, predicted images on bottom, 80 steps):
![flowers stochastic generation gif](https://i.imgur.com/kRXOGzd.gif)
Trained model and demo available on HuggingFace:
| Trained Model | Demo |
| :--: | :--: |
| [![model badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-DDIM-black.svg)](https://huggingface.co/keras-io/denoising-diffusion-implicit-models) | [![spaces badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-DDIM-black.svg)](https://huggingface.co/spaces/keras-io/denoising-diffusion-implicit-models) |
"""
"""
## Lessons learned
During preparation for this code example I have run numerous experiments using
[this repository](https://github.com/beresandras/clear-diffusion-keras).
In this section I list
the lessons learned and my recommendations in my subjective order of importance.
### Algorithmic tips
* **min. and max. signal rates**: I found the min. signal rate to be an important
hyperparameter. Setting it too low will make the generated images oversaturated, while
setting it too high will make them undersaturated. I recommend tuning it carefully. Also,
setting it to 0 will lead to a division by zero error. The max. signal rate can be set to
1, but I found that setting it lower slightly improves generation quality.
* **loss function**: While large models tend to use mean squared error (MSE) loss, I
recommend using mean absolute error (MAE) on this dataset. In my experience MSE loss
generates more diverse samples (it also seems to hallucinate more
[Section 3](https://arxiv.org/abs/2111.05826)), while MAE loss leads to smoother images.
I recommend trying both.
* **weight decay**: I did occasionally run into diverged trainings when scaling up the
model, and found that weight decay helps in avoiding instabilities at a low performance
cost. This is why I use
[AdamW](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/experimental/AdamW)
instead of [Adam](https://keras.io/api/optimizers/adam/) in this example.
* **exponential moving average of weights**: This helps to reduce the variance of the KID
metric, and helps in averaging out short-term changes during training.
* **image augmentations**: Though I did not use image augmentations in this example, in
my experience adding horizontal flips to the training increases generation performance,
while random crops do not. Since we use a supervised denoising loss, overfitting can be
an issue, so image augmentations might be important on small datasets. One should also be
careful not to use
[leaky augmentations](https://keras.io/examples/generative/gan_ada/#invertible-data-augmentation),
which can be done following
[this method (end of Section 5)](https://arxiv.org/abs/2206.00364) for instance.
* **data normalization**: In the literature the pixel values of images are usually
converted to the -1 to 1 range. For theoretical correctness, I normalize the images to
have zero mean and unit variance instead, exactly like the random noises.
* **noise level input**: I chose to input the noise variance to the network, as it is
symmetrical under our sampling schedule. One could also input the noise rate (similar
performance), the signal rate (lower performance), or even the
[log-signal-to-noise ratio (Appendix B.1)](https://arxiv.org/abs/2107.00630)
(did not try, as its range is highly
dependent on the min. and max. signal rates, and would require adjusting the min.
embedding frequency accordingly).
* **gradient clipping**: Using global gradient clipping with a value of 1 can help with
training stability for large models, but decreased performance significantly in my
experience.
* **residual connection downscaling**: For
[deeper models (Appendix B)](https://arxiv.org/abs/2205.11487), scaling the residual
connections with 1/sqrt(2) can be helpful, but did not help in my case.
* **learning rate**: For me, [Adam optimizer's](https://keras.io/api/optimizers/adam/)
default learning rate of 1e-3 worked very well, but lower learning rates are more common
in the [literature (Tables 11-13)](https://arxiv.org/abs/2105.05233).
### Architectural tips
* **sinusoidal embedding**: Using sinusoidal embeddings on the noise level input of the
network is crucial for good performance. I recommend setting the min. embedding frequency
to the reciprocal of the range of this input, and since we use the noise variance in this
example, it can be left always at 1. The max. embedding frequency controls the smallest
change in the noise variance that the network will be sensitive to, and the embedding
dimensions set the number of frequency components in the embedding. In my experience the
performance is not too sensitive to these values.
* **skip connections**: Using skip connections in the network architecture is absolutely
critical, without them the model will fail to learn to denoise at a good performance.
* **residual connections**: In my experience residual connections also significantly
improve performance, but this might be due to the fact that we only input the noise
level embeddings to the first layer of the network instead of to all of them.
* **normalization**: When scaling up the model, I did occasionally encounter diverged
trainings, using normalization layers helped to mitigate this issue. In the literature it
is common to use
[GroupNormalization](https://www.tensorflow.org/addons/api_docs/python/tfa/layers/GroupNormalization)
(with 8 groups for example) or
[LayerNormalization](https://keras.io/api/layers/normalization_layers/layer_normalization/)
in the network, I however chose to use
[BatchNormalization](https://keras.io/api/layers/normalization_layers/batch_normalization/),
as it gave similar benefits in my experiments but was computationally lighter.
* **activations**: The choice of activation functions had a larger effect on generation
quality than I expected. In my experiments using non-monotonic activation functions
outperformed monotonic ones (such as
[ReLU](https://www.tensorflow.org/api_docs/python/tf/keras/activations/relu)), with
[Swish](https://www.tensorflow.org/api_docs/python/tf/keras/activations/swish) performing
the best (this is also what [Imagen uses, page 41](https://arxiv.org/abs/2205.11487)).
* **attention**: As mentioned earlier, it is common in the literature to use
[attention layers](https://keras.io/api/layers/attention_layers/multi_head_attention/) at low
resolutions for better global coherence. I omitted them for simplicity.
* **upsampling**:
[Bilinear and nearest neighbour upsampling](https://keras.io/api/layers/reshaping_layers/up_sampling2d/)
in the network performed similarly, however I did not try
[transposed convolutions](https://keras.io/api/layers/convolution_layers/convolution2d_transpose/).
For a similar list about GANs check out
[this Keras tutorial](https://keras.io/examples/generative/gan_ada/#gan-tips-and-tricks).
"""
"""
## What to try next?
If you would like to dive in deeper to the topic, I recommend checking out
[this repository](https://github.com/beresandras/clear-diffusion-keras) that I created in
preparation for this code example, which implements a wider range of features in a
similar style, such as:
* stochastic sampling
* second-order sampling based on the
[differential equation view of DDIMs (Equation 13)](https://arxiv.org/abs/2010.02502)
* more diffusion schedules
* more network output types: predicting image or
[velocity (Appendix D)](https://arxiv.org/abs/2202.00512) instead of noise
* more datasets
"""
"""
## Related works
* [Score-based generative modeling](https://yang-song.github.io/blog/2021/score/)
(blogpost)
* [What are diffusion models?](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)
(blogpost)
* [Annotated diffusion model](https://huggingface.co/blog/annotated-diffusion) (blogpost)
* [CVPR 2022 tutorial on diffusion models](https://cvpr2022-tutorial-diffusion-models.github.io/)
(slides available)
* [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364):
attempts unifying diffusion methods under a common framework
* High-level video overviews: [1](https://www.youtube.com/watch?v=yTAMrHVG1ew),
[2](https://www.youtube.com/watch?v=344w5h24-h8)
* Detailed technical videos: [1](https://www.youtube.com/watch?v=fbLgFrlTnGU),
[2](https://www.youtube.com/watch?v=W-O7AZNzbzQ)
* Score-based generative models: [NCSN](https://arxiv.org/abs/1907.05600),
[NCSN+](https://arxiv.org/abs/2006.09011), [NCSN++](https://arxiv.org/abs/2011.13456)
* Denoising diffusion models: [DDPM](https://arxiv.org/abs/2006.11239),
[DDIM](https://arxiv.org/abs/2010.02502), [DDPM+](https://arxiv.org/abs/2102.09672),
[DDPM++](https://arxiv.org/abs/2105.05233)
* Large diffusion models: [GLIDE](https://arxiv.org/abs/2112.10741),
[DALL-E 2](https://arxiv.org/abs/2204.06125/), [Imagen](https://arxiv.org/abs/2205.11487)
"""

@ -69,7 +69,8 @@ class Concatenate(Merge):
if axis != concat_axis and axis_value == 1:
del reduced_inputs_shapes[i][axis]
del reduced_inputs_shapes[i][self.axis]
if len(reduced_inputs_shapes[i]) > self.axis:
del reduced_inputs_shapes[i][self.axis]
shape_set.add(tuple(reduced_inputs_shapes[i]))
if len(shape_set) != 1: