Attention#

At its core, attention is a mechanism that allows a token vector to include information about the context of that token (i.e. the surrounding token vectors). This way the token vector can “pay attention” to the surrounding token vectors.

As an example consider the token sequence \(T_0, T_1, T_2\) (like “an”, “example”, “sequence”). It seems sensible that token \(T_2\) should have some information about \(T_0\) and \(T_1\) if we want to model the sequence successfully.

As usual, we will need a few imports from torch:

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)
<torch._C.Generator at 0x7f737c1c00d0>

Linear Combinations#

We will continue working with our example sequence \(T_0, T_1, T_2\). We will call the vectors that represent the tokens \(\mathbf{x_0}, \mathbf{x_1}\) and \(\mathbf{x_2}\) respectively.

Let’s say that every token is represented by a vector of dimension \(5\), i.e. the entire sequence is represented by a tensor of dimension \(3\times 5\).

We will use a random tensor throughout this section, in reality, the tensor would be the result of the layer that comes before the attention layer:

X = torch.randn(3, 5)
print(X)
tensor([[ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229],
        [-0.1863,  2.2082, -0.6380,  0.4617,  0.2674],
        [ 0.5349,  0.8094,  1.1103, -1.6898, -0.9890]])

Now, let’s say that we would like the token vectors to be able to “pay attention” to each other. The simplest way to accomplish this would be to calculate the averages of the token vectors.

For example, if we would like to get the information contained in \(T_0\) and \(T_1\), we might compute the average of \(\mathbf{x_0}\) and \(\mathbf{x_1}\):

avg = 1 / 2 * X[0] + 1 / 2 * X[1]
print(avg)
tensor([ 0.0752,  1.1685, -0.2018,  0.3460, -0.4278])

Similarly, if we would like to combine the information in \(T_0, T_1\) and \(T_2\), we might compute the average of \(\mathbf{x_0}, \mathbf{x_1}\) and \(\mathbf{x_2}\):

avg = 1 / 3 * X[0] + 1 / 3 * X[1] + 1 / 3 * X[2]
print(avg)
tensor([ 0.2284,  1.0488,  0.2356, -0.3326, -0.6148])

This doesn’t look like a great idea, primarily because not every token is equally important for every token. For example, if we have the sentence “The bat flew out of the cave”, then “bat” should probably pay more attention to “cave” than to “of”.

Therefore, we want to have data-driven weights in our linear combination, i.e. we would like to compute arbitrary linear combinations like:

\(w_0 \cdot \mathbf{x_0} + w_1 \cdot \mathbf{x_1} + w_2 \cdot \mathbf{x_2}\)

where the weights \(w_0, w_1\) and \(w_2\) are data-driven parameters.

That is, the input of a hypothetical attention layer would be a tensor containing the vectors \(\mathbf{x_0}, \mathbf{x_1}\) and \(\mathbf{x_2}\), while the output would be another tensor containing new vectors of \(\mathbf{y_0}, \mathbf{y_1}\) and \(\mathbf{y_2}\) that are certain linear combinations of the input vectors:

\(w_{00} \cdot \mathbf{x_0} + w_{01} \cdot \mathbf{x_1} + w_{02} \cdot \mathbf{x_2} = \mathbf{y_0}\)

\(w_{10} \cdot \mathbf{x_0} + w_{11} \cdot \mathbf{x_1} + w_{12} \cdot \mathbf{x_2} = \mathbf{y_1}\)

\(w_{20} \cdot \mathbf{x_0} + w_{21} \cdot \mathbf{x_1} + w_{22} \cdot \mathbf{x_2} = \mathbf{y_2}\)

Note that both \(\mathbf{x_0}, \mathbf{x_1}, \mathbf{x_2}\) and \(\mathbf{y_0}, \mathbf{y_1}, \mathbf{y_2}\) are representations of the token sequence \(T_0, T_1, T_2\), but \(\mathbf{y_0}, \mathbf{y_1}, \mathbf{y_2}\) is in some sense better than \(\mathbf{x_0}, \mathbf{x_1}, \mathbf{x_2}\) because the token vectors have more context.

In other words, attention “mixes” the input vectors together, producing new output vectors that have, in a way, “communicated” with each other.

The big question now is how to compute the attention weights.

Naive Attention#

Let’s take a stab at a very naive attention mechanism. The idea would be to calculate the similarity of every token with every other token. We could use the dot product for this:

for i in range(3):
    for j in range(3):
        dot_product = torch.dot(X[i], X[j])
        print(f"attention({i}, {j}) = {round(dot_product.item(), 2)}")
attention(0, 0) = 1.5
attention(0, 1) = -0.12
attention(0, 2) = 1.27
attention(1, 0) = -0.12
attention(1, 1) = 5.6
attention(1, 2) = -0.07
attention(2, 0) = 1.27
attention(2, 1) = -0.07
attention(2, 2) = 6.01

This can of course be done much more efficiently via matrix multiplication:

S = torch.matmul(X, X.transpose(0, 1))
print(S)
tensor([[ 1.4988, -0.1217,  1.2659],
        [-0.1217,  5.6025, -0.0653],
        [ 1.2659, -0.0653,  6.0074]])

This yields a matrix of attention scores. It would clearly be beneficial if the attention scores were between \(0\) and \(1\) and summed to \(1\).

We can accomplish this using the softmax function, yielding a matrix of attention weights:

W = torch.softmax(S, dim=1)
print(W)
tensor([[0.5025, 0.0994, 0.3981],
        [0.0032, 0.9933, 0.0034],
        [0.0086, 0.0023, 0.9891]])

Now, we have a matrix of “attention weights” indicating how much attention vector \(\mathbf{x_i}\) should pay to vector \(\mathbf{x_j}\). These are our data-driven parameters - we can now compute the linear combination using these attention weights.

Let’s start with \(\mathbf{y_0}\). Remember that \(\mathbf{y_0} = w_{00} \cdot \mathbf{x_0} + w_{01} \cdot \mathbf{x_1} + w_{02} \cdot \mathbf{x_2}\):

y_0 = W[0, 0] * X[0] + W[0, 1] * X[1] + W[0, 2] * X[2]
print(y_0)
tensor([ 0.3636,  0.6064,  0.4964, -0.5111, -0.9314])

Next, we compute \(\mathbf{y_1}\). Remember that \(\mathbf{y_1} = w_{10} \cdot \mathbf{x_0} + w_{11} \cdot \mathbf{x_1} + w_{12} \cdot \mathbf{x_2}\):

y_1 = W[1, 0] * X[0] + W[1, 1] * X[1] + W[1, 2] * X[2]
print(y_1)
tensor([-0.1822,  2.1967, -0.6292,  0.4535,  0.2585])

Finally, we compute \(\mathbf{y_2}\). Remember that \(\mathbf{y_2} = w_{20} \cdot \mathbf{x_0} + w_{21} \cdot \mathbf{x_1} + w_{22} \cdot \mathbf{x_2}\):

y_2 = W[2, 0] * X[0] + W[2, 1] * X[1] + W[2, 2] * X[2]
print(y_2)
tensor([ 0.5315,  0.8067,  1.0987, -1.6683, -0.9873])

Again, we can achieve this much more efficiently through matrix multiplication. For this, we define \(W\) as the matrix containing the weights, \(X\) as the matrix where row \(i\) is the vector \(\mathbf{x_i}\), and \(Y\) as the matrix where row \(i\) is the vector \(\mathbf{y_i}\).

Then, we have \(Y = WX\):

Y = torch.matmul(W, X)
print(Y)
tensor([[ 0.3636,  0.6064,  0.4964, -0.5111, -0.9314],
        [-0.1822,  2.1967, -0.6292,  0.4535,  0.2585],
        [ 0.5315,  0.8067,  1.0987, -1.6683, -0.9873]])

We made some good progress! Theoretically, we could use this attention as is; however, this naive attention won’t perform well in practice. The reason is that we need to be able to differentiate between “information that a token vector represents” and “information that a token vector is interested in”.

For example, if a token vector encodes that it is the subject of a sentence, it will tend to pay high attention to other subjects in the sentence (primarily to itself). Instead, it should likely pay attention to token vectors that encode predicates, articles, and other relevant elements of the sentence.

We acknowledge that this is a rather hand-wavy intuition, and token vectors often represent inscrutable high-dimensional concepts that have little to no real analogy in linguistics. Despite this, in practice, the overall idea of differentiating between “information that a token vector represents” and “information that a token vector is interested in” works better than our current naive attention.

Key and Query Vectors#

Here, we introduce the first important components of the real self-attention mechanism - the key vectors and query vectors.

Each token vector generates a key vector and a query vector. The key vector reflects the aforementioned “information that the token represents”, while the query vector represents the “information the token is interested in”.

To continue our informal example, a token might have the key vector “I am the subject of the sentence” and the query vector “I am interested in the predicate of the sentence”. Again, in reality, key and query vectors won’t be this interpretable and will represent some instructable high-dimensional concepts that the language model learned during training.

We can also use a different, albeit somewhat vague, intuition to understand key and query vectors by borrowing concepts from databases.

In this context, a “query” is analogous to a search query in a database. It represents the current token the model is trying to understand. The query is then used to probe (i.e. to “query”) the other parts of the input sequence to determine how much attention to pay to them.

The “key” is similar to a database key. In the attention mechanism, each token then has an associated key that is used to match with the query.

Basically, the query is used to look up the keys and determine their relevance to the query.

The way we compute the key vectors and query vectors from the token vectors is surprisingly simple - we do this via linear layers.

For our example, we will create random linear layers - in reality, these would be parameters that our neural network would have to learn. Let’s call the matrix that will produce the key vectors \(W_K\) and the matrix that will produce the query vectors \(W_Q\):

W_K = torch.randn(5, 4)
W_Q = torch.randn(5, 4)

We can now compute the key vectors. Here is how we would obtain the key vector \(\mathbf{k_0}\):

k_0 = torch.matmul(X[0], W_K)
print(k_0)
tensor([ 0.6023, -0.7260,  1.1799,  0.2383])

To compute all key vectors at the same time, we can again use matrix multiplication:

K = torch.matmul(X, W_K)
print(K)
tensor([[ 0.6023, -0.7260,  1.1799,  0.2383],
        [-0.6521,  4.4224, -3.7460, -1.2657],
        [-0.7106, -4.3429,  4.2984, -2.3664]])

We can also compute the query vectors. Here is how we would obtain the query vector \(\mathbf{q_0}\):

q_0 = torch.matmul(X[0], W_Q)
print(q_0)
tensor([-1.6964,  1.3355, -0.5133,  0.0674])

To get all query vectors efficiently, we can use - you guessed it - matrix multiplication:

Q = torch.matmul(X, W_Q)
print(Q)
tensor([[-1.6964,  1.3355, -0.5133,  0.0674],
        [ 1.6595, -0.4445, -0.1917,  1.7729],
        [-0.1650, -2.9899, -3.8893,  1.2756]])

Now, we will obtain the attention scores in a similar way as before. The big difference is that, instead of computing the similarity of the token vectors with each other, we will compute the similarity between the key vectors and the query vectors:

S = torch.matmul(Q, K.transpose(0, 1))
print(S)
tensor([[-2.5809,  8.8498, -6.9600],
        [ 1.5185, -4.5739, -4.2686],
        [-2.2135, -0.1601, -6.6347]])

A minor, but important technical detail, is that we will need to scale the attention scores to avoid numerical instability. For quite boring mathematical reasons that we won’t delve into here, we scale them by a factor of \(\sqrt{d}\) where \(d\) is the dimension of a key/query vector:

S_scaled = S / (4**0.5)
print(S_scaled)
tensor([[-1.2905,  4.4249, -3.4800],
        [ 0.7593, -2.2869, -2.1343],
        [-1.1067, -0.0801, -3.3173]])

Now, again, we compute the attention weights from the scaled attention scores:

W = torch.softmax(S, dim=1)
print(W)
tensor([[1.0857e-05, 9.9999e-01, 1.3611e-07],
        [9.9470e-01, 2.2480e-03, 3.0506e-03],
        [1.1356e-01, 8.8508e-01, 1.3650e-03]])

We still face one problem - our tokens can currently “look into the future”. For example, token \(T_0\) can “see” \(T_1\) and \(T_2\).

However, this won’t be possible during inference time, as we can’t consider tokens that haven’t been generated yet. Therefore, we should disable this during training as well.

We will “mask” the attention scores of future tokens.

To do this, we define the following mask:

mask = torch.tril(torch.ones(S.shape[0], S.shape[0]))
print(mask)
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

Next, we set all the scores where mask == 0 to -inf. This way, when we apply the softmax, these entries will be set to 0:

S_masked = S.masked_fill(mask == 0, float("-inf"))
print(S_masked)
tensor([[-2.5809,    -inf,    -inf],
        [ 1.5185, -4.5739,    -inf],
        [-2.2135, -0.1601, -6.6347]])

Again, we need to normalize the scores:

S_masked_scaled = S_masked / (4 ** 0.5)
print(S_masked_scaled)
tensor([[-1.2905,    -inf,    -inf],
        [ 0.7593, -2.2869,    -inf],
        [-1.1067, -0.0801, -3.3173]])
W_masked = torch.softmax(S_masked_scaled, dim=1)
print(W_masked)
tensor([[1.0000, 0.0000, 0.0000],
        [0.9546, 0.0454, 0.0000],
        [0.2563, 0.7156, 0.0281]])

Additionally, it is common to apply dropout at this point:

dropout = nn.Dropout(0.5)
dropout(W_masked)
tensor([[0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.5126, 0.0000, 0.0000]])

Value Vectors#

Right now, we would apply the weights to the token vectors directly.

It turns out that the attention mechanism performs even better if we introduce one more indirection and calculate value vectors from the token vectors and apply the weights to the value vectors.

To borrow from databases one more time:

In this context, the “value” is similar to the value in a key-value pair in a database (representing the actual content of the items). Once we determine which keys (and thus which parts of the input) are most relevant to the query (the item we are currently looking at), we retrieve the corresponding values.

The computation of the value vectors is carried out in exactly the same way as the computation of the key and query vectors:

W_V = torch.randn(5, 4)
V = torch.matmul(X, W_V)
print(V)
tensor([[ 0.3301,  1.8359, -1.3448,  0.7947],
        [-0.1512, -0.5678,  0.8648,  4.8368],
        [ 2.6772, -1.3256, -3.2423, -0.3151]])

We can now apply the attention weights to the values to get the final result:

R = torch.matmul(W_masked, V)
print(R)
tensor([[ 0.3301,  1.8359, -1.3448,  0.7947],
        [ 0.3082,  1.7268, -1.2445,  0.9781],
        [ 0.0517,  0.0270,  0.1831,  3.6559]])

This particular attention mechanism is called scaled dot product attention and is by far the most common attention mechanism used today.

Let’s verify that our understanding of scaled dot product attention is correct by using the scaled_dot_product_attention function from PyTorch:

from torch.nn.functional import scaled_dot_product_attention

R_torch = scaled_dot_product_attention(Q, K, V, is_causal=True)
print(R_torch)
tensor([[ 0.3301,  1.8359, -1.3448,  0.7947],
        [ 0.3082,  1.7268, -1.2445,  0.9781],
        [ 0.0517,  0.0270,  0.1831,  3.6559]])

Looks good!

The Batch Dimension#

There is one more detail that we need to address: incorporating the batch dimension into our calculations.

While this may not be particularly exciting, we will go through the calculations, as this serves as a nice example of how the batch dimension doesn’t complicate things conceptually. You simply need to get used to the fact that there is always an additional dimension in front of your tensors.

Let’s initialize a random tensor that represents a situation where we have a batch size of 2, each batch element has 3 tokens, and each token is represented by a vector of size 5.

Note that in this section we will postfix all variables with _b to emphasize that we are dealing with an additional batch dimension here:

X_b = torch.randn(2, 3, 5)
print(X_b)
tensor([[[ 0.3449, -1.4241, -0.1163, -0.9727,  0.9585],
         [ 1.6192,  1.4506,  0.2695,  0.2625, -1.4391],
         [ 0.5214,  0.3488,  0.9676, -0.4657,  1.1179]],

        [[-1.2956,  0.0503, -0.5855, -0.3900,  0.0358],
         [ 0.1206, -0.8057,  0.2080, -1.1586, -0.9637],
         [-0.3750,  0.8033, -0.5188, -1.5013, -1.9267]]])

Next, we need to create the weights for the queries, keys and values. Instead of initializing matrices and then doing the calculations manually, let’s initialize linear layers and then call the layers (since this will be closer to what we will do in the end):

W_Q_b = nn.Linear(5, 4, bias=False)
W_K_b = nn.Linear(5, 4, bias=False)
W_V_b = nn.Linear(5, 4, bias=False)
K_b = W_K_b(X_b)
Q_b = W_Q_b(X_b)
V_b = W_V_b(X_b)

print(K_b.shape, Q_b.shape, V_b.shape)
torch.Size([2, 3, 4]) torch.Size([2, 3, 4]) torch.Size([2, 3, 4])

Finally, we calculate the attention scores. Note that we need to be careful when transposing the key matrix, as our actual token dimensions are now dimensions 1 and 2 (with 0 being the batch dimension):

S_b = torch.matmul(Q_b, K_b.transpose(1, 2))
print(S_b)
tensor([[[ 0.0351,  0.3447,  0.2260],
         [ 0.1568, -2.1882, -0.7753],
         [ 0.0279, -0.2793, -0.1115]],

        [[-0.1849,  0.3218,  0.4802],
         [ 0.4504, -0.5067, -0.4989],
         [ 0.9851, -1.1174, -1.2743]]], grad_fn=<UnsafeViewBackward0>)

We initialize the mask as usual:

mask_b = torch.tril(torch.ones(3, 3))
print(mask_b)
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

We mask the respective scores:

S_masked_b = S_b.masked_fill(mask_b == 0, float("-inf"))
print(S_masked_b)
tensor([[[ 0.0351,    -inf,    -inf],
         [ 0.1568, -2.1882,    -inf],
         [ 0.0279, -0.2793, -0.1115]],

        [[-0.1849,    -inf,    -inf],
         [ 0.4504, -0.5067,    -inf],
         [ 0.9851, -1.1174, -1.2743]]], grad_fn=<MaskedFillBackward0>)

Finally, we normalize the scores and apply softmax:

S_masked_scaled_b = S_masked_b / K_b.shape[-1] ** 0.5
W_b = torch.softmax(S_masked_scaled_b, dim=-1)
print(W_b)
tensor([[[1.0000, 0.0000, 0.0000],
         [0.7636, 0.2364, 0.0000],
         [0.3584, 0.3074, 0.3343]],

        [[1.0000, 0.0000, 0.0000],
         [0.6174, 0.3826, 0.0000],
         [0.5979, 0.2090, 0.1932]]], grad_fn=<SoftmaxBackward0>)

And the output is:

R_b = torch.matmul(W_b, V_b)
print(R_b.shape)
torch.Size([2, 3, 4])

Implementing a SelfAttention Layer#

Let’s now package our calculations into a nice self-contained SelfAttention layer:

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, c_len, dropout):
        super().__init__()

        self.W_Q = nn.Linear(d_in, d_out, bias=False)
        self.W_K = nn.Linear(d_in, d_out, bias=False)
        self.W_V = nn.Linear(d_in, d_out, bias=False)
        self.dropout = nn.Dropout(dropout)

        self.register_buffer("mask", torch.tril(torch.ones(c_len, c_len)))

    def forward(self, X):
        batch_size, num_tokens, d_in = X.shape
        
        Q = self.W_Q(X)
        K = self.W_K(X)
        V = self.W_V(X)

        S = torch.matmul(Q, K.transpose(1, 2))
        S_masked = S.masked_fill(self.mask == 0, float("-inf"))
        S_masked_scaled = S_masked / K.shape[-1] ** 0.5
        W = torch.softmax(S_masked_scaled, dim=-1)

        W = self.dropout(W)
        
        R = torch.matmul(W, V)

        return R

layer = SelfAttention(d_in=5, d_out=4, c_len=3, dropout=0.5)

Let’s test that the layer works as expected:

X = torch.randn(2, 3, 5)
print(X.shape)
torch.Size([2, 3, 5])
Y = layer(X)
print(Y.shape)
torch.Size([2, 3, 4])

Note that, in practice, the number of input dimensions and output dimensions will usually be the same, i.e. the self-attention layer doesn’t change the dimensionality of the tensor, it only “mixes” the elements of the tensor together.

This has the nice benefit of allowing us to stack multiple self-attention layers on top of each other without having to worry about the shapes.

Multi - Head Attention#

The final piece we are missing is the multi-head attention.

Basically, instead of having a single self-attention layer, we use multiple self-attention layers, each with its own weights, and combine their outputs. This allows the model to learn different features in different heads, allowing for richer capabilities:

X = torch.randn(2, 3, 5)

d_in = 5
d_out = 4
c_len = 3
n_heads = 2

heads = [SelfAttention(d_in, d_out, c_len, 0.0) for _ in range(n_heads)]

result = [head(X) for head in heads]
print([head_out.shape for head_out in result])
[torch.Size([2, 3, 4]), torch.Size([2, 3, 4])]

Next, we combine the head results:

head_out_combined = torch.cat(result, dim=-1)
print(head_out_combined.shape)
torch.Size([2, 3, 8])

Note that in practice we don’t want the multi-head attention layer to blow up the size of the tensor.

Therefore, we reduce the value of \(d_{out}\) and set it in such a way that the multi-head self-attention layer won’t change the dimension of the incoming tensor:

d_out = 2
heads = [SelfAttention(d_in, d_out, c_len, 0.0) for _ in range(n_heads)]

head_out_combined = torch.cat([head(X) for head in heads], dim=-1)
print(head_out_combined.shape)
torch.Size([2, 3, 4])

While this technically already works, it is computationally expensive since the heads are processed sequentially. Instead, we could process them in parallel by computing the outputs for all attention heads at the same time.

Here is what the entire process looks like:

n_tokens = 3
d_in = 4
d_out = 4
n_heads = 2
head_dim = d_out // n_heads

Let X be the input tensor:

X = torch.randn(2, n_tokens, d_in)
X.shape
torch.Size([2, 3, 4])

We initialize weight matrices for W_K, W_Q and W_V:

W_K = torch.randn(d_out, d_in)
W_Q = torch.randn(d_out, d_in)
W_V = torch.randn(d_out, d_in)

We compute the key and query matrices:

K = torch.matmul(X, W_K)
Q = torch.matmul(X, W_Q)
print(K.shape, Q.shape)
torch.Size([2, 3, 4]) torch.Size([2, 3, 4])

We reshape the key and query matrices in such a way that we get multiple heads:

K_view = K.view(2, n_tokens, n_heads, head_dim)
print(K_view.shape)
torch.Size([2, 3, 2, 2])
Q_view = Q.view(2, n_tokens, n_heads, head_dim)
print(Q_view.shape)
torch.Size([2, 3, 2, 2])
K_view = K_view.transpose(1, 2)
print(K_view.shape)
torch.Size([2, 2, 3, 2])
Q_view = Q_view.transpose(1, 2)
print(Q_view.shape)
torch.Size([2, 2, 3, 2])

We obtain the attention scores:

S = torch.matmul(Q_view, K_view.transpose(2, 3))
print(S.shape)
torch.Size([2, 2, 3, 3])
print(S[0])
tensor([[[ -2.8630,  -0.1614,   1.2478],
         [-10.0620,   0.9408,   3.7090],
         [  2.8880,   0.4650,  -1.3942]],

        [[ 11.3074,   4.0007,  -6.2850],
         [ -4.6397,  -7.2396,   3.7868],
         [ -4.1246,  -0.4083,   2.0658]]])

We obtain the mask matrix and use it to mask out the future scores:

mask = torch.tril(torch.ones(n_tokens, n_tokens))
print(mask)
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
S_masked = S.masked_fill(mask == 0, float("-inf"))
print(S_masked[0])
tensor([[[ -2.8630,     -inf,     -inf],
         [-10.0620,   0.9408,     -inf],
         [  2.8880,   0.4650,  -1.3942]],

        [[ 11.3074,     -inf,     -inf],
         [ -4.6397,  -7.2396,     -inf],
         [ -4.1246,  -0.4083,   2.0658]]])

We scale the scores:

S_masked_scaled = S_masked / K.shape[-1] ** 0.5
print(S_masked_scaled)
tensor([[[[-1.4315,    -inf,    -inf],
          [-5.0310,  0.4704,    -inf],
          [ 1.4440,  0.2325, -0.6971]],

         [[ 5.6537,    -inf,    -inf],
          [-2.3198, -3.6198,    -inf],
          [-2.0623, -0.2041,  1.0329]]],


        [[[ 5.0593,    -inf,    -inf],
          [-1.4993,  0.2439,    -inf],
          [ 5.9892, -1.2553,  1.5600]],

         [[-4.6679,    -inf,    -inf],
          [ 0.5341,  0.3049,    -inf],
          [-3.2725, -0.4792, -0.8532]]]])

We apply the softmax to get the normalized attention weights:

W = torch.softmax(S_masked_scaled, dim=-1)
print(W[0])
tensor([[[1.0000, 0.0000, 0.0000],
         [0.0041, 0.9959, 0.0000],
         [0.7066, 0.2104, 0.0830]],

        [[1.0000, 0.0000, 0.0000],
         [0.7858, 0.2142, 0.0000],
         [0.0339, 0.2173, 0.7488]]])

We calculate the values and reshape the value matrix. Finally, we get the result:

V = torch.matmul(X, W_V) 
V_view = V.view(2, n_tokens, n_heads, head_dim)
V_view = V_view.transpose(1, 2)
R = torch.matmul(W, V_view)
print(V_view.shape, W.shape)
torch.Size([2, 2, 3, 2]) torch.Size([2, 2, 3, 3])
print(R.shape)
torch.Size([2, 2, 3, 2])

Next, we reshape the result back to three dimensions. The resulting tensor can now be fed into another layer:

R = R.transpose(1, 2)
R.shape
torch.Size([2, 3, 2, 2])
R_combined = R.contiguous().view(2, n_tokens, d_out)
R_combined.shape
torch.Size([2, 3, 4])