The lecture on building GPT by Andrej Karpathy is a MUST-watch for everyone who’s in NLP. In this post, I’ll be summarizing my notes from the course. You can watch it here.
GPT stands for Generative Pre-trained Transformer. The one we’ll be building in this post is very similar to the actual technology behind ChatGPT. Let’s get to it.
Setting up the data
Our task will be next-character prediction. For learning purposes, we will use a ~1M character dataset composed of Shakespeare’s texts. In the end, we’ll have a Shakespearean GPT model that can output plays just like (maybe not that much, but nevermind) the original ones.
The flow for preprocessing the dataset and initializing the model is similar to the one covered in previous posts. Here’s a summary.
Based on the vocabulary, we obtain a set of characters, which is then ordered. Then, we tokenize the input text, converting the raw strings (play lines) into integers. These integers indicate the indexes of the ordered set of characters and are used to pick rows from the embedding matrix that should be retrieved for model input.
- By the way, character-level is just one possibility for tokenization. Modern alternatives include subwords and byte-pair encoding (the latter is the one that’s used in ChatGPT).
Then, we split the dataset into train and test so we can estimate, at the end of training, if the performance of our model is generalizing well or if it’s just overfitting (memorizing) the training data.
The context length matters
There are multiple ways to frame the next-character prediction problem. We can do it in the naïve bigram approach (covered in this post), where only the immediate previous token is used to predict the next one.
The problem with this is that the context window is super small. As previous tokens don’t interact with each other, predictions are very limited. There’s no mechanism (like attention, recurrence, or convolutions) that allows the model to use information from other tokens in the sequence.
So, let’s develop a more complex model that allows for a bigger context window. For this, we set a block size, a parameter that indicates the maximum number of previous characters considered for predicting the next one . The context can then slowly build up, and tokens are integrated gradually, helping avoid information loss (the importance of this was extensively covered in this post).
Our block size is equal to eight, so we’ll use up to eight characters to predict the next. In a sequence, we can then have up to eight different contextual examples, as the first character can be used to predict the second; the first two to predict the third, and so on:
Furthermore, we’ll be using random mini-batches of examples in the training process. The sequences are processed independently but grouped for efficiency. In this setting, if we have, e.g., 4 sequences, each will have its own 8 respective examples:
Causal masking
Given these examples, there are some different ways we can fuse the representation of previous characters to predict the next one.
First, we have to keep in mind that information should only flow from the previous context to the current timestep. For example, the token at the 5th location shouldn’t communicate with tokens in the 6th, 7th, and so on, as these are “future” tokens — they come after the 5th token.
We want the 5th token to just interact with tokens 1 through 4, ensuring that at each step, the model is only using information from the past. This is implemented through causal masking, which was exactly what we did when we created the examples from a single sequence (refer to the last image).
Considering this, the easiest way to make the correct tokens communicate is to average the embeddings of the preceding characters.
- So if I’m the 5th token, I can gather my own and my predecessors' embeddings, average them out, and use the resulting feature vector to try and predict the next character.
Because we’re averaging tokens, disregarding their relative orders, we call this approach a “bag of words (or characters).” The cool thing about it is that we can implement the averaging with matrix multiplication for all examples at once.
Specifically, we can use a lower triangular matrix, where each row determines how much of the previous characters to include. If we set each row to sum to 1, the matrix will then take the average of the previous elements when multiplied with another matrix.
- First row, consider only first character: Take 100% of 2 and 0% of 6 and 6, resulting in 2. Then, take 100% of 7 and 0% of 4 and 5, resulting in 7.
- Second row, consider first two characters: Take the mean of 2 and 6, resulting in 4. Then, take the mean of 7 and 4, resulting in 5.5.
If we create a matrix wei that matches the shape of x (ours is 8x8, where each row represents an example) and multiply it by x, what we get is a masked x, where a weighted aggregation of past elements is achieved.
In this initial approach, each past element is weighted equally. Also, wei is the same for every sequence:
We can then use wei to fuse information from past tokens, making the embeddings more contextually aware.
Self-attention
I think you can see where this is going. In the previous case, wei was simply averaging all the past characters, giving each element the exact same weight. Moreover, wei remained the same for every input, regardless of their specific contents.
In self-attention, what changes is that in the averaging process, different past characters can have different weights. Furthermore, we will have multiple wei’s, each one being calculated specifically for each input sequence.
In summary: self-attention allows for computing a unique set of weights for each input sequence. In this setting, different tokens can find past tokens more or less interesting in a data-dependent fashion.
Queries, keys, and values
In practice, the way we obtain each respective wei (the matrix of attention weights) is by using queries and keys:
- Every token in the sequence emits two vectors based on learned linear projections of their features: the query (q), meaning “What am I looking for?” and the key (k), “What do I contain?”
- A dot product is taken between q’s and the respective k’s of past tokens in the sequence. The results indicate the affinity between them.
- If q and k are aligned, they interact a lot, so the resulting dot product is higher. This way, a given k is integrated more than other past tokens in the sequence.
This operation outputs attention scores (logits), which are then passed through a softmax to obtain probabilities. The probabilities are the actual attention weights that are used in the aggregation process.
- We also convert all the 0s in the matrix, which indicate the masked tokens, to -infinity so that we don’t have issues when applying the softmax.
Here’s a visualization of wei. You can see it is now actually a set of 4 masks, each slightly different and adapted to each of our 4 examples:
One more detail: we don't actually aggregate the exact token embeddings; we aggregate their values, which also come from a learned linear projection based on their original embeddings.
- Think of it like this: the original feature vector, x is like private information of the token; it’s its identity.
- So, in an attention head: q is “Here’s what I’m interested in,” k is “Here’s what I have,” and v is “If you find me interesting, here’s what I’ll communicate to you” (what gets aggregated).
We can then use each wei to fuse information from past tokens in a data-dependent manner, making the learned representations way more expressive.
Moreover, it’s important to mention that attention scores should be scaled by dividing them by the square root of the head dimension. This preserves the unit variance of the weights.
- Without scaling, the scores could vary widely, causing the softmax function to converge toward a one-hot encoding.
- This would lead each node to aggregate information predominantly from a single node, limiting the attention mechanism.
Now I think you can understand the equation below!
Attention is a communication mechanism
Now that we understand attention let’s reinforce some intuitions behind it.
We can think of attention as a communication mechanism between nodes in a directed graph (where the direction of the arrows matter).
In this graph, every node has an associated vector of information and can aggregate information via a weighted sum of the other nodes that point to it.
When processing sequences, we have a directed graph for each of the example sequences.
- Considering the character sequences we’ve been processing, we have 4 pools of 8 nodes each, which are processed independently.
The connections of the graph are established in an autoregressive way: the first node points to itself, the second node points to node 1 and itself, up to node 8, representing the causal masking we’ve been talking about.
Here’s an example considering a sequence of 4 nodes:
As attention acts over a set of vectors, nodes don't inherently know where they are in space. That’s why we use something called positional encodings, vectors that are summed to the original token embeddings to introduce the notion of “who came first” (this won’t be covered here in detail, all you need to know is that it introduces ordering).
- In other words, there is a specific vector that, when summed to a node’s embedding, makes it “know” that it’s the first and the same for the second, third etc.
By the way, this autoregressive style doesn’t have to be always present. As attention is a communication mechanism, we can set the arrows as we want.
For example, we could instead model a problem where every token attends to every other token — what is commonly done in sentence-based sentiment analysis. You can also expand this to any other kind of graph, where nodes are other entities, such as cities, patient visits, or whatever you want.
Cross-attention
When we want to pull information from a separate source of nodes, we use something called cross-attention. “Cross” because the keys and values come from a different source than the source of the queries.
This is useful when you want to make different data modalities talk to each other, i.e. allowing image features to contribute to enriching textual representations or vice versa. You can read more about cross-attention here.
Multi-head attention
What if our tokens have a lot to talk about? So far, we’ve been considering a single channel for the tokens. Couldn’t we open more communication channels so we can capture different nuances on how other tokens might contribute to new representations?
Yes, we can! The way we achieve this is by applying multiple attention operations (or attention heads) in parallel over the inputs and then concatenating the results, which become the final attention layer’s output. We call this multi-head attention. In our case, considering causal masking, multi-head self-attention.
- One detail: the input vector is divided across the heads. However, each head still processes a unique projection of the entire input vector.
- So, if there are 4 heads and the input vectors have 32 dimensions, each head will project the entire input into a smaller, 8-dimensional subspace. Each head then independently performs attention on this reduced 8-dimensional representation.
- After, these independent outputs are concatenated, resulting in a combined 32-dimensional output that matches the original input size.
Letting tokens think
Via self-attention, we allowed tokens to “talk to each other.” However, they still need to process — or “think about” — the information they’ve gathered.
Self-attention helps tokens exchange data, but there’s still more work to be done to understand what they’ve learned. This is done in a per-token (or per-node) feedforward layer.
In the feedforward layer, each token can “reflect” on the new information and update its own representation accordingly. Furthermore, the nonlinearity in the layer allows the network to model more complex relationships in the data.
In practice, what we will do considering model architecture is to alternate between “communicate” (attention) and “compute” (feedforward) layers until we get to the last layer.
Optimization
There are still some things we should implement in our architecture to help with the training process.
The first is adding skip/residual connections. This involves adding back the input x after computing a transformation F(x).
- Early in training, residual blocks act as identity mappings, passing inputs mostly unchanged. In summary, they kind of “do nothing”. During optimization, these pathways start contributing more actively, allowing information to “fork off” for specific computations.
- Over time, the transformations in F(x) improve, and the residual connections help integrate these learned transformations with the original input.
- Residual connections allow gradients to flow easily through addition nodes, enabling gradients from the loss to propagate directly to the input while also branching through the residual blocks. This helps prevent vanishing gradients and ensures more stable learning.
The second optimization maneuver is implementing layer normalization, which normalizes the features within each layer on a per-example basis to avoid gradient instability. Unlike batch norm — which normalizes across the entire batch — layer norm computes statistics based on the feature vectors of each instance, making it more effective for models where batch stats vary a lot.
The third and last one is dropout, which involves deactivating a random percentage of neurons at each training step. This makes the network function like an ensemble of subnetworks, as different subsets of neurons are trained together across iterations. By doing so, dropout reduces reliance on any single neuron and helps prevent overfitting.
Generating Shakespearean text
After setting the GPT model and training it in our data, we get some pretty nice outputs, way better than the ones from the bigram models!
And that’s all! We’ve covered the whole of developing and training a GPT-like model from scratch. All code is available in the nanoGPT GitHub.
Here are some other details:
Encoder-decoder architecture
Encoders encode information. They take text and create vectorized representations. No causal masking is involved, meaning that every token is allowed to attend to every other token (both past and future). This way, the model can output global representations from text.
Decoders, on the other hand, decode information. They take these global representations and decode them to generate text. Here, causal masking is essential because it ensures that each token only attends to previous tokens and itself, which is crucial for autoregressive tasks like text generation.
In machine translation, this setup is essential. The encoder first creates a global representation of a text in e.g. French, encoding the entire input sentence. This encoded information is then passed to the decoder, which uses both the encoded French input and its own past outputs to sequentially predict the next token in the target language.
How ChatGPT works
GPT is a decoder-only model pre-trained on a large corpus of internet text. It uses chunks of words as tokens, not single characters like in our Shakespearean model.
By the end of pre-training, it learns how to babble some text. At this point, its outputs are what statistically makes more sense given the text it has seen during training, acting more like a document completer.
- For example, if given a prompt of questions, it will likely generate other similar questions or responses based on its understanding of question patterns.
After this, it undergoes a fine-tuning stage to properly become an assistant. This involves further training on data formatted as question-answer pairs. Next, a labeler ranks the model’s responses from best to worst to train a reward model.
Finally, a reinforcement learning algorithm fine-tunes the response generation, optimizing it to produce answers that are likely to score highly based on the reward model’s evaluations.
And that's all. See you next time.