- Published on
Building an LLM from scratch part 6 - Finetuning for Text Classification
- Authors

- Name
- Matt Du-Feu
- @mattdufeu
The first part of the chapter covers preparing the data. As in the previous chapter, we need training, validation and test data. We're using the SMS Spam Collection from the UC Irvine Machine Learning Repository, which is slightly imbalanced. As there are many more "not spam" messages than "spam". Again, I won't plagiarise the code from the book, but we end up with three DataLoaders (train, test and validation).
Creating the Classification model
This is the incredible part.
Recall from the previous chapters, we have built an LLM architecture where the Linear output layer mapped 768 hidden units to the 50,027 possible tokens in our vocabulary.
We don't want 50,027 possible outputs. We want "spam" and "not spam". So we simply swap the output layer:
num_classes = 2
model.out_head = torch.nn.Linear(
in_features=BASE_CONFIG["emb_dim"],
out_features=num_classes)
If we feed this model some text to classify, it's no better than randomly guessing. We need to teach it (or fine-tune the model) how to perform classification.
Fine-Tuning
Fine-tuning is essentially performing another training run, but with data specific to the task you want to specialise in.
We don't want to train all the parts of the model. That would be throwing away all our hard work. So instead, after loading the OpenAI weights, we freeze the model with:
for param in model.parameters():
param.requires_grad = False
Note: you actually do this freezing before swapping out the output head.
The book mentions that you could technically just train the output layer, but apparently also making the last transformer block and the final LayerNorm module improves results. So we unfreeze those two with:
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
for param in model.final_norm.parameters():
param.requires_grad = True
We then setup a loss function (per batch in this instance) using cross entropy and train on our SMS dataset.
Even on my 3060 Ti, it only took 30 seconds to perform the fine-tuning, resulting in over 95% accuracy.
Summary
This felt a little like magic to me. How can replacing just the output layer convert a model that predicts the next token into one that can classify spam?
With a few chats with an LLM, I think I've put the pieces together. In particular, the original LLM encodes knowledge about the language it was trained on. The later layers of the model encode more complex language features, like relationships between words and how surrounding words affect meaning.
The output layer "just" translates that knowledge into a dimension we want. Originally this was "what token of the vocabulary comes next". By swapping this output layer and training those new parameters we're building on top of all that prior knowledge but not changing it.
I'm guessing that retraining the last transformer block and the final LayerNorm module improves results because we can spread that new knowledge over more parameters. The loss of higher-complexity reasoning in the "next-token" space isn't relevant (or may not have an overly negative impact) on the classification task.