Converted to Keras Core: A Vision Transformer without Attention (#497)

* Port ShiftViT to keras core

* remove empty spaces

* Reverted epochs
This commit is contained in:
Muhammad Anas Raza 2023-07-17 17:13:28 -04:00 committed by Francois Chollet
parent 7c9bda9d2a
commit 4853a0a6f1

@ -0,0 +1,877 @@
"""
Title: A Vision Transformer without Attention
Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
Converted to Keras Core: [Muhammad Anas Raza](https://anasrz.com)
Date created: 2022/02/24
Last modified: 2023/07/15
Description: A minimal implementation of ShiftViT.
Accelerator: GPU
"""
"""
## Introduction
[Vision Transformers](https://arxiv.org/abs/2010.11929) (ViTs) have sparked a wave of
research at the intersection of Transformers and Computer Vision (CV).
ViTs can simultaneously model long- and short-range dependencies, thanks to
the Multi-Head Self-Attention mechanism in the Transformer block. Many researchers believe
that the success of ViTs are purely due to the attention layer, and they seldom
think about other parts of the ViT model.
In the academic paper
[When Shift Operation Meets Vision Transformer: An Extremely Simple Alternative to Attention Mechanism](https://arxiv.org/abs/2201.10801)
the authors propose to demystify the success of ViTs with the introduction of a **NO
PARAMETER** operation in place of the attention operation. They swap the attention
operation with a shifting operation.
In this example, we minimally implement the paper with close alignement to the author's
[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).
"""
"""
## Setup and imports
"""
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras_core as keras
from keras_core import layers
# Setting seed for reproducibiltiy
SEED = 42
keras.utils.set_random_seed(SEED)
"""
## Hyperparameters
These are the hyperparameters that we have chosen for the experiment.
Please feel free to tune them.
"""
class Config(object):
# DATA
batch_size = 256
buffer_size = batch_size * 2
input_shape = (32, 32, 3)
num_classes = 10
# AUGMENTATION
image_size = 48
# ARCHITECTURE
patch_size = 4
projected_dim = 96
num_shift_blocks_per_stages = [2, 4, 8, 2]
epsilon = 1e-5
stochastic_depth_rate = 0.2
mlp_dropout_rate = 0.2
num_div = 12
shift_pixel = 1
mlp_expand_ratio = 2
# OPTIMIZER
lr_start = 1e-5
lr_max = 1e-3
weight_decay = 1e-4
# TRAINING
epochs = 100
config = Config()
"""
## Load the CIFAR-10 dataset
We use the CIFAR-10 dataset for our experiments.
"""
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
(x_train[:40000], y_train[:40000]),
(x_train[40000:], y_train[40000:]),
)
print(f"Training samples: {len(x_train)}")
print(f"Validation samples: {len(x_val)}")
print(f"Testing samples: {len(x_test)}")
AUTO = tf.data.AUTOTUNE
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(config.buffer_size).batch(config.batch_size).prefetch(AUTO)
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = val_ds.batch(config.batch_size).prefetch(AUTO)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.batch(config.batch_size).prefetch(AUTO)
"""
## Data Augmentation
The augmentation pipeline consists of:
- Rescaling
- Resizing
- Random cropping
- Random horizontal flipping
_Note_: The image data augmentation layers do not apply
data transformations at inference time. This means that
when these layers are called with `training=False` they
behave differently. Refer to the
[documentation](https://keras.io/api/layers/preprocessing_layers/image_augmentation/)
for more details.
"""
def get_augmentation_model():
"""Build the data augmentation model."""
data_augmentation = keras.Sequential(
[
layers.Resizing(config.input_shape[0] + 20, config.input_shape[0] + 20),
layers.RandomCrop(config.image_size, config.image_size),
layers.RandomFlip("horizontal"),
layers.Rescaling(1 / 255.0),
]
)
return data_augmentation
"""
## The ShiftViT architecture
In this section, we build the architecture proposed in
[the ShiftViT paper](https://arxiv.org/abs/2201.10801).
| ![ShiftViT Architecture](https://i.imgur.com/CHU40HX.png) |
| :--: |
| Figure 1: The entire architecutre of ShiftViT.
[Source](https://arxiv.org/abs/2201.10801) |
The architecture as shown in Fig. 1, is inspired by
[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030).
Here the authors propose a modular architecture with 4 stages. Each stage works on its
own spatial size, creating a hierarchical architecture.
An input image of size `HxWx3` is split into non-overlapping patches of size `4x4`.
This is done via the patchify layer which results in individual tokens of feature size `48`
(`4x4x3`). Each stage comprises two parts:
1. Embedding Generation
2. Stacked Shift Blocks
We discuss the stages and the modules in detail in what follows.
_Note_: Compared to the [official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py)
we restructure some key components to better fit the Keras API.
"""
"""
### The ShiftViT Block
| ![ShiftViT block](https://i.imgur.com/IDe35vo.gif) |
| :--: |
| Figure 2: From the Model to a Shift Block. |
Each stage in the ShiftViT architecture comprises of a Shift Block as shown in Fig 2.
| ![Shift Vit Block](https://i.imgur.com/0q13pLu.png) |
| :--: |
| Figure 3: The Shift ViT Block. [Source](https://arxiv.org/abs/2201.10801) |
The Shift Block as shown in Fig. 3, comprises of the following:
1. Shift Operation
2. Linear Normalization
3. MLP Layer
"""
"""
#### The MLP block
The MLP block is intended to be a stack of densely-connected layers.s
"""
class MLP(layers.Layer):
"""Get the MLP layer for each shift block.
Args:
mlp_expand_ratio (int): The ratio with which the first feature map is expanded.
mlp_dropout_rate (float): The rate for dropout.
"""
def __init__(self, mlp_expand_ratio, mlp_dropout_rate, **kwargs):
super().__init__(**kwargs)
self.mlp_expand_ratio = mlp_expand_ratio
self.mlp_dropout_rate = mlp_dropout_rate
def build(self, input_shape):
input_channels = input_shape[-1]
initial_filters = int(self.mlp_expand_ratio * input_channels)
self.mlp = keras.Sequential(
[
layers.Dense(
units=initial_filters,
activation=tf.nn.gelu,
),
layers.Dropout(rate=self.mlp_dropout_rate),
layers.Dense(units=input_channels),
layers.Dropout(rate=self.mlp_dropout_rate),
]
)
def call(self, x):
x = self.mlp(x)
return x
"""
#### The DropPath layer
Stochastic depth is a regularization technique that randomly drops a set of
layers. During inference, the layers are kept as they are. It is very
similar to Dropout, but it operates on a block of layers rather
than on individual nodes present inside a layer.
"""
class DropPath(layers.Layer):
"""Drop Path also known as the Stochastic Depth layer.
Refernece:
- https://keras.io/examples/vision/cct/#stochastic-depth-for-regularization
- github.com:rwightman/pytorch-image-models
"""
def __init__(self, drop_path_prob, **kwargs):
super().__init__(**kwargs)
self.drop_path_prob = drop_path_prob
def call(self, x, training=False):
if training:
keep_prob = 1 - self.drop_path_prob
shape = (tf.shape(x)[0],) + (1,) * (len(x.shape) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x
"""
#### Block
The most important operation in this paper is the **shift opperation**. In this section,
we describe the shift operation and compare it with its original implementation provided
by the authors.
A generic feature map is assumed to have the shape `[N, H, W, C]`. Here we choose a
`num_div` parameter that decides the division size of the channels. The first 4 divisions
are shifted (1 pixel) in the left, right, up, and down direction. The remaining splits
are kept as is. After partial shifting the shifted channels are padded and the overflown
pixels are chopped off. This completes the partial shifting operation.
In the original implementation, the code is approximately:
```python
out[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:] # shift left
out[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1] # shift right
out[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :] # shift up
out[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :] # shift down
out[:, g * 4:, :, :] = x[:, g * 4:, :, :] # no shift
```
In TensorFlow it would be infeasible for us to assign shifted channels to a tensor in the
middle of the training process. This is why we have resorted to the following procedure:
1. Split the channels with the `num_div` parameter.
2. Select each of the first four spilts and shift and pad them in the respective
directions.
3. After shifting and padding, we concatenate the channel back.
| ![Manim rendered animation for shift operation](https://i.imgur.com/PReeULP.gif) |
| :--: |
| Figure 4: The TensorFlow style shifting |
The entire procedure is explained in the Fig. 4.
"""
class ShiftViTBlock(layers.Layer):
"""A unit ShiftViT Block
Args:
shift_pixel (int): The number of pixels to shift. Default to 1.
mlp_expand_ratio (int): The ratio with which MLP features are
expanded. Default to 2.
mlp_dropout_rate (float): The dropout rate used in MLP.
num_div (int): The number of divisions of the feature map's channel.
Totally, 4/num_div of channels will be shifted. Defaults to 12.
epsilon (float): Epsilon constant.
drop_path_prob (float): The drop probability for drop path.
"""
def __init__(
self,
epsilon,
drop_path_prob,
mlp_dropout_rate,
num_div=12,
shift_pixel=1,
mlp_expand_ratio=2,
**kwargs,
):
super().__init__(**kwargs)
self.shift_pixel = shift_pixel
self.mlp_expand_ratio = mlp_expand_ratio
self.mlp_dropout_rate = mlp_dropout_rate
self.num_div = num_div
self.epsilon = epsilon
self.drop_path_prob = drop_path_prob
def build(self, input_shape):
self.H = input_shape[1]
self.W = input_shape[2]
self.C = input_shape[3]
self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon)
self.drop_path = (
DropPath(drop_path_prob=self.drop_path_prob)
if self.drop_path_prob > 0.0
else layers.Activation("linear")
)
self.mlp = MLP(
mlp_expand_ratio=self.mlp_expand_ratio,
mlp_dropout_rate=self.mlp_dropout_rate,
)
def get_shift_pad(self, x, mode):
"""Shifts the channels according to the mode chosen."""
if mode == "left":
offset_height = 0
offset_width = 0
target_height = 0
target_width = self.shift_pixel
elif mode == "right":
offset_height = 0
offset_width = self.shift_pixel
target_height = 0
target_width = self.shift_pixel
elif mode == "up":
offset_height = 0
offset_width = 0
target_height = self.shift_pixel
target_width = 0
else:
offset_height = self.shift_pixel
offset_width = 0
target_height = self.shift_pixel
target_width = 0
crop = tf.image.crop_to_bounding_box(
x,
offset_height=offset_height,
offset_width=offset_width,
target_height=self.H - target_height,
target_width=self.W - target_width,
)
shift_pad = tf.image.pad_to_bounding_box(
crop,
offset_height=offset_height,
offset_width=offset_width,
target_height=self.H,
target_width=self.W,
)
return shift_pad
def call(self, x, training=False):
# Split the feature maps
x_splits = tf.split(x, num_or_size_splits=self.C // self.num_div, axis=-1)
# Shift the feature maps
x_splits[0] = self.get_shift_pad(x_splits[0], mode="left")
x_splits[1] = self.get_shift_pad(x_splits[1], mode="right")
x_splits[2] = self.get_shift_pad(x_splits[2], mode="up")
x_splits[3] = self.get_shift_pad(x_splits[3], mode="down")
# Concatenate the shifted and unshifted feature maps
x = tf.concat(x_splits, axis=-1)
# Add the residual connection
shortcut = x
x = shortcut + self.drop_path(self.mlp(self.layer_norm(x)), training=training)
return x
"""
### The ShiftViT blocks
| ![Shift Blokcs](https://i.imgur.com/FKy5NnD.png) |
| :--: |
| Figure 5: Shift Blocks in the architecture. [Source](https://arxiv.org/abs/2201.10801) |
Each stage of the architecture has shift blocks as shown in Fig.5. Each of these blocks
contain a variable number of stacked ShiftViT block (as built in the earlier section).
Shift blocks are followed by a PatchMerging layer that scales down feature inputs. The
PatchMerging layer helps in the pyramidal structure of the model.
"""
"""
#### The PatchMerging layer
This layer merges the two adjacent tokens. This layer helps in scaling the features down
spatially and increasing the features up channel wise. We use a Conv2D layer to merge the
patches.
"""
class PatchMerging(layers.Layer):
"""The Patch Merging layer.
Args:
epsilon (float): The epsilon constant.
"""
def __init__(self, epsilon, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
def build(self, input_shape):
filters = 2 * input_shape[-1]
self.reduction = layers.Conv2D(
filters=filters, kernel_size=2, strides=2, padding="same", use_bias=False
)
self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon)
def call(self, x):
# Apply the patch merging algorithm on the feature maps
x = self.layer_norm(x)
x = self.reduction(x)
return x
"""
#### Stacked Shift Blocks
Each stage will have a variable number of stacked ShiftViT Blocks, as suggested in
the paper. This is a generic layer that will contain the stacked shift vit blocks
with the patch merging layer as well. Combining the two operations (shift ViT
block and patch merging) is a design choice we picked for better code reusability.
"""
# Note: This layer will have a different depth of stacking
# for different stages on the model.
class StackedShiftBlocks(layers.Layer):
"""The layer containing stacked ShiftViTBlocks.
Args:
epsilon (float): The epsilon constant.
mlp_dropout_rate (float): The dropout rate used in the MLP block.
num_shift_blocks (int): The number of shift vit blocks for this stage.
stochastic_depth_rate (float): The maximum drop path rate chosen.
is_merge (boolean): A flag that determines the use of the Patch Merge
layer after the shift vit blocks.
num_div (int): The division of channels of the feature map. Defaults to 12.
shift_pixel (int): The number of pixels to shift. Defaults to 1.
mlp_expand_ratio (int): The ratio with which the initial dense layer of
the MLP is expanded Defaults to 2.
"""
def __init__(
self,
epsilon,
mlp_dropout_rate,
num_shift_blocks,
stochastic_depth_rate,
is_merge,
num_div=12,
shift_pixel=1,
mlp_expand_ratio=2,
**kwargs,
):
super().__init__(**kwargs)
self.epsilon = epsilon
self.mlp_dropout_rate = mlp_dropout_rate
self.num_shift_blocks = num_shift_blocks
self.stochastic_depth_rate = stochastic_depth_rate
self.is_merge = is_merge
self.num_div = num_div
self.shift_pixel = shift_pixel
self.mlp_expand_ratio = mlp_expand_ratio
def build(self, input_shapes):
# Calculate stochastic depth probabilities.
# Reference: https://keras.io/examples/vision/cct/#the-final-cct-model
dpr = [
x
for x in np.linspace(
start=0, stop=self.stochastic_depth_rate, num=self.num_shift_blocks
)
]
# Build the shift blocks as a list of ShiftViT Blocks
self.shift_blocks = list()
for num in range(self.num_shift_blocks):
self.shift_blocks.append(
ShiftViTBlock(
num_div=self.num_div,
epsilon=self.epsilon,
drop_path_prob=dpr[num],
mlp_dropout_rate=self.mlp_dropout_rate,
shift_pixel=self.shift_pixel,
mlp_expand_ratio=self.mlp_expand_ratio,
)
)
if self.is_merge:
self.patch_merge = PatchMerging(epsilon=self.epsilon)
def call(self, x, training=False):
for shift_block in self.shift_blocks:
x = shift_block(x, training=training)
if self.is_merge:
x = self.patch_merge(x)
return x
"""
## The ShiftViT model
Build the ShiftViT custom model.
"""
class ShiftViTModel(keras.Model):
"""The ShiftViT Model.
Args:
data_augmentation (keras.Model): A data augmentation model.
projected_dim (int): The dimension to which the patches of the image are
projected.
patch_size (int): The patch size of the images.
num_shift_blocks_per_stages (list[int]): A list of all the number of shit
blocks per stage.
epsilon (float): The epsilon constant.
mlp_dropout_rate (float): The dropout rate used in the MLP block.
stochastic_depth_rate (float): The maximum drop rate probability.
num_div (int): The number of divisions of the channesl of the feature
map. Defaults to 12.
shift_pixel (int): The number of pixel to shift. Default to 1.
mlp_expand_ratio (int): The ratio with which the initial mlp dense layer
is expanded to. Defaults to 2.
"""
def __init__(
self,
data_augmentation,
projected_dim,
patch_size,
num_shift_blocks_per_stages,
epsilon,
mlp_dropout_rate,
stochastic_depth_rate,
num_div=12,
shift_pixel=1,
mlp_expand_ratio=2,
**kwargs,
):
super().__init__(**kwargs)
self.data_augmentation = data_augmentation
self.patch_projection = layers.Conv2D(
filters=projected_dim,
kernel_size=patch_size,
strides=patch_size,
padding="same",
)
self.stages = list()
for index, num_shift_blocks in enumerate(num_shift_blocks_per_stages):
if index == len(num_shift_blocks_per_stages) - 1:
# This is the last stage, do not use the patch merge here.
is_merge = False
else:
is_merge = True
# Build the stages.
self.stages.append(
StackedShiftBlocks(
epsilon=epsilon,
mlp_dropout_rate=mlp_dropout_rate,
num_shift_blocks=num_shift_blocks,
stochastic_depth_rate=stochastic_depth_rate,
is_merge=is_merge,
num_div=num_div,
shift_pixel=shift_pixel,
mlp_expand_ratio=mlp_expand_ratio,
)
)
self.global_avg_pool = layers.GlobalAveragePooling2D()
def get_config(self):
config = super().get_config()
config.update(
{
"data_augmentation": self.data_augmentation,
"patch_projection": self.patch_projection,
"stages": self.stages,
"global_avg_pool": self.global_avg_pool,
}
)
return config
def _calculate_loss(self, data, training=False):
(images, labels) = data
# Augment the images
augmented_images = self.data_augmentation(images, training=training)
# Create patches and project the pathces.
projected_patches = self.patch_projection(augmented_images)
# Pass through the stages
x = projected_patches
for stage in self.stages:
x = stage(x, training=training)
# Get the logits.
logits = self.global_avg_pool(x)
# Calculate the loss and return it.
total_loss = self.compute_loss(data, labels, logits)
return total_loss, labels, logits
def train_step(self, inputs):
with tf.GradientTape() as tape:
total_loss, labels, logits = self._calculate_loss(
data=inputs, training=True
)
# Apply gradients.
train_vars = [
self.data_augmentation.trainable_variables,
self.patch_projection.trainable_variables,
self.global_avg_pool.trainable_variables,
]
train_vars = train_vars + [stage.trainable_variables for stage in self.stages]
# Optimize the gradients.
grads = tape.gradient(total_loss, train_vars)
trainable_variable_list = []
for grad, var in zip(grads, train_vars):
for g, v in zip(grad, var):
trainable_variable_list.append((g, v))
self.optimizer.apply_gradients(trainable_variable_list)
# Update the metrics
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(total_loss)
else:
metric.update_state(labels, logits)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
loss, labels, logits = self._calculate_loss(data=data, training=False)
# Update the metrics
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(labels, logits)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
"""
## Instantiate the model
"""
model = ShiftViTModel(
data_augmentation=get_augmentation_model(),
projected_dim=config.projected_dim,
patch_size=config.patch_size,
num_shift_blocks_per_stages=config.num_shift_blocks_per_stages,
epsilon=config.epsilon,
mlp_dropout_rate=config.mlp_dropout_rate,
stochastic_depth_rate=config.stochastic_depth_rate,
num_div=config.num_div,
shift_pixel=config.shift_pixel,
mlp_expand_ratio=config.mlp_expand_ratio,
)
"""
## Learning rate schedule
In many experiments, we want to warm up the model with a slowly increasing learning rate
and then cool down the model with a slowly decaying learning rate. In the warmup cosine
decay, the learning rate linearly increases for the warmup steps and then decays with a
cosine decay.
"""
# Some code is taken from:
# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
"""A LearningRateSchedule that uses a warmup cosine decay schedule."""
def __init__(self, lr_start, lr_max, warmup_steps, total_steps):
"""
Args:
lr_start: The initial learning rate
lr_max: The maximum learning rate to which lr should increase to in
the warmup steps
warmup_steps: The number of steps for which the model warms up
total_steps: The total number of steps for the model training
"""
super().__init__()
self.lr_start = lr_start
self.lr_max = lr_max
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.pi = tf.constant(np.pi)
def __call__(self, step):
# Check whether the total number of steps is larger than the warmup
# steps. If not, then throw a value error.
if self.total_steps < self.warmup_steps:
raise ValueError(
f"Total number of steps {self.total_steps} must be"
+ f"larger or equal to warmup steps {self.warmup_steps}."
)
# `cos_annealed_lr` is a graph that increases to 1 from the initial
# step to the warmup step. After that this graph decays to -1 at the
# final step mark.
cos_annealed_lr = tf.cos(
self.pi
* (tf.cast(step, tf.float32) - self.warmup_steps)
/ tf.cast(self.total_steps - self.warmup_steps, tf.float32)
)
# Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes
# from 0 to 2. Normalize the graph with 0.5 so that now it goes from 0
# to 1. With the normalized graph we scale it with `lr_max` such that
# it goes from 0 to `lr_max`
learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr)
# Check whether warmup_steps is more than 0.
if self.warmup_steps > 0:
# Check whether lr_max is larger that lr_start. If not, throw a value
# error.
if self.lr_max < self.lr_start:
raise ValueError(
f"lr_start {self.lr_start} must be smaller or"
+ f"equal to lr_max {self.lr_max}."
)
# Calculate the slope with which the learning rate should increase
# in the warumup schedule. The formula for slope is m = ((b-a)/steps)
slope = (self.lr_max - self.lr_start) / self.warmup_steps
# With the formula for a straight line (y = mx+c) build the warmup
# schedule
warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start
# When the current step is lesser that warmup steps, get the line
# graph. When the current step is greater than the warmup steps, get
# the scaled cos graph.
learning_rate = tf.where(
step < self.warmup_steps, warmup_rate, learning_rate
)
# When the current step is more that the total steps, return 0 else return
# the calculated graph.
return tf.where(
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
)
"""
## Compile and train the model
"""
# Get the total number of steps for training.
total_steps = int((len(x_train) / config.batch_size) * config.epochs)
# Calculate the number of steps for warmup.
warmup_epoch_percentage = 0.15
warmup_steps = int(total_steps * warmup_epoch_percentage)
# Initialize the warmupcosine schedule.
scheduled_lrs = WarmUpCosine(
lr_start=1e-5,
lr_max=1e-3,
warmup_steps=warmup_steps,
total_steps=total_steps,
)
# Get the optimizer.
optimizer = keras.optimizers.AdamW(
learning_rate=scheduled_lrs, weight_decay=config.weight_decay
)
# Compile and pretrain the model.
model.compile(
optimizer=optimizer,
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
# Train the model
history = model.fit(
train_ds,
epochs=config.epochs,
validation_data=val_ds,
callbacks=[
keras.callbacks.EarlyStopping(
monitor="val_accuracy",
patience=5,
mode="auto",
)
],
)
# Evaluate the model with the test dataset.
print("TESTING")
loss, acc_top1, acc_top5 = model.evaluate(test_ds)
print(f"Loss: {loss:0.2f}")
print(f"Top 1 test accuracy: {acc_top1*100:0.2f}%")
print(f"Top 5 test accuracy: {acc_top5*100:0.2f}%")
"""
## Conclusion
The most impactful contribution of the paper is not the novel architecture, but
the idea that hierarchical ViTs trained with no attention can perform quite well. This
opens up the question of how essential attention is to the performance of ViTs.
For curious minds, we would suggest reading the
[ConvNexT](https://arxiv.org/abs/2201.03545) paper which attends more to the training
paradigms and architectural details of ViTs rather than providing a novel architecture
based on attention.
Acknowledgements:
- We would like to thank [PyImageSearch](https://pyimagesearch.com) for providing us with
resources that helped in the completion of this project.
- We would like to thank [JarvisLabs.ai](https://jarvislabs.ai/) for providing with the
GPU credits.
- We would like to thank [Manim Community](https://www.manim.community/) for the manim
library.
- A personal note of thanks to [Puja Roychowdhury](https://twitter.com/pleb_talks) for
helping us with the Learning Rate Schedule.
"""