Training Vision Transformers with Only 2040 Images: Implementation






The architecture of the model is quite straightforward. It’s the losses that are more complicated. For the network, we use a vision transformer combined with a linear layer to predict the probabilities of each class. We’re going to be using the Vision Transformer library, vit-pytorch, for initializing the vision transformer.

class VisionTransformer(nn.Module):
    def __init__(self, img_size, z_dim,num_classes):
        super(VisionTransformer, self).__init__()
        self.vit = ViT(
            image_size = img_size,
            patch_size = 32,
            num_classes = z_dim,
            dim = 1024,
            depth = 6,
            heads = 16,
            mlp_dim = 2048,
            dropout = 0.1,
            emb_dropout = 0.1
        self.linear = nn.Linear(in_features=z_dim, out_features=num_classes)
    def forward(self, x):
        x = self.vit(x)
        x = self.linear(x)
        x = nn.functional.softmax(x)
        return x


There are two losses in the original paper.

  1. Instance Discrimination Loss
  2. Contrasive Learning Loss

Instance Discrimination Loss

The instance discrimination loss, \(L_{InsDis}\) is defined as follows

\[L_{InsDis} = – \sum^{N}_{i=1}\sum^{N}_{c=1}y_{c}^{i}\log{P_{c}^{i}}\]

where \(c\) sums over classes and \(i\) sums over instances passed to the network, which are batches

class InstanceDiscriminationLoss(nn.Module):
    def __init__(self):
        super(InstanceDiscriminationLoss, self).__init__()

    def forward(self, predictions):
        return -torch.sum(torch.log(predictions))

Contrastive Learning Loss

The constrasive learning loss \(L_{CN}\) is defined as follows

\[L_{CN} = -\sum^{N}_{i=1}z_{iA}^{T}z_{iB} + \sum^{N}_{i=1} \log{(e^{z_{iA}^{T}z_{iB}}} + \sum{e^{z_{iA}^{T}z_{i}^{-}}})\]

Here \(z_{iA}\) and \(z_{iB}\) are features extracted from augmented versions of the image \(x_{i}\). Let’s write an augmentation layer first

class RandomAugmentation(nn.Module):
    def __init__(self):
        super(RandomAugmentation, self).__init__()
        self.augment = transforms.Compose(

    def forward(self, x):
        x_a = self.augment(x)
        x_b = self.augment(x)
        return x_a, x_b

Next up is the loss class.

class ContrastiveLearningLoss(nn.Module):
    def __init__(self):
        super(ContrastiveLearningLoss, self).__init__()

    def forward(self, z_a, z_b):
        n = len(z_a)

        alignment = 0
        for i in range(n):
            alignment += -torch.sum([i].T, z_b[i]))
        uniformity_loss = 0
        for i in range(n):
            negative_sum = 0
            for j in range(n-1):
                if i == j:
                negative_sum += torch.sum([i].T, z_b[i]))
            uniformity = torch.exp([i].T, z_b[i]))
            uniformity_loss += negative_sum + uniformity
        return alignment + uniformity_loss


The preprocessing part is pretty straightforward. We define a custom dataset class for extracting the data into data loader. We also define all the 17 classes for the flower dataset.

labels_2_idx = {
    "Colts Foot":1,
    "Lily Valley":11,
    "Tiger lily":14,
idx_2_labels = {}

for key, value in labels_2_idx.items():
    idx_2_labels[value] = key

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = open(annotations_file,"r").read()
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = int(np.floor((idx + 1/17)))
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        print(image, label)
        return image, label

dataset = CustomImageDataset(
train_set, val_set =, [1360 - 136, 136])

from import DataLoader

train_dataloader = DataLoader(train_set, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_set, batch_size=64, shuffle=True)


The training part is pretty straightforward as well. You just have to iterate through the batches and add up the two losses.

epochs = 12

visionTransformer = nn.DataParallel(VisionTransformer(224, 10, 17))
randomAugmentation = nn.DataParallel(RandomAugmentation())
randomAugmentation .cuda()
instanceDiscriminationLoss = nn.DataParallel(InstanceDiscriminationLoss())
contrastiveLearningLoss = nn.DataParallel(ContrastiveLearningLoss())
optimizer = torch.optim.AdamW(visionTransformer.parameters(), lr=0.001)
for epoch in range(epochs):
    aggregate_loss = 0
    for step, batch in enumerate(train_dataloader):
        if step % 10  == 0 and step != 0:
            print("Epoch:{} Step:{} Loss:{:.3f}".format(epoch, step, aggregate_loss/step))
        batch = batch[0].to(torch.float)
        predictions = visionTransformer(batch)
        x_a, x_b = randomAugmentation(batch)
        z_embeddings_a = nn.functional.normalize(visionTransformer.module.vit(x_a), dim=1)
        z_embeddings_b =  nn.functional.normalize(visionTransformer.module.vit(x_b), dim=1)

        loss_1 = instanceDiscriminationLoss(predictions).mean()
        loss_2 = contrastiveLearningLoss(z_embeddings_a, z_embeddings_b).mean()

        total_loss = loss_1 + loss_2
        aggregate_loss += total_loss.item()
    print("Total Epoch {} Loss : {}".format(epoch, aggregate_loss/len(train_dataloader)))


Well, that’s it for now. We’ve implemented the paper “Training Vision Transformers with Only 2040 Images” in PyTorch and it seems to work. Thanks for taking the patience to read.

Discover more from Niranjan Krishna

Subscribe to get the latest posts to your email.

Leave a Reply

Your email address will not be published. Required fields are marked *