diff --git a/examples/keras_io/tensorflow/vision/mixup.py b/examples/keras_io/tensorflow/vision/mixup.py new file mode 100644 index 000000000..2666d3f84 --- /dev/null +++ b/examples/keras_io/tensorflow/vision/mixup.py @@ -0,0 +1,229 @@ +""" +Title: MixUp augmentation for image classification +Author: [Sayak Paul](https://twitter.com/RisingSayak) +Date created: 2021/03/06 +Last modified: 2021/03/06 +Description: Data augmentation using the mixup technique for image classification. +Accelerator: GPU +""" +""" +## Introduction +""" + +""" +_mixup_ is a *domain-agnostic* data augmentation technique proposed in [mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412) +by Zhang et al. It's implemented with the following formulas: + +![](https://i.ibb.co/DRyHYww/image.png) + +(Note that the lambda values are values with the [0, 1] range and are sampled from the +[Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution).) + +The technique is quite systematically named. We are literally mixing up the features and +their corresponding labels. Implementation-wise it's simple. Neural networks are prone +to [memorizing corrupt labels](https://arxiv.org/abs/1611.03530). mixup relaxes this by +combining different features with one another (same happens for the labels too) so that +a network does not get overconfident about the relationship between the features and +their labels. + +mixup is specifically useful when we are not sure about selecting a set of augmentation +transforms for a given dataset, medical imaging datasets, for example. mixup can be +extended to a variety of data modalities such as computer vision, naturallanguage +processing, speech, and so on. +""" + +""" +## Setup +""" + +import numpy as np +import tensorflow as tf +import matplotlib.pyplot as plt +from keras_core import layers +import keras_core as keras + +""" +## Prepare the dataset + +In this example, we will be using the [FashionMNIST](https://github.com/zalandoresearch/fashion-mnist) dataset. But this same recipe can +be used for other classification datasets as well. +""" + +(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data() + +x_train = x_train.astype("float32") / 255.0 +x_train = np.reshape(x_train, (-1, 28, 28, 1)) +y_train = tf.one_hot(y_train, 10) + +x_test = x_test.astype("float32") / 255.0 +x_test = np.reshape(x_test, (-1, 28, 28, 1)) +y_test = tf.one_hot(y_test, 10) + +""" +## Define hyperparameters +""" + +AUTO = tf.data.AUTOTUNE +BATCH_SIZE = 64 +EPOCHS = 10 + +""" +## Convert the data into TensorFlow `Dataset` objects +""" + +# Put aside a few samples to create our validation set +val_samples = 2000 +x_val, y_val = x_train[:val_samples], y_train[:val_samples] +new_x_train, new_y_train = x_train[val_samples:], y_train[val_samples:] + +train_ds_one = ( + tf.data.Dataset.from_tensor_slices((new_x_train, new_y_train)) + .shuffle(BATCH_SIZE * 100) + .batch(BATCH_SIZE) +) +train_ds_two = ( + tf.data.Dataset.from_tensor_slices((new_x_train, new_y_train)) + .shuffle(BATCH_SIZE * 100) + .batch(BATCH_SIZE) +) +# Because we will be mixing up the images and their corresponding labels, we will be +# combining two shuffled datasets from the same training data. +train_ds = tf.data.Dataset.zip((train_ds_one, train_ds_two)) + +val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(BATCH_SIZE) + +test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE) + +""" +## Define the mixup technique function + +To perform the mixup routine, we create new virtual datasets using the training data from +the same dataset, and apply a lambda value within the [0, 1] range sampled from a [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution) +— such that, for example, `new_x = lambda * x1 + (1 - lambda) * x2` (where +`x1` and `x2` are images) and the same equation is applied to the labels as well. +""" + + +def sample_beta_distribution(size, concentration_0=0.2, concentration_1=0.2): + gamma_1_sample = tf.random.gamma(shape=[size], alpha=concentration_1) + gamma_2_sample = tf.random.gamma(shape=[size], alpha=concentration_0) + return gamma_1_sample / (gamma_1_sample + gamma_2_sample) + + +def mix_up(ds_one, ds_two, alpha=0.2): + # Unpack two datasets + images_one, labels_one = ds_one + images_two, labels_two = ds_two + batch_size = tf.shape(images_one)[0] + + # Sample lambda and reshape it to do the mixup + l = sample_beta_distribution(batch_size, alpha, alpha) + x_l = tf.reshape(l, (batch_size, 1, 1, 1)) + y_l = tf.reshape(l, (batch_size, 1)) + + # Perform mixup on both images and labels by combining a pair of images/labels + # (one from each dataset) into one image/label + images = images_one * x_l + images_two * (1 - x_l) + labels = labels_one * y_l + labels_two * (1 - y_l) + return (images, labels) + + +""" +**Note** that here , we are combining two images to create a single one. Theoretically, +we can combine as many we want but that comes at an increased computation cost. In +certain cases, it may not help improve the performance as well. +""" + +""" +## Visualize the new augmented dataset +""" + +# First create the new dataset using our `mix_up` utility +train_ds_mu = train_ds.map( + lambda ds_one, ds_two: mix_up(ds_one, ds_two, alpha=0.2), num_parallel_calls=AUTO +) + +# Let's preview 9 samples from the dataset +sample_images, sample_labels = next(iter(train_ds_mu)) +plt.figure(figsize=(10, 10)) +for i, (image, label) in enumerate(zip(sample_images[:9], sample_labels[:9])): + ax = plt.subplot(3, 3, i + 1) + plt.imshow(image.numpy().squeeze()) + print(label.numpy().tolist()) + plt.axis("off") + +""" +## Model building +""" + + +def get_training_model(): + model = keras.Sequential( + [ + layers.Conv2D(16, (5, 5), activation="relu", input_shape=(28, 28, 1)), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Conv2D(32, (5, 5), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Dropout(0.2), + layers.GlobalAveragePooling2D(), + layers.Dense(128, activation="relu"), + layers.Dense(10, activation="softmax"), + ] + ) + return model + + +""" +For the sake of reproducibility, we serialize the initial random weights of our shallow +network. +""" + +initial_model = get_training_model() +initial_model.save_weights("initial_weights.weights.h5") + +""" +## 1. Train the model with the mixed up dataset +""" + +model = get_training_model() +model.load_weights("initial_weights.weights.h5") +model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) +model.fit(train_ds_mu, validation_data=val_ds, epochs=EPOCHS) +_, test_acc = model.evaluate(test_ds) +print("Test accuracy: {:.2f}%".format(test_acc * 100)) + +""" +## 2. Train the model *without* the mixed up dataset +""" + +model = get_training_model() +model.load_weights("initial_weights.weights.h5") +model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) +# Notice that we are NOT using the mixed up dataset here +model.fit(train_ds_one, validation_data=val_ds, epochs=EPOCHS) +_, test_acc = model.evaluate(test_ds) +print("Test accuracy: {:.2f}%".format(test_acc * 100)) + +""" +Readers are encouraged to try out mixup on different datasets from different domains and +experiment with the lambda parameter. You are strongly advised to check out the +[original paper](https://arxiv.org/abs/1710.09412) as well - the authors present several ablation studies on mixup +showing how it can improve generalization, as well as show their results of combining +more than two images to create a single one. +""" + +""" +## Notes + +* With mixup, you can create synthetic examples — especially when you lack a large +dataset - without incurring high computational costs. +* [Label smoothing](https://www.pyimagesearch.com/2019/12/30/label-smoothing-with-keras-tensorflow-and-deep-learning/) and mixup usually do not work well together because label smoothing +already modifies the hard labels by some factor. +* mixup does not work well when you are using [Supervised Contrastive +Learning](https://arxiv.org/abs/2004.11362) (SCL) since SCL expects the true labels +during its pre-training phase. +* A few other benefits of mixup include (as described in the [paper](https://arxiv.org/abs/1710.09412)) robustness to +adversarial examples and stabilized GAN (Generative Adversarial Networks) training. +* There are a number of data augmentation techniques that extend mixup such as +[CutMix](https://arxiv.org/abs/1905.04899) and [AugMix](https://arxiv.org/abs/1912.02781). +""" \ No newline at end of file