Training Vision Transformers with Only 2040 Images: Implementation






The architecture of the model is quite straightforward. It’s the losses that are more complicated. For the network, we use a vision transformer combined with a linear layer to predict the probabilities of each class. We’re going to be using the Vision Transformer library, vit-pytorch, for initializing the vision transformer.

class VisionTransformer(nn.Module):
    def __init__(self, img_size, z_dim,num_classes):
        super(VisionTransformer, self).__init__()
        self.vit = ViT(
            image_size = img_size,
            patch_size = 32,
            num_classes = z_dim,
            dim = 1024,
            depth = 6,
            heads = 16,
            mlp_dim = 2048,
            dropout = 0.1,
            emb_dropout = 0.1
        self.linear = nn.Linear(in_features=z_dim, out_features=num_classes)
    def forward(self, x):
        x = self.vit(x)
        x = self.linear(x)
        x = nn.functional.softmax(x)
        return x


There are two losses in the original paper.

  1. Instance Discrimination Loss
  2. Contrasive Learning Loss

Instance Discrimination Loss

The instance discrimination loss, \(L_{InsDis}\) is defined as follows

\[L_{InsDis} = – \sum^{N}_{i=1}\sum^{N}_{c=1}y_{c}^{i}\log{P_{c}^{i}}\]

where \(c\) sums over classes and \(i\) sums over instances passed to the network, which are batches

class InstanceDiscriminationLoss(nn.Module):
    def __init__(self):
        super(InstanceDiscriminationLoss, self).__init__()

    def forward(self, predictions):
        return -torch.sum(torch.log(predictions))

Contrastive Learning Loss

The constrasive learning loss \(L_{CN}\) is defined as follows

\[L_{CN} = -\sum^{N}_{i=1}z_{iA}^{T}z_{iB} + \sum^{N}_{i=1} \log{(e^{z_{iA}^{T}z_{iB}}} + \sum{e^{z_{iA}^{T}z_{i}^{-}}})\]

Here \(z_{iA}\) and \(z_{iB}\) are features extracted from augmented versions of the image \(x_{i}\). Let’s write an augmentation layer first

class RandomAugmentation(nn.Module):
    def __init__(self):
        super(RandomAugmentation, self).__init__()
        self.augment = transforms.Compose(

    def forward(self, x):
        x_a = self.augment(x)
        x_b = self.augment(x)
        return x_a, x_b

Next up is the loss class.

class ContrastiveLearningLoss(nn.Module):
    def __init__(self):
        super(ContrastiveLearningLoss, self).__init__()

    def forward(self, z_a, z_b):
        n = len(z_a)

        alignment = 0
        for i in range(n):
            alignment += -torch.sum([i].T, z_b[i]))
        uniformity_loss = 0
        for i in range(n):
            negative_sum = 0
            for j in range(n-1):
                if i == j:
                negative_sum += torch.sum([i].T, z_b[i]))
            uniformity = torch.exp([i].T, z_b[i]))
            uniformity_loss += negative_sum + uniformity
        return alignment + uniformity_loss


The preprocessing part is pretty straightforward. We define a custom dataset class for extracting the data into data loader. We also define all the 17 classes for the flower dataset.

labels_2_idx = {
    "Colts Foot":1,
    "Lily Valley":11,
    "Tiger lily":14,
idx_2_labels = {}

for key, value in labels_2_idx.items():
    idx_2_labels[value] = key

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = open(annotations_file,"r").read()
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = int(np.floor((idx + 1/17)))
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        print(image, label)
        return image, label

dataset = CustomImageDataset(
train_set, val_set =, [1360 - 136, 136])

from import DataLoader

train_dataloader = DataLoader(train_set, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_set, batch_size=64, shuffle=True)


The training part is pretty straightforward as well. You just have to iterate through the batches and add up the two losses.

epochs = 12

visionTransformer = nn.DataParallel(VisionTransformer(224, 10, 17))
randomAugmentation = nn.DataParallel(RandomAugmentation())
randomAugmentation .cuda()
instanceDiscriminationLoss = nn.DataParallel(InstanceDiscriminationLoss())
contrastiveLearningLoss = nn.DataParallel(ContrastiveLearningLoss())
optimizer = torch.optim.AdamW(visionTransformer.parameters(), lr=0.001)
for epoch in range(epochs):
    aggregate_loss = 0
    for step, batch in enumerate(train_dataloader):
        if step % 10  == 0 and step != 0:
            print("Epoch:{} Step:{} Loss:{:.3f}".format(epoch, step, aggregate_loss/step))
        batch = batch[0].to(torch.float)
        predictions = visionTransformer(batch)
        x_a, x_b = randomAugmentation(batch)
        z_embeddings_a = nn.functional.normalize(visionTransformer.module.vit(x_a), dim=1)
        z_embeddings_b =  nn.functional.normalize(visionTransformer.module.vit(x_b), dim=1)

        loss_1 = instanceDiscriminationLoss(predictions).mean()
        loss_2 = contrastiveLearningLoss(z_embeddings_a, z_embeddings_b).mean()

        total_loss = loss_1 + loss_2
        aggregate_loss += total_loss.item()
    print("Total Epoch {} Loss : {}".format(epoch, aggregate_loss/len(train_dataloader)))


Well, that’s it for now. We’ve implemented the paper “Training Vision Transformers with Only 2040 Images” in PyTorch and it seems to work. Thanks for taking the patience to read.

