From fcda274b0099a5d6d3d8c3aea13ab6621defa55c Mon Sep 17 00:00:00 2001 From: Muhammad Anas Raza <63569834+anas-rz@users.noreply.github.com> Date: Tue, 18 Jul 2023 14:19:11 -0400 Subject: [PATCH] Convert to Keras Core: Token Learner (#528) * Create token_learner.py * Change tf_data and fix typos * typo * Fix Typo * move to backend_agnostic --- examples/keras_io/vision/token_learner.py | 511 ++++++++++++++++++++++ 1 file changed, 511 insertions(+) create mode 100644 examples/keras_io/vision/token_learner.py diff --git a/examples/keras_io/vision/token_learner.py b/examples/keras_io/vision/token_learner.py new file mode 100644 index 000000000..6cd726b58 --- /dev/null +++ b/examples/keras_io/vision/token_learner.py @@ -0,0 +1,511 @@ +""" +Title: Learning to tokenize in Vision Transformers +Authors: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Sayak Paul](https://twitter.com/RisingSayak) (equal contribution) +Converted to Keras Core by: [Muhammad Anas Raza](https://anasrz.com) +Date created: 2021/12/10 +Last modified: 2023/07/18 +Description: Adaptively generating a smaller number of tokens for Vision Transformers. +Accelerator: GPU +""" +""" +## Introduction + +Vision Transformers ([Dosovitskiy et al.](https://arxiv.org/abs/2010.11929)) and many +other Transformer-based architectures ([Liu et al.](https://arxiv.org/abs/2103.14030), +[Yuan et al.](https://arxiv.org/abs/2101.11986), etc.) have shown strong results in +image recognition. The following provides a brief overview of the components involved in the +Vision Transformer architecture for image classification: + +* Extract small patches from input images. +* Linearly project those patches. +* Add positional embeddings to these linear projections. +* Run these projections through a series of Transformer ([Vaswani et al.](https://arxiv.org/abs/1706.03762)) +blocks. +* Finally, take the representation from the final Transformer block and add a +classification head. + +If we take 224x224 images and extract 16x16 patches, we get a total of 196 patches (also +called tokens) for each image. The number of patches increases as we increase the +resolution, leading to higher memory footprint. Could we use a reduced +number of patches without having to compromise performance? +Ryoo et al. investigate this question in +[TokenLearner: Adaptive Space-Time Tokenization for Videos](https://openreview.net/forum?id=z-l1kpDXs88). +They introduce a novel module called **TokenLearner** that can help reduce the number +of patches used by a Vision Transformer (ViT) in an adaptive manner. With TokenLearner +incorporated in the standard ViT architecture, they are able to reduce the amount of +compute (measured in FLOPS) used by the model. + +In this example, we implement the TokenLearner module and demonstrate its +performance with a mini ViT and the CIFAR-10 dataset. We make use of the following +references: + +* [Official TokenLearner code](https://github.com/google-research/scenic/blob/main/scenic/projects/token_learner/model.py) +* [Image Classification with ViTs on keras.io](https://keras.io/examples/vision/image_classification_with_vision_transformer/) +* [TokenLearner slides from NeurIPS 2021](https://nips.cc/media/neurips-2021/Slides/26578.pdf) +""" + + +""" +## Imports +""" + +import keras_core as keras +from keras_core import layers +from keras_core import ops +from tensorflow import data as tf_data + + +from datetime import datetime +import matplotlib.pyplot as plt +import numpy as np + +import math + +""" +## Hyperparameters + +Please feel free to change the hyperparameters and check your results. The best way to +develop intuition about the architecture is to experiment with it. +""" + +# DATA +BATCH_SIZE = 256 +AUTO = tf_data.AUTOTUNE +INPUT_SHAPE = (32, 32, 3) +NUM_CLASSES = 10 + +# OPTIMIZER +LEARNING_RATE = 1e-3 +WEIGHT_DECAY = 1e-4 + +# TRAINING +EPOCHS = 20 + +# AUGMENTATION +IMAGE_SIZE = 48 # We will resize input images to this size. +PATCH_SIZE = 6 # Size of the patches to be extracted from the input images. +NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2 + +# ViT ARCHITECTURE +LAYER_NORM_EPS = 1e-6 +PROJECTION_DIM = 128 +NUM_HEADS = 4 +NUM_LAYERS = 4 +MLP_UNITS = [ + PROJECTION_DIM * 2, + PROJECTION_DIM, +] + +# TOKENLEARNER +NUM_TOKENS = 4 + +""" +## Load and prepare the CIFAR-10 dataset +""" + +# Load the CIFAR-10 dataset. +(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data() +(x_train, y_train), (x_val, y_val) = ( + (x_train[:40000], y_train[:40000]), + (x_train[40000:], y_train[40000:]), +) +print(f"Training samples: {len(x_train)}") +print(f"Validation samples: {len(x_val)}") +print(f"Testing samples: {len(x_test)}") + +# Convert to tf.data.Dataset objects. +train_ds = tf_data.Dataset.from_tensor_slices((x_train, y_train)) +train_ds = train_ds.shuffle(BATCH_SIZE * 100).batch(BATCH_SIZE).prefetch(AUTO) + +val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val)) +val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO) + +test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test)) +test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO) + +""" +## Data augmentation + +The augmentation pipeline consists of: + +- Rescaling +- Resizing +- Random cropping (fixed-sized or random sized) +- Random horizontal flipping +""" + +data_augmentation = keras.Sequential( + [ + layers.Rescaling(1 / 255.0), + layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20), + layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE), + layers.RandomFlip("horizontal"), + ], + name="data_augmentation", +) + +""" +Note that image data augmentation layers do not apply data transformations at inference time. +This means that when these layers are called with `training=False` they behave differently. Refer +[to the documentation](https://keras.io/api/layers/preprocessing_layers/image_augmentation/) for more +details. +""" + +""" +## Positional embedding module + +A [Transformer](https://arxiv.org/abs/1706.03762) architecture consists of **multi-head +self attention** layers and **fully-connected feed forward** networks (MLP) as the main +components. Both these components are _permutation invariant_: they're not aware of +feature order. + +To overcome this problem we inject tokens with positional information. The +`position_embedding` function adds this positional information to the linearly projected +tokens. +""" + + +def position_embedding( + projected_patches, num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM +): + # Build the positions. + positions = ops.arange(start=0, stop=num_patches, step=1) + + # Encode the positions with an Embedding layer. + encoded_positions = layers.Embedding( + input_dim=num_patches, output_dim=projection_dim + )(positions) + + # Add encoded positions to the projected patches. + return projected_patches + encoded_positions + + +""" +## MLP block for Transformer + +This serves as the Fully Connected Feed Forward block for our Transformer. +""" + + +def mlp(x, dropout_rate, hidden_units): + # Iterate over the hidden units and + # add Dense => Dropout. + for units in hidden_units: + x = layers.Dense(units, activation=ops.gelu)(x) + x = layers.Dropout(dropout_rate)(x) + return x + + +""" +## TokenLearner module + +The following figure presents a pictorial overview of the module +([source](https://ai.googleblog.com/2021/12/improving-vision-transformer-efficiency.html)). + +![TokenLearner module GIF](https://blogger.googleusercontent.com/img/a/AVvXsEiylT3_nmd9-tzTnz3g3Vb4eTn-L5sOwtGJOad6t2we7FsjXSpbLDpuPrlInAhtE5hGCA_PfYTJtrIOKfLYLYGcYXVh1Ksfh_C1ZC-C8gw6GKtvrQesKoMrEA_LU_Gd5srl5-3iZDgJc1iyCELoXtfuIXKJ2ADDHOBaUjhU8lXTVdr2E7bCVaFgVHHkmA=w640-h208) + +The TokenLearner module takes as input an image-shaped tensor. It then passes it through +multiple single-channel convolutional layers extracting different spatial attention maps +focusing on different parts of the input. These attention maps are then element-wise +multiplied to the input and result is aggregated with pooling. This pooled output can be +trated as a summary of the input and has much lesser number of patches (8, for example) +than the original one (196, for example). + +Using multiple convolution layers helps with expressivity. Imposing a form of spatial +attention helps retain relevant information from the inputs. Both of these components are +crucial to make TokenLearner work, especially when we are significantly reducing the number of patches. +""" + + +def token_learner(inputs, number_of_tokens=NUM_TOKENS): + # Layer normalize the inputs. + x = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(inputs) # (B, H, W, C) + + # Applying Conv2D => Reshape => Permute + # The reshape and permute is done to help with the next steps of + # multiplication and Global Average Pooling. + attention_maps = keras.Sequential( + [ + # 3 layers of conv with gelu activation as suggested + # in the paper. + layers.Conv2D( + filters=number_of_tokens, + kernel_size=(3, 3), + activation=ops.gelu, + padding="same", + use_bias=False, + ), + layers.Conv2D( + filters=number_of_tokens, + kernel_size=(3, 3), + activation=ops.gelu, + padding="same", + use_bias=False, + ), + layers.Conv2D( + filters=number_of_tokens, + kernel_size=(3, 3), + activation=ops.gelu, + padding="same", + use_bias=False, + ), + # This conv layer will generate the attention maps + layers.Conv2D( + filters=number_of_tokens, + kernel_size=(3, 3), + activation="sigmoid", # Note sigmoid for [0, 1] output + padding="same", + use_bias=False, + ), + # Reshape and Permute + layers.Reshape((-1, number_of_tokens)), # (B, H*W, num_of_tokens) + layers.Permute((2, 1)), + ] + )( + x + ) # (B, num_of_tokens, H*W) + + # Reshape the input to align it with the output of the conv block. + num_filters = inputs.shape[-1] + inputs = layers.Reshape((1, -1, num_filters))(inputs) # inputs == (B, 1, H*W, C) + + # Element-Wise multiplication of the attention maps and the inputs + attended_inputs = ( + ops.expand_dims(attention_maps, axis=-1) * inputs + ) # (B, num_tokens, H*W, C) + + # Global average pooling the element wise multiplication result. + outputs = ops.mean(attended_inputs, axis=2) # (B, num_tokens, C) + return outputs + + +""" +## Transformer block +""" + + +def transformer(encoded_patches): + # Layer normalization 1. + x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches) + + # Multi Head Self Attention layer 1. + attention_output = layers.MultiHeadAttention( + num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=0.1 + )(x1, x1) + + # Skip connection 1. + x2 = layers.Add()([attention_output, encoded_patches]) + + # Layer normalization 2. + x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2) + + # MLP layer 1. + x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=0.1) + + # Skip connection 2. + encoded_patches = layers.Add()([x4, x2]) + return encoded_patches + + +""" +## ViT model with the TokenLearner module +""" + + +def create_vit_classifier(use_token_learner=True, token_learner_units=NUM_TOKENS): + inputs = layers.Input(shape=INPUT_SHAPE) # (B, H, W, C) + + # Augment data. + augmented = data_augmentation(inputs) + + # Create patches and project the pathces. + projected_patches = layers.Conv2D( + filters=PROJECTION_DIM, + kernel_size=(PATCH_SIZE, PATCH_SIZE), + strides=(PATCH_SIZE, PATCH_SIZE), + padding="VALID", + )(augmented) + _, h, w, c = projected_patches.shape + projected_patches = layers.Reshape((h * w, c))( + projected_patches + ) # (B, number_patches, projection_dim) + + # Add positional embeddings to the projected patches. + encoded_patches = position_embedding( + projected_patches + ) # (B, number_patches, projection_dim) + encoded_patches = layers.Dropout(0.1)(encoded_patches) + + # Iterate over the number of layers and stack up blocks of + # Transformer. + for i in range(NUM_LAYERS): + # Add a Transformer block. + encoded_patches = transformer(encoded_patches) + + # Add TokenLearner layer in the middle of the + # architecture. The paper suggests that anywhere + # between 1/2 or 3/4 will work well. + if use_token_learner and i == NUM_LAYERS // 2: + _, hh, c = encoded_patches.shape + h = int(math.sqrt(hh)) + encoded_patches = layers.Reshape((h, h, c))( + encoded_patches + ) # (B, h, h, projection_dim) + encoded_patches = token_learner( + encoded_patches, token_learner_units + ) # (B, num_tokens, c) + + # Layer normalization and Global average pooling. + representation = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches) + representation = layers.GlobalAvgPool1D()(representation) + + # Classify outputs. + outputs = layers.Dense(NUM_CLASSES, activation="softmax")(representation) + + # Create the Keras model. + model = keras.Model(inputs=inputs, outputs=outputs) + return model + + +""" +As shown in the [TokenLearner paper](https://openreview.net/forum?id=z-l1kpDXs88), it is +almost always advantageous to include the TokenLearner module in the middle of the +network. +""" + +""" +## Training utility +""" + + +def run_experiment(model): + # Initialize the AdamW optimizer. + optimizer = keras.optimizers.AdamW( + learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY + ) + + # Compile the model with the optimizer, loss function + # and the metrics. + model.compile( + optimizer=optimizer, + loss="sparse_categorical_crossentropy", + metrics=[ + keras.metrics.SparseCategoricalAccuracy(name="accuracy"), + keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"), + ], + ) + + # Define callbacks + checkpoint_filepath = "/tmp/checkpoint.weights.h5" + checkpoint_callback = keras.callbacks.ModelCheckpoint( + checkpoint_filepath, + monitor="val_accuracy", + save_best_only=True, + save_weights_only=True, + ) + + # Train the model. + _ = model.fit( + train_ds, + epochs=EPOCHS, + validation_data=val_ds, + callbacks=[checkpoint_callback], + ) + + model.load_weights(checkpoint_filepath) + _, accuracy, top_5_accuracy = model.evaluate(test_ds) + print(f"Test accuracy: {round(accuracy * 100, 2)}%") + print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%") + + +""" +## Train and evaluate a ViT with TokenLearner +""" + +vit_token_learner = create_vit_classifier() +run_experiment(vit_token_learner) + +""" +## Results + +We experimented with and without the TokenLearner inside the mini ViT we implemented +(with the same hyperparameters presented in this example). Here are our results: + +| **TokenLearner** | **# tokens in
TokenLearner** | **Top-1 Acc
(Averaged across 5 runs)** | **GFLOPs** | **TensorBoard** | +|:---:|:---:|:---:|:---:|:---:| +| N | - | 56.112% | 0.0184 | [Link](https://tensorboard.dev/experiment/vkCwM49dQZ2RiK0ZT4mj7w/) | +| Y | 8 | **56.55%** | **0.0153** | [Link](https://tensorboard.dev/experiment/vkCwM49dQZ2RiK0ZT4mj7w/) | +| N | - | 56.37% | 0.0184 | [Link](https://tensorboard.dev/experiment/hdyJ4wznQROwqZTgbtmztQ/) | +| Y | 4 | **56.4980%** | **0.0147** | [Link](https://tensorboard.dev/experiment/hdyJ4wznQROwqZTgbtmztQ/) | +| N | - (# Transformer layers: 8) | 55.36% | 0.0359 | [Link](https://tensorboard.dev/experiment/sepBK5zNSaOtdCeEG6SV9w/) | + +TokenLearner is able to consistently outperform our mini ViT without the module. It is +also interesting to notice that it was also able to outperform a deeper version of our +mini ViT (with 8 layers). The authors also report similar observations in the paper and +they attribute this to the adaptiveness of TokenLearner. + +One should also note that the FLOPs count **decreases** considerably with the addition of +the TokenLearner module. With less FLOPs count the TokenLearner module is able to +deliver better results. This aligns very well with the authors' findings. + +Additionally, the authors [introduced](https://github.com/google-research/scenic/blob/main/scenic/projects/token_learner/model.py#L104) +a newer version of the TokenLearner for smaller training data regimes. Quoting the authors: + +> Instead of using 4 conv. layers with small channels to implement spatial attention, + this version uses 2 grouped conv. layers with more channels. It also uses softmax + instead of sigmoid. We confirmed that this version works better when having limited + training data, such as training with ImageNet1K from scratch. + +We experimented with this module and in the following table we summarize the results: + +| **# Groups** | **# Tokens** | **Top-1 Acc** | **GFLOPs** | **TensorBoard** | +|:---:|:---:|:---:|:---:|:---:| +| 4 | 4 | 54.638% | 0.0149 | [Link](https://tensorboard.dev/experiment/KmfkGqAGQjikEw85phySmw/) | +| 8 | 8 | 54.898% | 0.0146 | [Link](https://tensorboard.dev/experiment/0PpgYOq9RFWV9njX6NJQ2w/) | +| 4 | 8 | 55.196% | 0.0149 | [Link](https://tensorboard.dev/experiment/WUkrHbZASdu3zrfmY4ETZg/) | + +Please note that we used the same hyperparameters presented in this example. Our +implementation is available +[in this notebook](https://github.com/ariG23498/TokenLearner/blob/master/TokenLearner-V1.1.ipynb). +We acknowledge that the results with this new TokenLearner module are slightly off +than expected and this might mitigate with hyperparameter tuning. + +*Note*: To compute the FLOPs of our models we used +[this utility](https://github.com/AdityaKane2001/regnety/blob/main/regnety/utils/model_utils.py#L27) +from [this repository](https://github.com/AdityaKane2001/regnety). +""" + +""" +## Number of parameters + +You may have noticed that adding the TokenLearner module increases the number of +parameters of the base network. But that does not mean it is less efficient as shown by +[Dehghani et al.](https://arxiv.org/abs/2110.12894). Similar findings were reported +by [Bello et al.](https://arxiv.org/abs/2103.07579) as well. The TokenLearner module +helps reducing the FLOPS in the overall network thereby helping to reduce the memory +footprint. +""" + +""" +## Final notes + +* TokenFuser: The authors of the paper also propose another module named TokenFuser. This +module helps in remapping the representation of the TokenLearner output back to its +original spatial resolution. To reuse the TokenLearner in the ViT architecture, the +TokenFuser is a must. We first learn the tokens from the TokenLearner, build a +representation of the tokens from a Transformer layer and then remap the representation +into the original spatial resolution, so that it can again be consumed by a TokenLearner. +Note here that you can only use the TokenLearner module once in entire ViT model if not +paired with the TokenFuser. +* Use of these modules for video: The authors also suggest that TokenFuser goes really +well with Vision Transformers for Videos ([Arnab et al.](https://arxiv.org/abs/2103.15691)). + +We are grateful to [JarvisLabs](https://jarvislabs.ai/) and +[Google Developers Experts](https://developers.google.com/programs/experts/) +program for helping with GPU credits. Also, we are thankful to Michael Ryoo (first +author of TokenLearner) for fruitful discussions. + +| Trained Model | Demo | +| :--: | :--: | +| [![Generic badge](https://img.shields.io/badge/🤗%20Model-TokenLearner-black.svg)](https://huggingface.co/keras-io/learning_to_tokenize_in_ViT) | [![Generic badge](https://img.shields.io/badge/🤗%20Spaces-TokenLearner-black.svg)](https://huggingface.co/spaces/keras-io/token_learner) | +"""