From 4853a0a6f11278cd28635d0e88f39b3ba5ce93f6 Mon Sep 17 00:00:00 2001 From: Muhammad Anas Raza <63569834+anas-rz@users.noreply.github.com> Date: Mon, 17 Jul 2023 17:13:28 -0400 Subject: [PATCH] Converted to Keras Core: A Vision Transformer without Attention (#497) * Port ShiftViT to keras core * remove empty spaces * Reverted epochs --- .../keras_io/tensorflow/vision/shiftvit.py | 877 ++++++++++++++++++ 1 file changed, 877 insertions(+) create mode 100644 examples/keras_io/tensorflow/vision/shiftvit.py diff --git a/examples/keras_io/tensorflow/vision/shiftvit.py b/examples/keras_io/tensorflow/vision/shiftvit.py new file mode 100644 index 000000000..a77078d00 --- /dev/null +++ b/examples/keras_io/tensorflow/vision/shiftvit.py @@ -0,0 +1,877 @@ +""" +Title: A Vision Transformer without Attention +Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha) +Converted to Keras Core: [Muhammad Anas Raza](https://anasrz.com) +Date created: 2022/02/24 +Last modified: 2023/07/15 +Description: A minimal implementation of ShiftViT. +Accelerator: GPU +""" +""" +## Introduction + +[Vision Transformers](https://arxiv.org/abs/2010.11929) (ViTs) have sparked a wave of +research at the intersection of Transformers and Computer Vision (CV). + +ViTs can simultaneously model long- and short-range dependencies, thanks to +the Multi-Head Self-Attention mechanism in the Transformer block. Many researchers believe +that the success of ViTs are purely due to the attention layer, and they seldom +think about other parts of the ViT model. + +In the academic paper +[When Shift Operation Meets Vision Transformer: An Extremely Simple Alternative to Attention Mechanism](https://arxiv.org/abs/2201.10801) +the authors propose to demystify the success of ViTs with the introduction of a **NO +PARAMETER** operation in place of the attention operation. They swap the attention +operation with a shifting operation. + +In this example, we minimally implement the paper with close alignement to the author's +[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py). +""" + +""" +## Setup and imports +""" + +import os +os.environ["KERAS_BACKEND"] = "tensorflow" + +import numpy as np +import matplotlib.pyplot as plt + +import tensorflow as tf +import keras_core as keras +from keras_core import layers + + + +# Setting seed for reproducibiltiy +SEED = 42 +keras.utils.set_random_seed(SEED) + +""" +## Hyperparameters + +These are the hyperparameters that we have chosen for the experiment. +Please feel free to tune them. +""" + + +class Config(object): + # DATA + batch_size = 256 + buffer_size = batch_size * 2 + input_shape = (32, 32, 3) + num_classes = 10 + + # AUGMENTATION + image_size = 48 + + # ARCHITECTURE + patch_size = 4 + projected_dim = 96 + num_shift_blocks_per_stages = [2, 4, 8, 2] + epsilon = 1e-5 + stochastic_depth_rate = 0.2 + mlp_dropout_rate = 0.2 + num_div = 12 + shift_pixel = 1 + mlp_expand_ratio = 2 + + # OPTIMIZER + lr_start = 1e-5 + lr_max = 1e-3 + weight_decay = 1e-4 + + # TRAINING + epochs = 100 + + +config = Config() + +""" +## Load the CIFAR-10 dataset + +We use the CIFAR-10 dataset for our experiments. +""" + +(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data() +(x_train, y_train), (x_val, y_val) = ( + (x_train[:40000], y_train[:40000]), + (x_train[40000:], y_train[40000:]), +) +print(f"Training samples: {len(x_train)}") +print(f"Validation samples: {len(x_val)}") +print(f"Testing samples: {len(x_test)}") + +AUTO = tf.data.AUTOTUNE +train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) +train_ds = train_ds.shuffle(config.buffer_size).batch(config.batch_size).prefetch(AUTO) + +val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)) +val_ds = val_ds.batch(config.batch_size).prefetch(AUTO) + +test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) +test_ds = test_ds.batch(config.batch_size).prefetch(AUTO) + +""" +## Data Augmentation + +The augmentation pipeline consists of: + +- Rescaling +- Resizing +- Random cropping +- Random horizontal flipping + +_Note_: The image data augmentation layers do not apply +data transformations at inference time. This means that +when these layers are called with `training=False` they +behave differently. Refer to the +[documentation](https://keras.io/api/layers/preprocessing_layers/image_augmentation/) +for more details. +""" + + +def get_augmentation_model(): + """Build the data augmentation model.""" + data_augmentation = keras.Sequential( + [ + layers.Resizing(config.input_shape[0] + 20, config.input_shape[0] + 20), + layers.RandomCrop(config.image_size, config.image_size), + layers.RandomFlip("horizontal"), + layers.Rescaling(1 / 255.0), + ] + ) + return data_augmentation + + +""" +## The ShiftViT architecture + +In this section, we build the architecture proposed in +[the ShiftViT paper](https://arxiv.org/abs/2201.10801). + +| ![ShiftViT Architecture](https://i.imgur.com/CHU40HX.png) | +| :--: | +| Figure 1: The entire architecutre of ShiftViT. +[Source](https://arxiv.org/abs/2201.10801) | + +The architecture as shown in Fig. 1, is inspired by +[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030). +Here the authors propose a modular architecture with 4 stages. Each stage works on its +own spatial size, creating a hierarchical architecture. + +An input image of size `HxWx3` is split into non-overlapping patches of size `4x4`. +This is done via the patchify layer which results in individual tokens of feature size `48` +(`4x4x3`). Each stage comprises two parts: + +1. Embedding Generation +2. Stacked Shift Blocks + +We discuss the stages and the modules in detail in what follows. + +_Note_: Compared to the [official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py) +we restructure some key components to better fit the Keras API. +""" + +""" +### The ShiftViT Block + +| ![ShiftViT block](https://i.imgur.com/IDe35vo.gif) | +| :--: | +| Figure 2: From the Model to a Shift Block. | + +Each stage in the ShiftViT architecture comprises of a Shift Block as shown in Fig 2. + +| ![Shift Vit Block](https://i.imgur.com/0q13pLu.png) | +| :--: | +| Figure 3: The Shift ViT Block. [Source](https://arxiv.org/abs/2201.10801) | + +The Shift Block as shown in Fig. 3, comprises of the following: + +1. Shift Operation +2. Linear Normalization +3. MLP Layer +""" + +""" +#### The MLP block + +The MLP block is intended to be a stack of densely-connected layers.s +""" + + +class MLP(layers.Layer): + """Get the MLP layer for each shift block. + + Args: + mlp_expand_ratio (int): The ratio with which the first feature map is expanded. + mlp_dropout_rate (float): The rate for dropout. + """ + + def __init__(self, mlp_expand_ratio, mlp_dropout_rate, **kwargs): + super().__init__(**kwargs) + self.mlp_expand_ratio = mlp_expand_ratio + self.mlp_dropout_rate = mlp_dropout_rate + + def build(self, input_shape): + input_channels = input_shape[-1] + initial_filters = int(self.mlp_expand_ratio * input_channels) + + self.mlp = keras.Sequential( + [ + layers.Dense( + units=initial_filters, + activation=tf.nn.gelu, + ), + layers.Dropout(rate=self.mlp_dropout_rate), + layers.Dense(units=input_channels), + layers.Dropout(rate=self.mlp_dropout_rate), + ] + ) + + def call(self, x): + x = self.mlp(x) + return x + + +""" +#### The DropPath layer + +Stochastic depth is a regularization technique that randomly drops a set of +layers. During inference, the layers are kept as they are. It is very +similar to Dropout, but it operates on a block of layers rather +than on individual nodes present inside a layer. +""" + + +class DropPath(layers.Layer): + """Drop Path also known as the Stochastic Depth layer. + + Refernece: + - https://keras.io/examples/vision/cct/#stochastic-depth-for-regularization + - github.com:rwightman/pytorch-image-models + """ + + def __init__(self, drop_path_prob, **kwargs): + super().__init__(**kwargs) + self.drop_path_prob = drop_path_prob + + def call(self, x, training=False): + if training: + keep_prob = 1 - self.drop_path_prob + shape = (tf.shape(x)[0],) + (1,) * (len(x.shape) - 1) + random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) + random_tensor = tf.floor(random_tensor) + return (x / keep_prob) * random_tensor + return x + + +""" +#### Block + +The most important operation in this paper is the **shift opperation**. In this section, +we describe the shift operation and compare it with its original implementation provided +by the authors. + +A generic feature map is assumed to have the shape `[N, H, W, C]`. Here we choose a +`num_div` parameter that decides the division size of the channels. The first 4 divisions +are shifted (1 pixel) in the left, right, up, and down direction. The remaining splits +are kept as is. After partial shifting the shifted channels are padded and the overflown +pixels are chopped off. This completes the partial shifting operation. + +In the original implementation, the code is approximately: + +```python +out[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:] # shift left +out[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1] # shift right +out[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :] # shift up +out[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :] # shift down + +out[:, g * 4:, :, :] = x[:, g * 4:, :, :] # no shift +``` + +In TensorFlow it would be infeasible for us to assign shifted channels to a tensor in the +middle of the training process. This is why we have resorted to the following procedure: + +1. Split the channels with the `num_div` parameter. +2. Select each of the first four spilts and shift and pad them in the respective +directions. +3. After shifting and padding, we concatenate the channel back. + +| ![Manim rendered animation for shift operation](https://i.imgur.com/PReeULP.gif) | +| :--: | +| Figure 4: The TensorFlow style shifting | + +The entire procedure is explained in the Fig. 4. +""" + + +class ShiftViTBlock(layers.Layer): + """A unit ShiftViT Block + + Args: + shift_pixel (int): The number of pixels to shift. Default to 1. + mlp_expand_ratio (int): The ratio with which MLP features are + expanded. Default to 2. + mlp_dropout_rate (float): The dropout rate used in MLP. + num_div (int): The number of divisions of the feature map's channel. + Totally, 4/num_div of channels will be shifted. Defaults to 12. + epsilon (float): Epsilon constant. + drop_path_prob (float): The drop probability for drop path. + """ + + def __init__( + self, + epsilon, + drop_path_prob, + mlp_dropout_rate, + num_div=12, + shift_pixel=1, + mlp_expand_ratio=2, + **kwargs, + ): + super().__init__(**kwargs) + self.shift_pixel = shift_pixel + self.mlp_expand_ratio = mlp_expand_ratio + self.mlp_dropout_rate = mlp_dropout_rate + self.num_div = num_div + self.epsilon = epsilon + self.drop_path_prob = drop_path_prob + + def build(self, input_shape): + self.H = input_shape[1] + self.W = input_shape[2] + self.C = input_shape[3] + self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon) + self.drop_path = ( + DropPath(drop_path_prob=self.drop_path_prob) + if self.drop_path_prob > 0.0 + else layers.Activation("linear") + ) + self.mlp = MLP( + mlp_expand_ratio=self.mlp_expand_ratio, + mlp_dropout_rate=self.mlp_dropout_rate, + ) + + def get_shift_pad(self, x, mode): + """Shifts the channels according to the mode chosen.""" + if mode == "left": + offset_height = 0 + offset_width = 0 + target_height = 0 + target_width = self.shift_pixel + elif mode == "right": + offset_height = 0 + offset_width = self.shift_pixel + target_height = 0 + target_width = self.shift_pixel + elif mode == "up": + offset_height = 0 + offset_width = 0 + target_height = self.shift_pixel + target_width = 0 + else: + offset_height = self.shift_pixel + offset_width = 0 + target_height = self.shift_pixel + target_width = 0 + crop = tf.image.crop_to_bounding_box( + x, + offset_height=offset_height, + offset_width=offset_width, + target_height=self.H - target_height, + target_width=self.W - target_width, + ) + shift_pad = tf.image.pad_to_bounding_box( + crop, + offset_height=offset_height, + offset_width=offset_width, + target_height=self.H, + target_width=self.W, + ) + return shift_pad + + def call(self, x, training=False): + # Split the feature maps + x_splits = tf.split(x, num_or_size_splits=self.C // self.num_div, axis=-1) + + # Shift the feature maps + x_splits[0] = self.get_shift_pad(x_splits[0], mode="left") + x_splits[1] = self.get_shift_pad(x_splits[1], mode="right") + x_splits[2] = self.get_shift_pad(x_splits[2], mode="up") + x_splits[3] = self.get_shift_pad(x_splits[3], mode="down") + + # Concatenate the shifted and unshifted feature maps + x = tf.concat(x_splits, axis=-1) + + # Add the residual connection + shortcut = x + x = shortcut + self.drop_path(self.mlp(self.layer_norm(x)), training=training) + return x + + +""" +### The ShiftViT blocks + +| ![Shift Blokcs](https://i.imgur.com/FKy5NnD.png) | +| :--: | +| Figure 5: Shift Blocks in the architecture. [Source](https://arxiv.org/abs/2201.10801) | + +Each stage of the architecture has shift blocks as shown in Fig.5. Each of these blocks +contain a variable number of stacked ShiftViT block (as built in the earlier section). + +Shift blocks are followed by a PatchMerging layer that scales down feature inputs. The +PatchMerging layer helps in the pyramidal structure of the model. +""" + +""" +#### The PatchMerging layer + +This layer merges the two adjacent tokens. This layer helps in scaling the features down +spatially and increasing the features up channel wise. We use a Conv2D layer to merge the +patches. +""" + + +class PatchMerging(layers.Layer): + """The Patch Merging layer. + + Args: + epsilon (float): The epsilon constant. + """ + + def __init__(self, epsilon, **kwargs): + super().__init__(**kwargs) + self.epsilon = epsilon + + def build(self, input_shape): + filters = 2 * input_shape[-1] + self.reduction = layers.Conv2D( + filters=filters, kernel_size=2, strides=2, padding="same", use_bias=False + ) + self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon) + + def call(self, x): + # Apply the patch merging algorithm on the feature maps + x = self.layer_norm(x) + x = self.reduction(x) + return x + + +""" +#### Stacked Shift Blocks + +Each stage will have a variable number of stacked ShiftViT Blocks, as suggested in +the paper. This is a generic layer that will contain the stacked shift vit blocks +with the patch merging layer as well. Combining the two operations (shift ViT +block and patch merging) is a design choice we picked for better code reusability. +""" + + +# Note: This layer will have a different depth of stacking +# for different stages on the model. +class StackedShiftBlocks(layers.Layer): + """The layer containing stacked ShiftViTBlocks. + + Args: + epsilon (float): The epsilon constant. + mlp_dropout_rate (float): The dropout rate used in the MLP block. + num_shift_blocks (int): The number of shift vit blocks for this stage. + stochastic_depth_rate (float): The maximum drop path rate chosen. + is_merge (boolean): A flag that determines the use of the Patch Merge + layer after the shift vit blocks. + num_div (int): The division of channels of the feature map. Defaults to 12. + shift_pixel (int): The number of pixels to shift. Defaults to 1. + mlp_expand_ratio (int): The ratio with which the initial dense layer of + the MLP is expanded Defaults to 2. + """ + + def __init__( + self, + epsilon, + mlp_dropout_rate, + num_shift_blocks, + stochastic_depth_rate, + is_merge, + num_div=12, + shift_pixel=1, + mlp_expand_ratio=2, + **kwargs, + ): + super().__init__(**kwargs) + self.epsilon = epsilon + self.mlp_dropout_rate = mlp_dropout_rate + self.num_shift_blocks = num_shift_blocks + self.stochastic_depth_rate = stochastic_depth_rate + self.is_merge = is_merge + self.num_div = num_div + self.shift_pixel = shift_pixel + self.mlp_expand_ratio = mlp_expand_ratio + + def build(self, input_shapes): + # Calculate stochastic depth probabilities. + # Reference: https://keras.io/examples/vision/cct/#the-final-cct-model + dpr = [ + x + for x in np.linspace( + start=0, stop=self.stochastic_depth_rate, num=self.num_shift_blocks + ) + ] + + # Build the shift blocks as a list of ShiftViT Blocks + self.shift_blocks = list() + for num in range(self.num_shift_blocks): + self.shift_blocks.append( + ShiftViTBlock( + num_div=self.num_div, + epsilon=self.epsilon, + drop_path_prob=dpr[num], + mlp_dropout_rate=self.mlp_dropout_rate, + shift_pixel=self.shift_pixel, + mlp_expand_ratio=self.mlp_expand_ratio, + ) + ) + if self.is_merge: + self.patch_merge = PatchMerging(epsilon=self.epsilon) + + def call(self, x, training=False): + for shift_block in self.shift_blocks: + x = shift_block(x, training=training) + if self.is_merge: + x = self.patch_merge(x) + return x + + +""" +## The ShiftViT model + +Build the ShiftViT custom model. +""" + + +class ShiftViTModel(keras.Model): + """The ShiftViT Model. + + Args: + data_augmentation (keras.Model): A data augmentation model. + projected_dim (int): The dimension to which the patches of the image are + projected. + patch_size (int): The patch size of the images. + num_shift_blocks_per_stages (list[int]): A list of all the number of shit + blocks per stage. + epsilon (float): The epsilon constant. + mlp_dropout_rate (float): The dropout rate used in the MLP block. + stochastic_depth_rate (float): The maximum drop rate probability. + num_div (int): The number of divisions of the channesl of the feature + map. Defaults to 12. + shift_pixel (int): The number of pixel to shift. Default to 1. + mlp_expand_ratio (int): The ratio with which the initial mlp dense layer + is expanded to. Defaults to 2. + """ + + def __init__( + self, + data_augmentation, + projected_dim, + patch_size, + num_shift_blocks_per_stages, + epsilon, + mlp_dropout_rate, + stochastic_depth_rate, + num_div=12, + shift_pixel=1, + mlp_expand_ratio=2, + **kwargs, + ): + super().__init__(**kwargs) + self.data_augmentation = data_augmentation + self.patch_projection = layers.Conv2D( + filters=projected_dim, + kernel_size=patch_size, + strides=patch_size, + padding="same", + ) + self.stages = list() + for index, num_shift_blocks in enumerate(num_shift_blocks_per_stages): + if index == len(num_shift_blocks_per_stages) - 1: + # This is the last stage, do not use the patch merge here. + is_merge = False + else: + is_merge = True + # Build the stages. + self.stages.append( + StackedShiftBlocks( + epsilon=epsilon, + mlp_dropout_rate=mlp_dropout_rate, + num_shift_blocks=num_shift_blocks, + stochastic_depth_rate=stochastic_depth_rate, + is_merge=is_merge, + num_div=num_div, + shift_pixel=shift_pixel, + mlp_expand_ratio=mlp_expand_ratio, + ) + ) + self.global_avg_pool = layers.GlobalAveragePooling2D() + + def get_config(self): + config = super().get_config() + config.update( + { + "data_augmentation": self.data_augmentation, + "patch_projection": self.patch_projection, + "stages": self.stages, + "global_avg_pool": self.global_avg_pool, + } + ) + return config + + def _calculate_loss(self, data, training=False): + (images, labels) = data + + # Augment the images + augmented_images = self.data_augmentation(images, training=training) + + # Create patches and project the pathces. + projected_patches = self.patch_projection(augmented_images) + + # Pass through the stages + x = projected_patches + for stage in self.stages: + x = stage(x, training=training) + + # Get the logits. + logits = self.global_avg_pool(x) + + # Calculate the loss and return it. + total_loss = self.compute_loss(data, labels, logits) + return total_loss, labels, logits + + def train_step(self, inputs): + with tf.GradientTape() as tape: + total_loss, labels, logits = self._calculate_loss( + data=inputs, training=True + ) + + # Apply gradients. + train_vars = [ + self.data_augmentation.trainable_variables, + self.patch_projection.trainable_variables, + self.global_avg_pool.trainable_variables, + ] + train_vars = train_vars + [stage.trainable_variables for stage in self.stages] + + # Optimize the gradients. + grads = tape.gradient(total_loss, train_vars) + trainable_variable_list = [] + for grad, var in zip(grads, train_vars): + for g, v in zip(grad, var): + trainable_variable_list.append((g, v)) + self.optimizer.apply_gradients(trainable_variable_list) + + # Update the metrics + for metric in self.metrics: + if metric.name == "loss": + metric.update_state(total_loss) + else: + metric.update_state(labels, logits) + + # Return a dict mapping metric names to current value + return {m.name: m.result() for m in self.metrics} + + def test_step(self, data): + loss, labels, logits = self._calculate_loss(data=data, training=False) + + # Update the metrics + for metric in self.metrics: + if metric.name == "loss": + metric.update_state(loss) + else: + metric.update_state(labels, logits) + + # Return a dict mapping metric names to current value + return {m.name: m.result() for m in self.metrics} + + +""" +## Instantiate the model +""" + +model = ShiftViTModel( + data_augmentation=get_augmentation_model(), + projected_dim=config.projected_dim, + patch_size=config.patch_size, + num_shift_blocks_per_stages=config.num_shift_blocks_per_stages, + epsilon=config.epsilon, + mlp_dropout_rate=config.mlp_dropout_rate, + stochastic_depth_rate=config.stochastic_depth_rate, + num_div=config.num_div, + shift_pixel=config.shift_pixel, + mlp_expand_ratio=config.mlp_expand_ratio, +) + +""" +## Learning rate schedule + +In many experiments, we want to warm up the model with a slowly increasing learning rate +and then cool down the model with a slowly decaying learning rate. In the warmup cosine +decay, the learning rate linearly increases for the warmup steps and then decays with a +cosine decay. +""" + + +# Some code is taken from: +# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2. +class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule): + """A LearningRateSchedule that uses a warmup cosine decay schedule.""" + + def __init__(self, lr_start, lr_max, warmup_steps, total_steps): + """ + Args: + lr_start: The initial learning rate + lr_max: The maximum learning rate to which lr should increase to in + the warmup steps + warmup_steps: The number of steps for which the model warms up + total_steps: The total number of steps for the model training + """ + super().__init__() + self.lr_start = lr_start + self.lr_max = lr_max + self.warmup_steps = warmup_steps + self.total_steps = total_steps + self.pi = tf.constant(np.pi) + + def __call__(self, step): + # Check whether the total number of steps is larger than the warmup + # steps. If not, then throw a value error. + if self.total_steps < self.warmup_steps: + raise ValueError( + f"Total number of steps {self.total_steps} must be" + + f"larger or equal to warmup steps {self.warmup_steps}." + ) + + # `cos_annealed_lr` is a graph that increases to 1 from the initial + # step to the warmup step. After that this graph decays to -1 at the + # final step mark. + cos_annealed_lr = tf.cos( + self.pi + * (tf.cast(step, tf.float32) - self.warmup_steps) + / tf.cast(self.total_steps - self.warmup_steps, tf.float32) + ) + + # Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes + # from 0 to 2. Normalize the graph with 0.5 so that now it goes from 0 + # to 1. With the normalized graph we scale it with `lr_max` such that + # it goes from 0 to `lr_max` + learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr) + + # Check whether warmup_steps is more than 0. + if self.warmup_steps > 0: + # Check whether lr_max is larger that lr_start. If not, throw a value + # error. + if self.lr_max < self.lr_start: + raise ValueError( + f"lr_start {self.lr_start} must be smaller or" + + f"equal to lr_max {self.lr_max}." + ) + + # Calculate the slope with which the learning rate should increase + # in the warumup schedule. The formula for slope is m = ((b-a)/steps) + slope = (self.lr_max - self.lr_start) / self.warmup_steps + + # With the formula for a straight line (y = mx+c) build the warmup + # schedule + warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start + + # When the current step is lesser that warmup steps, get the line + # graph. When the current step is greater than the warmup steps, get + # the scaled cos graph. + learning_rate = tf.where( + step < self.warmup_steps, warmup_rate, learning_rate + ) + + # When the current step is more that the total steps, return 0 else return + # the calculated graph. + return tf.where( + step > self.total_steps, 0.0, learning_rate, name="learning_rate" + ) + + +""" +## Compile and train the model +""" + +# Get the total number of steps for training. +total_steps = int((len(x_train) / config.batch_size) * config.epochs) + +# Calculate the number of steps for warmup. +warmup_epoch_percentage = 0.15 +warmup_steps = int(total_steps * warmup_epoch_percentage) + +# Initialize the warmupcosine schedule. +scheduled_lrs = WarmUpCosine( + lr_start=1e-5, + lr_max=1e-3, + warmup_steps=warmup_steps, + total_steps=total_steps, +) + +# Get the optimizer. +optimizer = keras.optimizers.AdamW( + learning_rate=scheduled_lrs, weight_decay=config.weight_decay +) + +# Compile and pretrain the model. +model.compile( + optimizer=optimizer, + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=[ + keras.metrics.SparseCategoricalAccuracy(name="accuracy"), + keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"), + ], +) + +# Train the model +history = model.fit( + train_ds, + epochs=config.epochs, + validation_data=val_ds, + callbacks=[ + keras.callbacks.EarlyStopping( + monitor="val_accuracy", + patience=5, + mode="auto", + ) + ], +) + +# Evaluate the model with the test dataset. +print("TESTING") +loss, acc_top1, acc_top5 = model.evaluate(test_ds) +print(f"Loss: {loss:0.2f}") +print(f"Top 1 test accuracy: {acc_top1*100:0.2f}%") +print(f"Top 5 test accuracy: {acc_top5*100:0.2f}%") + +""" +## Conclusion + +The most impactful contribution of the paper is not the novel architecture, but +the idea that hierarchical ViTs trained with no attention can perform quite well. This +opens up the question of how essential attention is to the performance of ViTs. + +For curious minds, we would suggest reading the +[ConvNexT](https://arxiv.org/abs/2201.03545) paper which attends more to the training +paradigms and architectural details of ViTs rather than providing a novel architecture +based on attention. + +Acknowledgements: + +- We would like to thank [PyImageSearch](https://pyimagesearch.com) for providing us with +resources that helped in the completion of this project. +- We would like to thank [JarvisLabs.ai](https://jarvislabs.ai/) for providing with the +GPU credits. +- We would like to thank [Manim Community](https://www.manim.community/) for the manim +library. +- A personal note of thanks to [Puja Roychowdhury](https://twitter.com/pleb_talks) for +helping us with the Learning Rate Schedule. +"""