diff --git a/examples/keras_io/tensorflow/generative/ddim.py b/examples/keras_io/tensorflow/generative/ddim.py new file mode 100644 index 000000000..7e4caa935 --- /dev/null +++ b/examples/keras_io/tensorflow/generative/ddim.py @@ -0,0 +1,889 @@ +""" +Title: Denoising Diffusion Implicit Models +Author: [András Béres](https://www.linkedin.com/in/andras-beres-789190210) +Date created: 2022/06/24 +Last modified: 2022/06/24 +Description: Generating images of flowers with denoising diffusion implicit models. +Accelerator: GPU +""" + +""" +## Introduction + +### What are diffusion models? + +Recently, [denoising diffusion models](https://arxiv.org/abs/2006.11239), including +[score-based generative models](https://arxiv.org/abs/1907.05600), gained popularity as a +powerful class of generative models, that can [rival](https://arxiv.org/abs/2105.05233) +even [generative adversarial networks (GANs)](https://arxiv.org/abs/1406.2661) in image +synthesis quality. They tend to generate more diverse samples, while being stable to +train and easy to scale. Recent large diffusion models, such as +[DALL-E 2](https://openai.com/dall-e-2/) and [Imagen](https://imagen.research.google/), +have shown incredible text-to-image generation capability. One of their drawbacks is +however, that they are slower to sample from, because they require multiple forward passes +for generating an image. + +Diffusion refers to the process of turning a structured signal (an image) into noise +step-by-step. By simulating diffusion, we can generate noisy images from our training +images, and can train a neural network to try to denoise them. Using the trained network +we can simulate the opposite of diffusion, reverse diffusion, which is the process of an +image emerging from noise. + +![diffusion process gif](https://i.imgur.com/dipPOfa.gif) + +One-sentence summary: **diffusion models are trained to denoise noisy images, and can +generate images by iteratively denoising pure noise.** + +### Goal of this example + +This code example intends to be a minimal but feature-complete (with a generation quality +metric) implementation of diffusion models, with modest compute requirements and +reasonable performance. My implementation choices and hyperparameter tuning were done +with these goals in mind. + +Since currently the literature of diffusion models is +[mathematically quite complex](https://arxiv.org/abs/2206.00364) +with multiple theoretical frameworks +([score matching](https://arxiv.org/abs/1907.05600), +[differential equations](https://arxiv.org/abs/2011.13456), +[Markov chains](https://arxiv.org/abs/2006.11239)) and sometimes even +[conflicting notations (see Appendix C.2)](https://arxiv.org/abs/2010.02502), +it can be daunting trying to understand +them. My view of these models in this example will be that they learn to separate a +noisy image into its image and Gaussian noise components. + +In this example I made effort to break down all long mathematical expressions into +digestible pieces and gave all variables explanatory names. I also included numerous +links to relevant literature to help interested readers dive deeper into the topic, in +the hope that this code example will become a good starting point for practitioners +learning about diffusion models. + +In the following sections, we will implement a continuous time version of +[Denoising Diffusion Implicit Models (DDIMs)](https://arxiv.org/abs/2010.02502) +with deterministic sampling. +""" + +""" +## Setup +""" + +import math +import matplotlib.pyplot as plt +import tensorflow as tf +import tensorflow_datasets as tfds + +import keras_core as keras +from keras_core import layers +from keras_core import ops + +""" +## Hyperparameters +""" + +# data +dataset_name = "oxford_flowers102" +dataset_repetitions = 5 +num_epochs = 1 # train for at least 50 epochs for good results +image_size = 64 +# KID = Kernel Inception Distance, see related section +kid_image_size = 75 +kid_diffusion_steps = 5 +plot_diffusion_steps = 20 + +# sampling +min_signal_rate = 0.02 +max_signal_rate = 0.95 + +# architecture +embedding_dims = 32 +embedding_max_frequency = 1000.0 +widths = [32, 64, 96, 128] +block_depth = 2 + +# optimization +batch_size = 64 +ema = 0.999 +learning_rate = 1e-3 +weight_decay = 1e-4 + +""" +## Data pipeline + +We will use the +[Oxford Flowers 102](https://www.tensorflow.org/datasets/catalog/oxford_flowers102) +dataset for +generating images of flowers, which is a diverse natural dataset containing around 8,000 +images. Unfortunately the official splits are imbalanced, as most of the images are +contained in the test split. We create new splits (80% train, 20% validation) using the +[Tensorflow Datasets slicing API](https://www.tensorflow.org/datasets/splits). We apply +center crops as preprocessing, and repeat the dataset multiple times (reason given in the +next section). +""" + + +def preprocess_image(data): + # center crop image + height = ops.shape(data["image"])[0] + width = ops.shape(data["image"])[1] + crop_size = ops.minimum(height, width) + image = tf.image.crop_to_bounding_box( + data["image"], + (height - crop_size) // 2, + (width - crop_size) // 2, + crop_size, + crop_size, + ) + + # resize and clip + # for image downsampling it is important to turn on antialiasing + image = tf.image.resize( + image, size=[image_size, image_size], antialias=True + ) + return ops.clip(image / 255.0, 0.0, 1.0) + + +def prepare_dataset(split): + # the validation dataset is shuffled as well, because data order matters + # for the KID estimation + return ( + tfds.load(dataset_name, split=split, shuffle_files=True) + .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE) + .cache() + .repeat(dataset_repetitions) + .shuffle(10 * batch_size) + .batch(batch_size, drop_remainder=True) + .prefetch(buffer_size=tf.data.AUTOTUNE) + ) + + +# load dataset +train_dataset = prepare_dataset("train[:80%]+validation[:80%]+test[:80%]") +val_dataset = prepare_dataset("train[80%:]+validation[80%:]+test[80%:]") + +""" +## Kernel inception distance + +[Kernel Inception Distance (KID)](https://arxiv.org/abs/1801.01401) is an image quality +metric which was proposed as a replacement for the popular +[Frechet Inception Distance (FID)](https://arxiv.org/abs/1706.08500). +I prefer KID to FID because it is simpler to +implement, can be estimated per-batch, and is computationally lighter. More details +[here](https://keras.io/examples/generative/gan_ada/#kernel-inception-distance). + +In this example, the images are evaluated at the minimal possible resolution of the +Inception network (75x75 instead of 299x299), and the metric is only measured on the +validation set for computational efficiency. We also limit the number of sampling steps +at evaluation to 5 for the same reason. + +Since the dataset is relatively small, we go over the train and validation splits +multiple times per epoch, because the KID estimation is noisy and compute-intensive, so +we want to evaluate only after many iterations, but for many iterations. + +""" + + +@keras.saving.register_keras_serializable() +class KID(keras.metrics.Metric): + def __init__(self, name, **kwargs): + super().__init__(name=name, **kwargs) + + # KID is estimated per batch and is averaged across batches + self.kid_tracker = keras.metrics.Mean(name="kid_tracker") + + # a pretrained InceptionV3 is used without its classification layer + # transform the pixel values to the 0-255 range, then use the same + # preprocessing as during pretraining + self.encoder = keras.Sequential( + [ + keras.Input(shape=(image_size, image_size, 3)), + layers.Rescaling(255.0), + layers.Resizing(height=kid_image_size, width=kid_image_size), + layers.Lambda(keras.applications.inception_v3.preprocess_input), + keras.applications.InceptionV3( + include_top=False, + input_shape=(kid_image_size, kid_image_size, 3), + weights="imagenet", + ), + layers.GlobalAveragePooling2D(), + ], + name="inception_encoder", + ) + + def polynomial_kernel(self, features_1, features_2): + feature_dimensions = ops.cast(ops.shape(features_1)[1], dtype="float32") + return ( + features_1 @ ops.transpose(features_2) / feature_dimensions + 1.0 + ) ** 3.0 + + def update_state(self, real_images, generated_images, sample_weight=None): + real_features = self.encoder(real_images, training=False) + generated_features = self.encoder(generated_images, training=False) + + # compute polynomial kernels using the two sets of features + kernel_real = self.polynomial_kernel(real_features, real_features) + kernel_generated = self.polynomial_kernel( + generated_features, generated_features + ) + kernel_cross = self.polynomial_kernel(real_features, generated_features) + + # estimate the squared maximum mean discrepancy using the average kernel values + batch_size = real_features.shape[0] + batch_size_f = ops.cast(batch_size, dtype="float32") + mean_kernel_real = ops.sum( + kernel_real * (1.0 - ops.eye(batch_size)) + ) / (batch_size_f * (batch_size_f - 1.0)) + mean_kernel_generated = ops.sum( + kernel_generated * (1.0 - ops.eye(batch_size)) + ) / (batch_size_f * (batch_size_f - 1.0)) + mean_kernel_cross = ops.mean(kernel_cross) + kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross + + # update the average KID estimate + self.kid_tracker.update_state(kid) + + def result(self): + return self.kid_tracker.result() + + def reset_state(self): + self.kid_tracker.reset_state() + + +""" +## Network architecture + +Here we specify the architecture of the neural network that we will use for denoising. We +build a [U-Net](https://arxiv.org/abs/1505.04597) with identical input and output +dimensions. U-Net is a popular semantic segmentation architecture, whose main idea is +that it progressively downsamples and then upsamples its input image, and adds skip +connections between layers having the same resolution. These help with gradient flow and +avoid introducing a representation bottleneck, unlike usual +[autoencoders](https://www.deeplearningbook.org/contents/autoencoders.html). Based on +this, one can view +[diffusion models as denoising autoencoders](https://benanne.github.io/2022/01/31/diffusion.html) +without a bottleneck. + +The network takes two inputs, the noisy images and the variances of their noise +components. The latter is required since denoising a signal requires different operations +at different levels of noise. We transform the noise variances using sinusoidal +embeddings, similarly to positional encodings used both in +[transformers](https://arxiv.org/abs/1706.03762) and +[NeRF](https://arxiv.org/abs/2003.08934). This helps the network to be +[highly sensitive](https://arxiv.org/abs/2006.10739) to the noise level, which is +crucial for good performance. We implement sinusoidal embeddings using a +[Lambda layer](https://keras.io/api/layers/core_layers/lambda/). + +Some other considerations: + +* We build the network using the +[Keras Functional API](https://keras.io/guides/functional_api/), and use +[closures](https://twitter.com/fchollet/status/1441927912836321280) to build blocks of +layers in a consistent style. +* [Diffusion models](https://arxiv.org/abs/2006.11239) embed the index of the timestep of +the diffusion process instead of the noise variance, while +[score-based models (Table 1)](https://arxiv.org/abs/2206.00364) +usually use some function of the noise level. I +prefer the latter so that we can change the sampling schedule at inference time, without +retraining the network. +* [Diffusion models](https://arxiv.org/abs/2006.11239) input the embedding to each +convolution block separately. We only input it at the start of the network for +simplicity, which in my experience barely decreases performance, because the skip and +residual connections help the information propagate through the network properly. +* In the literature it is common to use +[attention layers](https://keras.io/api/layers/attention_layers/multi_head_attention/) +at lower resolutions for better global coherence. I omitted it for simplicity. +* We disable the learnable center and scale parameters of the batch normalization layers, +since the following convolution layers make them redundant. +* We initialize the last convolution's kernel to all zeros as a good practice, making the +network predict only zeros after initialization, which is the mean of its targets. This +will improve behaviour at the start of training and make the mean squared error loss +start at exactly 1. +""" + + +@keras.saving.register_keras_serializable() +def sinusoidal_embedding(x): + embedding_min_frequency = 1.0 + frequencies = ops.exp( + ops.linspace( + ops.log(embedding_min_frequency), + ops.log(embedding_max_frequency), + embedding_dims // 2, + ) + ) + angular_speeds = ops.cast(2.0 * math.pi * frequencies, "float32") + embeddings = ops.concatenate( + [ops.sin(angular_speeds * x), ops.cos(angular_speeds * x)], axis=3 + ) + return embeddings + + +def ResidualBlock(width): + def apply(x): + input_width = x.shape[3] + if input_width == width: + residual = x + else: + residual = layers.Conv2D(width, kernel_size=1)(x) + x = layers.BatchNormalization(center=False, scale=False)(x) + x = layers.Conv2D( + width, kernel_size=3, padding="same", activation="swish" + )(x) + x = layers.Conv2D(width, kernel_size=3, padding="same")(x) + x = layers.Add()([x, residual]) + return x + + return apply + + +def DownBlock(width, block_depth): + def apply(x): + x, skips = x + for _ in range(block_depth): + x = ResidualBlock(width)(x) + skips.append(x) + x = layers.AveragePooling2D(pool_size=2)(x) + return x + + return apply + + +def UpBlock(width, block_depth): + def apply(x): + x, skips = x + x = layers.UpSampling2D(size=2, interpolation="bilinear")(x) + for _ in range(block_depth): + x = layers.Concatenate()([x, skips.pop()]) + x = ResidualBlock(width)(x) + return x + + return apply + + +def get_network(image_size, widths, block_depth): + noisy_images = keras.Input(shape=(image_size, image_size, 3)) + noise_variances = keras.Input(shape=(1, 1, 1)) + + e = layers.Lambda(sinusoidal_embedding, output_shape=(1, 1, 32))( + noise_variances + ) + e = layers.UpSampling2D(size=image_size, interpolation="nearest")(e) + + x = layers.Conv2D(widths[0], kernel_size=1)(noisy_images) + x = layers.Concatenate()([x, e]) + + skips = [] + for width in widths[:-1]: + x = DownBlock(width, block_depth)([x, skips]) + + for _ in range(block_depth): + x = ResidualBlock(widths[-1])(x) + + for width in reversed(widths[:-1]): + x = UpBlock(width, block_depth)([x, skips]) + + x = layers.Conv2D(3, kernel_size=1, kernel_initializer="zeros")(x) + + return keras.Model([noisy_images, noise_variances], x, name="residual_unet") + + +""" +This showcases the power of the Functional API. Note how we built a relatively complex +U-Net with skip connections, residual blocks, multiple inputs, and sinusoidal embeddings +in 80 lines of code! +""" + +""" +## Diffusion model + +### Diffusion schedule + +Let us say, that a diffusion process starts at time = 0, and ends at time = 1. This +variable will be called diffusion time, and can be either discrete (common in diffusion +models) or continuous (common in score-based models). I choose the latter, so that the +number of sampling steps can be changed at inference time. + +We need to have a function that tells us at each point in the diffusion process the noise +levels and signal levels of the noisy image corresponding to the actual diffusion time. +This will be called the diffusion schedule (see `diffusion_schedule()`). + +This schedule outputs two quantities: the `noise_rate` and the `signal_rate` +(corresponding to sqrt(1 - alpha) and sqrt(alpha) in the DDIM paper, respectively). We +generate the noisy image by weighting the random noise and the training image by their +corresponding rates and adding them together. + +Since the (standard normal) random noises and the (normalized) images both have zero mean +and unit variance, the noise rate and signal rate can be interpreted as the standard +deviation of their components in the noisy image, while the squares of their rates can be +interpreted as their variance (or their power in the signal processing sense). The rates +will always be set so that their squared sum is 1, meaning that the noisy images will +always have unit variance, just like its unscaled components. + +We will use a simplified, continuous version of the +[cosine schedule (Section 3.2)](https://arxiv.org/abs/2102.09672), +that is quite commonly used in the literature. +This schedule is symmetric, slow towards the start and end of the diffusion process, and +it also has a nice geometric interpretation, using the +[trigonometric properties of the unit circle](https://en.wikipedia.org/wiki/Unit_circle#/media/File:Circle-trig6.svg): + +![diffusion schedule gif](https://i.imgur.com/JW9W0fA.gif) + +### Training process + +The training procedure (see `train_step()` and `denoise()`) of denoising diffusion models +is the following: we sample random diffusion times uniformly, and mix the training images +with random gaussian noises at rates corresponding to the diffusion times. Then, we train +the model to separate the noisy image to its two components. + +Usually, the neural network is trained to predict the unscaled noise component, from +which the predicted image component can be calculated using the signal and noise rates. +Pixelwise +[mean squared error](https://keras.io/api/losses/regression_losses/#mean_squared_error-function) should +be used theoretically, however I recommend using +[mean absolute error](https://keras.io/api/losses/regression_losses/#mean_absolute_error-function) +instead (similarly to +[this](https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L371) +implementation), which produces better results on this dataset. + +### Sampling (reverse diffusion) + +When sampling (see `reverse_diffusion()`), at each step we take the previous estimate of +the noisy image and separate it into image and noise using our network. Then we recombine +these components using the signal and noise rate of the following step. + +Though a similar view is shown in +[Equation 12 of DDIMs](https://arxiv.org/abs/2010.02502), I believe the above explanation +of the sampling equation is not widely known. + +This example only implements the deterministic sampling procedure from DDIM, which +corresponds to *eta = 0* in the paper. One can also use stochastic sampling (in which +case the model becomes a +[Denoising Diffusion Probabilistic Model (DDPM)](https://arxiv.org/abs/2006.11239)), +where a part of the predicted noise is +replaced with the same or larger amount of random noise +([see Equation 16 and below](https://arxiv.org/abs/2010.02502)). + +Stochastic sampling can be used without retraining the network (since both models are +trained the same way), and it can improve sample quality, while on the other hand +requiring more sampling steps usually. +""" + + +@keras.saving.register_keras_serializable() +class DiffusionModel(keras.Model): + def __init__(self, image_size, widths, block_depth): + super().__init__() + + self.normalizer = layers.Normalization() + self.network = get_network(image_size, widths, block_depth) + self.ema_network = keras.models.clone_model(self.network) + + def compile(self, **kwargs): + super().compile(**kwargs) + + self.noise_loss_tracker = keras.metrics.Mean(name="n_loss") + self.image_loss_tracker = keras.metrics.Mean(name="i_loss") + self.kid = KID(name="kid") + + @property + def metrics(self): + return [self.noise_loss_tracker, self.image_loss_tracker, self.kid] + + def denormalize(self, images): + # convert the pixel values back to 0-1 range + images = self.normalizer.mean + images * self.normalizer.variance**0.5 + return ops.clip(images, 0.0, 1.0) + + def diffusion_schedule(self, diffusion_times): + # diffusion times -> angles + start_angle = ops.cast(ops.arccos(max_signal_rate), "float32") + end_angle = ops.cast(ops.arccos(min_signal_rate), "float32") + + diffusion_angles = start_angle + diffusion_times * ( + end_angle - start_angle + ) + + # angles -> signal and noise rates + signal_rates = ops.cos(diffusion_angles) + noise_rates = ops.sin(diffusion_angles) + # note that their squared sum is always: sin^2(x) + cos^2(x) = 1 + + return noise_rates, signal_rates + + def denoise(self, noisy_images, noise_rates, signal_rates, training): + # the exponential moving average weights are used at evaluation + if training: + network = self.network + else: + network = self.ema_network + + # predict noise component and calculate the image component using it + pred_noises = network( + [noisy_images, noise_rates**2], training=training + ) + pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates + + return pred_noises, pred_images + + def reverse_diffusion(self, initial_noise, diffusion_steps): + # reverse diffusion = sampling + num_images = initial_noise.shape[0] + step_size = 1.0 / diffusion_steps + + # important line: + # at the first sampling step, the "noisy image" is pure noise + # but its signal rate is assumed to be nonzero (min_signal_rate) + next_noisy_images = initial_noise + for step in range(diffusion_steps): + noisy_images = next_noisy_images + + # separate the current noisy image to its components + diffusion_times = ops.ones((num_images, 1, 1, 1)) - step * step_size + noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) + pred_noises, pred_images = self.denoise( + noisy_images, noise_rates, signal_rates, training=False + ) + # network used in eval mode + + # remix the predicted components using the next signal and noise rates + next_diffusion_times = diffusion_times - step_size + next_noise_rates, next_signal_rates = self.diffusion_schedule( + next_diffusion_times + ) + next_noisy_images = ( + next_signal_rates * pred_images + next_noise_rates * pred_noises + ) + # this new noisy image will be used in the next step + + return pred_images + + def generate(self, num_images, diffusion_steps): + # noise -> images -> denormalized images + initial_noise = keras.random.normal( + shape=(num_images, image_size, image_size, 3) + ) + generated_images = self.reverse_diffusion( + initial_noise, diffusion_steps + ) + generated_images = self.denormalize(generated_images) + return generated_images + + def train_step(self, images): + # normalize images to have standard deviation of 1, like the noises + images = self.normalizer(images, training=True) + noises = keras.random.normal( + shape=(batch_size, image_size, image_size, 3) + ) + + # sample uniform random diffusion times + diffusion_times = keras.random.uniform( + shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0 + ) + noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) + # mix the images with noises accordingly + noisy_images = signal_rates * images + noise_rates * noises + + with tf.GradientTape() as tape: + # train the network to separate noisy images to their components + pred_noises, pred_images = self.denoise( + noisy_images, noise_rates, signal_rates, training=True + ) + + noise_loss = self.loss(noises, pred_noises) # used for training + image_loss = self.loss(images, pred_images) # only used as metric + + gradients = tape.gradient(noise_loss, self.network.trainable_weights) + self.optimizer.apply_gradients( + zip(gradients, self.network.trainable_weights) + ) + + self.noise_loss_tracker.update_state(noise_loss) + self.image_loss_tracker.update_state(image_loss) + + # track the exponential moving averages of weights + for weight, ema_weight in zip( + self.network.weights, self.ema_network.weights + ): + ema_weight.assign(ema * ema_weight + (1 - ema) * weight) + + # KID is not measured during the training phase for computational efficiency + return {m.name: m.result() for m in self.metrics[:-1]} + + def test_step(self, images): + # normalize images to have standard deviation of 1, like the noises + images = self.normalizer(images, training=False) + noises = keras.random.normal( + shape=(batch_size, image_size, image_size, 3) + ) + + # sample uniform random diffusion times + diffusion_times = keras.random.uniform( + shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0 + ) + noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) + # mix the images with noises accordingly + noisy_images = signal_rates * images + noise_rates * noises + + # use the network to separate noisy images to their components + pred_noises, pred_images = self.denoise( + noisy_images, noise_rates, signal_rates, training=False + ) + + noise_loss = self.loss(noises, pred_noises) + image_loss = self.loss(images, pred_images) + + self.image_loss_tracker.update_state(image_loss) + self.noise_loss_tracker.update_state(noise_loss) + + # measure KID between real and generated images + # this is computationally demanding, kid_diffusion_steps has to be small + images = self.denormalize(images) + generated_images = self.generate( + num_images=batch_size, diffusion_steps=kid_diffusion_steps + ) + self.kid.update_state(images, generated_images) + + return {m.name: m.result() for m in self.metrics} + + def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6): + # plot random generated images for visual evaluation of generation quality + generated_images = self.generate( + num_images=num_rows * num_cols, + diffusion_steps=plot_diffusion_steps, + ) + + plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0)) + for row in range(num_rows): + for col in range(num_cols): + index = row * num_cols + col + plt.subplot(num_rows, num_cols, index + 1) + plt.imshow(generated_images[index]) + plt.axis("off") + plt.tight_layout() + plt.show() + plt.close() + + +""" +## Training +""" + +# create and compile the model +model = DiffusionModel(image_size, widths, block_depth) +# below tensorflow 2.9: +# pip install tensorflow_addons +# import tensorflow_addons as tfa +# optimizer=tfa.optimizers.AdamW +model.compile( + optimizer=keras.optimizers.AdamW( + learning_rate=learning_rate, weight_decay=weight_decay + ), + loss=keras.losses.mean_absolute_error, +) +# pixelwise mean absolute error is used as loss + +# save the best model based on the validation KID metric +checkpoint_path = "checkpoints/diffusion_model.weights.h5" +checkpoint_callback = keras.callbacks.ModelCheckpoint( + filepath=checkpoint_path, + save_weights_only=True, + monitor="val_kid", + mode="min", + save_best_only=True, +) + +# calculate mean and variance of training dataset for normalization +model.normalizer.adapt(train_dataset) + +# run training and plot generated images periodically +model.fit( + train_dataset, + epochs=num_epochs, + validation_data=val_dataset, + callbacks=[ + keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images), + checkpoint_callback, + ], +) + +""" +## Inference +""" + +# load the best model and generate images +model.load_weights(checkpoint_path) +model.plot_images() + +""" +## Results + +By running the training for at least 50 epochs (takes 2 hours on a T4 GPU and 30 minutes +on an A100 GPU), one can get high quality image generations using this code example. + +The evolution of a batch of images over a 80 epoch training (color artifacts are due to +GIF compression): + +![flowers training gif](https://i.imgur.com/FSCKtZq.gif) + +Images generated using between 1 and 20 sampling steps from the same initial noise: + +![flowers sampling steps gif](https://i.imgur.com/tM5LyH3.gif) + +Interpolation (spherical) between initial noise samples: + +![flowers interpolation gif](https://i.imgur.com/hk5Hd5o.gif) + +Deterministic sampling process (noisy images on top, predicted images on bottom, 40 +steps): + +![flowers deterministic generation gif](https://i.imgur.com/wCvzynh.gif) + +Stochastic sampling process (noisy images on top, predicted images on bottom, 80 steps): + +![flowers stochastic generation gif](https://i.imgur.com/kRXOGzd.gif) + +Trained model and demo available on HuggingFace: + +| Trained Model | Demo | +| :--: | :--: | +| [![model badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-DDIM-black.svg)](https://huggingface.co/keras-io/denoising-diffusion-implicit-models) | [![spaces badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-DDIM-black.svg)](https://huggingface.co/spaces/keras-io/denoising-diffusion-implicit-models) | +""" + +""" +## Lessons learned + +During preparation for this code example I have run numerous experiments using +[this repository](https://github.com/beresandras/clear-diffusion-keras). +In this section I list +the lessons learned and my recommendations in my subjective order of importance. + +### Algorithmic tips + +* **min. and max. signal rates**: I found the min. signal rate to be an important +hyperparameter. Setting it too low will make the generated images oversaturated, while +setting it too high will make them undersaturated. I recommend tuning it carefully. Also, +setting it to 0 will lead to a division by zero error. The max. signal rate can be set to +1, but I found that setting it lower slightly improves generation quality. +* **loss function**: While large models tend to use mean squared error (MSE) loss, I +recommend using mean absolute error (MAE) on this dataset. In my experience MSE loss +generates more diverse samples (it also seems to hallucinate more +[Section 3](https://arxiv.org/abs/2111.05826)), while MAE loss leads to smoother images. +I recommend trying both. +* **weight decay**: I did occasionally run into diverged trainings when scaling up the +model, and found that weight decay helps in avoiding instabilities at a low performance +cost. This is why I use +[AdamW](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/experimental/AdamW) +instead of [Adam](https://keras.io/api/optimizers/adam/) in this example. +* **exponential moving average of weights**: This helps to reduce the variance of the KID +metric, and helps in averaging out short-term changes during training. +* **image augmentations**: Though I did not use image augmentations in this example, in +my experience adding horizontal flips to the training increases generation performance, +while random crops do not. Since we use a supervised denoising loss, overfitting can be +an issue, so image augmentations might be important on small datasets. One should also be +careful not to use +[leaky augmentations](https://keras.io/examples/generative/gan_ada/#invertible-data-augmentation), +which can be done following +[this method (end of Section 5)](https://arxiv.org/abs/2206.00364) for instance. +* **data normalization**: In the literature the pixel values of images are usually +converted to the -1 to 1 range. For theoretical correctness, I normalize the images to +have zero mean and unit variance instead, exactly like the random noises. +* **noise level input**: I chose to input the noise variance to the network, as it is +symmetrical under our sampling schedule. One could also input the noise rate (similar +performance), the signal rate (lower performance), or even the +[log-signal-to-noise ratio (Appendix B.1)](https://arxiv.org/abs/2107.00630) +(did not try, as its range is highly +dependent on the min. and max. signal rates, and would require adjusting the min. +embedding frequency accordingly). +* **gradient clipping**: Using global gradient clipping with a value of 1 can help with +training stability for large models, but decreased performance significantly in my +experience. +* **residual connection downscaling**: For +[deeper models (Appendix B)](https://arxiv.org/abs/2205.11487), scaling the residual +connections with 1/sqrt(2) can be helpful, but did not help in my case. +* **learning rate**: For me, [Adam optimizer's](https://keras.io/api/optimizers/adam/) +default learning rate of 1e-3 worked very well, but lower learning rates are more common +in the [literature (Tables 11-13)](https://arxiv.org/abs/2105.05233). + +### Architectural tips + +* **sinusoidal embedding**: Using sinusoidal embeddings on the noise level input of the +network is crucial for good performance. I recommend setting the min. embedding frequency +to the reciprocal of the range of this input, and since we use the noise variance in this +example, it can be left always at 1. The max. embedding frequency controls the smallest +change in the noise variance that the network will be sensitive to, and the embedding +dimensions set the number of frequency components in the embedding. In my experience the +performance is not too sensitive to these values. +* **skip connections**: Using skip connections in the network architecture is absolutely +critical, without them the model will fail to learn to denoise at a good performance. +* **residual connections**: In my experience residual connections also significantly +improve performance, but this might be due to the fact that we only input the noise +level embeddings to the first layer of the network instead of to all of them. +* **normalization**: When scaling up the model, I did occasionally encounter diverged +trainings, using normalization layers helped to mitigate this issue. In the literature it +is common to use +[GroupNormalization](https://www.tensorflow.org/addons/api_docs/python/tfa/layers/GroupNormalization) +(with 8 groups for example) or +[LayerNormalization](https://keras.io/api/layers/normalization_layers/layer_normalization/) +in the network, I however chose to use +[BatchNormalization](https://keras.io/api/layers/normalization_layers/batch_normalization/), +as it gave similar benefits in my experiments but was computationally lighter. +* **activations**: The choice of activation functions had a larger effect on generation +quality than I expected. In my experiments using non-monotonic activation functions +outperformed monotonic ones (such as +[ReLU](https://www.tensorflow.org/api_docs/python/tf/keras/activations/relu)), with +[Swish](https://www.tensorflow.org/api_docs/python/tf/keras/activations/swish) performing +the best (this is also what [Imagen uses, page 41](https://arxiv.org/abs/2205.11487)). +* **attention**: As mentioned earlier, it is common in the literature to use +[attention layers](https://keras.io/api/layers/attention_layers/multi_head_attention/) at low +resolutions for better global coherence. I omitted them for simplicity. +* **upsampling**: +[Bilinear and nearest neighbour upsampling](https://keras.io/api/layers/reshaping_layers/up_sampling2d/) +in the network performed similarly, however I did not try +[transposed convolutions](https://keras.io/api/layers/convolution_layers/convolution2d_transpose/). + +For a similar list about GANs check out +[this Keras tutorial](https://keras.io/examples/generative/gan_ada/#gan-tips-and-tricks). +""" + +""" +## What to try next? + +If you would like to dive in deeper to the topic, I recommend checking out +[this repository](https://github.com/beresandras/clear-diffusion-keras) that I created in +preparation for this code example, which implements a wider range of features in a +similar style, such as: + +* stochastic sampling +* second-order sampling based on the +[differential equation view of DDIMs (Equation 13)](https://arxiv.org/abs/2010.02502) +* more diffusion schedules +* more network output types: predicting image or +[velocity (Appendix D)](https://arxiv.org/abs/2202.00512) instead of noise +* more datasets +""" + +""" +## Related works + +* [Score-based generative modeling](https://yang-song.github.io/blog/2021/score/) +(blogpost) +* [What are diffusion models?](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/) +(blogpost) +* [Annotated diffusion model](https://huggingface.co/blog/annotated-diffusion) (blogpost) +* [CVPR 2022 tutorial on diffusion models](https://cvpr2022-tutorial-diffusion-models.github.io/) +(slides available) +* [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364): +attempts unifying diffusion methods under a common framework +* High-level video overviews: [1](https://www.youtube.com/watch?v=yTAMrHVG1ew), +[2](https://www.youtube.com/watch?v=344w5h24-h8) +* Detailed technical videos: [1](https://www.youtube.com/watch?v=fbLgFrlTnGU), +[2](https://www.youtube.com/watch?v=W-O7AZNzbzQ) +* Score-based generative models: [NCSN](https://arxiv.org/abs/1907.05600), +[NCSN+](https://arxiv.org/abs/2006.09011), [NCSN++](https://arxiv.org/abs/2011.13456) +* Denoising diffusion models: [DDPM](https://arxiv.org/abs/2006.11239), +[DDIM](https://arxiv.org/abs/2010.02502), [DDPM+](https://arxiv.org/abs/2102.09672), +[DDPM++](https://arxiv.org/abs/2105.05233) +* Large diffusion models: [GLIDE](https://arxiv.org/abs/2112.10741), +[DALL-E 2](https://arxiv.org/abs/2204.06125/), [Imagen](https://arxiv.org/abs/2205.11487) + + +""" diff --git a/keras_core/layers/merging/concatenate.py b/keras_core/layers/merging/concatenate.py index 5d32ebdbe..bdbd5130c 100644 --- a/keras_core/layers/merging/concatenate.py +++ b/keras_core/layers/merging/concatenate.py @@ -69,7 +69,8 @@ class Concatenate(Merge): if axis != concat_axis and axis_value == 1: del reduced_inputs_shapes[i][axis] - del reduced_inputs_shapes[i][self.axis] + if len(reduced_inputs_shapes[i]) > self.axis: + del reduced_inputs_shapes[i][self.axis] shape_set.add(tuple(reduced_inputs_shapes[i])) if len(shape_set) != 1: