The architecture of the model is quite straightforward. It’s the losses that are more complicated. For the network, we use a vision transformer combined with a linear layer to predict the probabilities of each class. We’re going to be using the Vision Transformer library, vit-pytorch, for initializing the vision transformer.
class VisionTransformer(nn.Module):
def __init__(self, img_size, z_dim,num_classes):
super(VisionTransformer, self).__init__()
self.vit = ViT(
image_size = img_size,
patch_size = 32,
num_classes = z_dim,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
self.linear = nn.Linear(in_features=z_dim, out_features=num_classes)
def forward(self, x):
x = self.vit(x)
x = self.linear(x)
x = nn.functional.softmax(x)
return x
There are two losses in the original paper.
- Instance Discrimination Loss
- Contrasive Learning Loss
Instance Discrimination Loss
The instance discrimination loss, \(L_{InsDis}\) is defined as follows
\[L_{InsDis} = – \sum^{N}_{i=1}\sum^{N}_{c=1}y_{c}^{i}\log{P_{c}^{i}}\]
where \(c\) sums over classes and \(i\) sums over instances passed to the network, which are batches
class InstanceDiscriminationLoss(nn.Module):
def __init__(self):
super(InstanceDiscriminationLoss, self).__init__()
def forward(self, predictions):
return -torch.sum(torch.log(predictions))
Contrastive Learning Loss
The constrasive learning loss \(L_{CN}\) is defined as follows
\[L_{CN} = -\sum^{N}_{i=1}z_{iA}^{T}z_{iB} + \sum^{N}_{i=1} \log{(e^{z_{iA}^{T}z_{iB}}} + \sum{e^{z_{iA}^{T}z_{i}^{-}}})\]
Here \(z_{iA}\) and \(z_{iB}\) are features extracted from augmented versions of the image \(x_{i}\). Let’s write an augmentation layer first
class RandomAugmentation(nn.Module):
def __init__(self):
super(RandomAugmentation, self).__init__()
self.augment = transforms.Compose(
def forward(self, x):
x_a = self.augment(x)
x_b = self.augment(x)
return x_a, x_b
Next up is the loss class.
class ContrastiveLearningLoss(nn.Module):
def __init__(self):
super(ContrastiveLearningLoss, self).__init__()
def forward(self, z_a, z_b):
n = len(z_a)
alignment = 0
for i in range(n):
alignment += -torch.sum([i].T, z_b[i]))
uniformity_loss = 0
for i in range(n):
negative_sum = 0
for j in range(n-1):
if i == j:
negative_sum += torch.sum([i].T, z_b[i]))
uniformity = torch.exp([i].T, z_b[i]))
uniformity_loss += negative_sum + uniformity
return alignment + uniformity_loss
The preprocessing part is pretty straightforward. We define a custom dataset class for extracting the data into data loader. We also define all the 17 classes for the flower dataset.
labels_2_idx = {
"Colts Foot":1,
"Lily Valley":11,
"Tiger lily":14,
idx_2_labels = {}
for key, value in labels_2_idx.items():
idx_2_labels[value] = key
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = open(annotations_file,"r").read()
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = int(np.floor((idx + 1/17)))
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
print(image, label)
return image, label
dataset = CustomImageDataset(
train_set, val_set =, [1360 - 136, 136])
from import DataLoader
train_dataloader = DataLoader(train_set, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_set, batch_size=64, shuffle=True)
The training part is pretty straightforward as well. You just have to iterate through the batches and add up the two losses.
epochs = 12
visionTransformer = nn.DataParallel(VisionTransformer(224, 10, 17))
randomAugmentation = nn.DataParallel(RandomAugmentation())
randomAugmentation .cuda()
instanceDiscriminationLoss = nn.DataParallel(InstanceDiscriminationLoss())
contrastiveLearningLoss = nn.DataParallel(ContrastiveLearningLoss())
optimizer = torch.optim.AdamW(visionTransformer.parameters(), lr=0.001)
for epoch in range(epochs):
aggregate_loss = 0
for step, batch in enumerate(train_dataloader):
if step % 10 == 0 and step != 0:
print("Epoch:{} Step:{} Loss:{:.3f}".format(epoch, step, aggregate_loss/step))
batch = batch[0].to(torch.float)
predictions = visionTransformer(batch)
x_a, x_b = randomAugmentation(batch)
z_embeddings_a = nn.functional.normalize(visionTransformer.module.vit(x_a), dim=1)
z_embeddings_b = nn.functional.normalize(visionTransformer.module.vit(x_b), dim=1)
loss_1 = instanceDiscriminationLoss(predictions).mean()
loss_2 = contrastiveLearningLoss(z_embeddings_a, z_embeddings_b).mean()
total_loss = loss_1 + loss_2
aggregate_loss += total_loss.item()
print("Total Epoch {} Loss : {}".format(epoch, aggregate_loss/len(train_dataloader)))
Well, that’s it for now. We’ve implemented the paper “Training Vision Transformers with Only 2040 Images” in PyTorch and it seems to work. Thanks for taking the patience to read.