In this notebook I'm showing how to use PyTorch and Huggingface Transformers to fine-tune a pre-trained transformers model to do natural language inference (NLI). In NLI the aim is to model the inferential relationship between two or more given sentences. In particular, given two sentences - the premise p and the hypothesis h - the task is to determine whether h is entailed by p, whether the sentences are in contradiction with each other or whether there is no inferential relationship between the sentences (neutral).

So let's get started! First we need to install the python libraries using the following command.

!pip3 install torch transformers datasets

We will then import the needed libraries. We are using DistilBERT model for this task so we need to import the relevant DistilBERT model designed for sequence classification task and the corresponding tokeniser.

import torch
from torch.utils.data import DataLoader
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, AdamW, logging
import datasets
from tqdm import tqdm
import numpy as np
logging.set_verbosity_error()

Let's load the MultiNLI dataset using the Huggingface Datasets library. For this demonstration we are using only the training and validation data. We are also further limiting the training data to just 20,000 sentence pairs. This will not allow us to train a good quality model, but it speeds up the demonstration. You can change the values here or use the whole dataset. However, be aware that fine tuning the model will take a lot of time.

nli_data = datasets.load_dataset("multi_nli")

train_data = nli_data['train'][:20000] # limiting the training set size to 20,000 for demo purposes
train_labels = train_data['label']

dev_data = nli_data['validation_matched']
val_labels = dev_data['label']

Next we will initialise the tokeniser and tokenise our training and validation data. Notice that we are two lists of sentences to both the training and validation set. This is because in NLI we are classifying pairs of sentences: the premise and the hypothesis.

tokeniser = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
train_encodings = tokeniser(train_data['premise'], train_data['hypothesis'], truncation=True, padding=True)
val_encodings = tokeniser(dev_data['premise'], dev_data['hypothesis'], truncation=True, padding=True)

Once the data has been tokenised we will create a NLIDataset object for our data. Here we are creating a subclass that inherits the torch.utils.data.Dataset class.

class NLIDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.encodings.input_ids)

Once we've defined our dataset class we can initialise the training and validation datasets with our tokenised sentence pairs and labels. We will then create DataLoader objects for the training and validation data.

train_dataset = NLIDataset(train_encodings, train_labels)
val_dataset = NLIDataset(val_encodings, val_labels)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)

Now, before we can start training, we need to import our model and optimiser to be used in training. We first set the device and use cuda if GPU is available. We then get the pre-trained DistilBERT model specifying the number of classes we are classifying to.

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=3)
model.to(device)
model.train()
optim = AdamW(model.parameters(), lr=5e-5)

Now we are ready to train the model. In this demonstration we are fine-tuning for just three epochs, but you can change the value to something more meaningful if you like. Note that you could also use the Transformers Trainer class to fine-tune the model but I've chosen to use native PyTorch instead.

epochs = 3
for epoch in range(epochs):
    all_losses = []

    for batch in tqdm(train_loader, total=len(train_loader), desc="Epoch: {}/{}".format(epoch+1, epochs)):
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        loss.backward()
        optim.step()
        all_losses.append(loss.item())
        
    print("\nMean loss: {:<.4f}".format(np.mean(all_losses)))
Epoch: 1/3: 100%|██████████| 1250/1250 [15:31<00:00,  1.34it/s]
Epoch: 2/3:   0%|          | 0/1250 [00:00<?, ?it/s]
Mean loss: 0.8789
Epoch: 2/3: 100%|██████████| 1250/1250 [15:27<00:00,  1.35it/s]
Epoch: 3/3:   0%|          | 0/1250 [00:00<?, ?it/s]
Mean loss: 0.5912
Epoch: 3/3: 100%|██████████| 1250/1250 [15:27<00:00,  1.35it/s]
Mean loss: 0.3316

Once the model has been trained we can evaluate it to get the validation accuracy for our model.

model.eval()
with torch.no_grad():
    eval_preds = []
    eval_labels = []

    for batch in tqdm(val_loader, total=len(val_loader)):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        preds = model(input_ids, attention_mask=attention_mask, labels=labels)
        preds = preds[1].argmax(dim=-1)
        eval_preds.append(preds.cpu().numpy())
        eval_labels.append(batch['labels'].cpu().numpy())

print("\nValidation accuracy: {:6.2f}".format(round(100 * (np.concatenate(eval_labels) == np.concatenate(eval_preds)).mean()), 2))
100%|██████████| 614/614 [02:26<00:00,  4.18it/s]
Validation accuracy:  69.00

Now we are all done. As you can see the results are far from state of the art if you use just a fraction of the training data.

Hope you enjoyed this demo. Feel free to contact me if you have any questions.