Paper: https://arxiv.org/pdf/2201.10728v1.pdf
Github: https://github.com/niranjankrishna-acad/Training-Vision-Transformers-with-Only-2040-Images
Architecture
The architecture of the model is quite straightforward. It’s the losses that are more complicated. For the network, we use a vision transformer combined with a linear layer to predict the probabilities of each class. We’re going to be using the Vision Transformer library, vit-pytorch, for initializing the vision transformer.
class VisionTransformer(nn.Module):
def __init__(self, img_size, z_dim,num_classes):
super(VisionTransformer, self).__init__()
self.vit = ViT(
image_size = img_size,
patch_size = 32,
num_classes = z_dim,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
self.linear = nn.Linear(in_features=z_dim, out_features=num_classes)
def forward(self, x):
x = self.vit(x)
x = self.linear(x)
x = nn.functional.softmax(x)
return x
Loss
There are two losses in the original paper.
- Instance Discrimination Loss
- Contrasive Learning Loss
Instance Discrimination Loss
The instance discrimination loss, \(L_{InsDis}\) is defined as follows
\[L_{InsDis} = – \sum^{N}_{i=1}\sum^{N}_{c=1}y_{c}^{i}\log{P_{c}^{i}}\]
where \(c\) sums over classes and \(i\) sums over instances passed to the network, which are batches
class InstanceDiscriminationLoss(nn.Module):
def __init__(self):
super(InstanceDiscriminationLoss, self).__init__()
def forward(self, predictions):
return -torch.sum(torch.log(predictions))
Contrastive Learning Loss
The constrasive learning loss \(L_{CN}\) is defined as follows
\[L_{CN} = -\sum^{N}_{i=1}z_{iA}^{T}z_{iB} + \sum^{N}_{i=1} \log{(e^{z_{iA}^{T}z_{iB}}} + \sum{e^{z_{iA}^{T}z_{i}^{-}}})\]
Here \(z_{iA}\) and \(z_{iB}\) are features extracted from augmented versions of the image \(x_{i}\). Let’s write an augmentation layer first
class RandomAugmentation(nn.Module):
def __init__(self):
super(RandomAugmentation, self).__init__()
self.augment = transforms.Compose(
[
transforms.RandomRotation(35),
transforms.ColorJitter(),
]
)
def forward(self, x):
x_a = self.augment(x)
x_b = self.augment(x)
return x_a, x_b
Next up is the loss class.
class ContrastiveLearningLoss(nn.Module):
def __init__(self):
super(ContrastiveLearningLoss, self).__init__()
def forward(self, z_a, z_b):
n = len(z_a)
alignment = 0
for i in range(n):
alignment += -torch.sum(torch.dot(z_a[i].T, z_b[i]))
uniformity_loss = 0
for i in range(n):
negative_sum = 0
for j in range(n-1):
if i == j:
continue
negative_sum += torch.sum(torch.dot(z_a[i].T, z_b[i]))
uniformity = torch.exp(torch.dot(z_a[i].T, z_b[i]))
uniformity_loss += negative_sum + uniformity
return alignment + uniformity_loss
Preprocessing
The preprocessing part is pretty straightforward. We define a custom dataset class for extracting the data into data loader. We also define all the 17 classes for the flower dataset.
labels_2_idx = {
"Buttercup":0,
"Colts Foot":1,
"Daffodil":2,
"Daisy":3,
"Dandelion":4,
"Fritilary":5,
"Iris":6,
"Pansy":7,
"Sunflower":8,
"Windflower":9,
"Snowdrop":10,
"Lily Valley":11,
"Bluebell":12,
"Crocus":13,
"Tiger lily":14,
"Tulip":15,
"Cowslip":16
}
idx_2_labels = {}
for key, value in labels_2_idx.items():
idx_2_labels[value] = key
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = open(annotations_file,"r").read()
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = int(np.floor((idx + 1/17)))
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
print(image, label)
return image, label
dataset = CustomImageDataset(
"data/files.txt",
"data"
)
train_set, val_set = torch.utils.data.random_split(dataset, [1360 - 136, 136])
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_set, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_set, batch_size=64, shuffle=True)
Training
The training part is pretty straightforward as well. You just have to iterate through the batches and add up the two losses.
epochs = 12
visionTransformer = nn.DataParallel(VisionTransformer(224, 10, 17))
visionTransformer.cuda()
randomAugmentation = nn.DataParallel(RandomAugmentation())
randomAugmentation .cuda()
instanceDiscriminationLoss = nn.DataParallel(InstanceDiscriminationLoss())
instanceDiscriminationLoss.cuda()
contrastiveLearningLoss = nn.DataParallel(ContrastiveLearningLoss())
contrastiveLearningLoss.cuda()
optimizer = torch.optim.AdamW(visionTransformer.parameters(), lr=0.001)
for epoch in range(epochs):
aggregate_loss = 0
for step, batch in enumerate(train_dataloader):
if step % 10 == 0 and step != 0:
print("Epoch:{} Step:{} Loss:{:.3f}".format(epoch, step, aggregate_loss/step))
batch = batch[0].to(torch.float)
predictions = visionTransformer(batch)
x_a, x_b = randomAugmentation(batch)
z_embeddings_a = nn.functional.normalize(visionTransformer.module.vit(x_a), dim=1)
z_embeddings_b = nn.functional.normalize(visionTransformer.module.vit(x_b), dim=1)
loss_1 = instanceDiscriminationLoss(predictions).mean()
loss_2 = contrastiveLearningLoss(z_embeddings_a, z_embeddings_b).mean()
total_loss = loss_1 + loss_2
optimizer.zero_grad()
total_loss.backward()
aggregate_loss += total_loss.item()
optimizer.step()
print("Total Epoch {} Loss : {}".format(epoch, aggregate_loss/len(train_dataloader)))
Conclusion
Well, that’s it for now. We’ve implemented the paper “Training Vision Transformers with Only 2040 Images” in PyTorch and it seems to work. Thanks for taking the patience to read.