Transformer-style Multi-Headed Self-Attention (MHSA) [^1][^2] is a way to create fully connected layers of neurons [^3] where the number of neurons changes to match the size of the input. The network literally makes it's self bigger or smaller as needed. It's not alone in this, of course, other architectures also handle variable-sized inputs. Convolutional Neural Networks (CNNs) and Recurrent Neural Networks (RNNs) do this easily. However, these approaches introduce *inductive biases*, assumptions suited to specific types of data [^4][^5]. These biases are baked into the architecture and cannot be gotten rid of. This isn't a bad thing, necessarily. Because of their biases CNNs work very well for images. RNNs work very well for time series data. In fact you can take a CNN with completely random weights, freeze it, then add and train a simple linear output layer to do image classification with astonishingly high accuracy [^6]. You can do the same thing with RNNs and time series data [^7]. This works so well it's a wonder we ever bother training whole models. At the extreme end of this are Weight Agnostic Neural Networks (WANNs)[^8]. Networks with a single value for all of the weights (eg. 1). During training these networks evolve, changing shape, adding or removing connections, but never changing the weights. In the end, none of the processing or "knowledge" is in the weights it's solely the architecture. But of course this means WANNs are extremely data specific. MHSA goes in exactly the opposite direction. Extremely simple, extremely general. Even more general than regular fully connected layers, as neither the number of neurons nor their order is hard coded. MHSA doesn't start out naturally good at anything. But, given enough data, it can *learn* to be good at basically everything. This generality is why we now have truly multimodal models. You can jointly train a single transformer to work with images, and time series, and audio, and radar, and whatever else you've got. They're currently our best shot at Artificial General Intelligence (AGI) [^9][^10]. Which, to me, is surprising given that transformers are basically modern versions of [Frank Rosenblatt's](https://en.wikipedia.org/wiki/Frank_Rosenblatt) 67 year old Multi Layer Perceptrons (MLPs) [^11]. Although, I doubt it would have surprised Rosenblatt; a man so self assured he claimed in 1958 that some day a version of his Perceptrons would "be able to walk, talk, see, write, reproduce itself and be conscious of its existence" [^12]. Too soon to say, I suppose. Not wanting to bury the lead, here is the entire procedure to implement MHSA. We will go through each part in turn. >[!Full MHSA Block] >![[Multi Headed Self Attention.svg]] --- # Step 1: Input Projection A regular fully connected layer (also called a linear projection) will take an input vector of shape $[In]$ and linearly project it into a new vector of shape $[Out]$ using a matrix multiplication between the input and a weight matrix of shape $[Out, In]$. An optional bias vector of shape $[Out]$ can then be added [^3]. MHSA aims to calculate the weight matrix on the fly, based on the input. To do this each element in our input actually has to be a vector. These vectors are typically called "tokens" or "embeddings", and their size is hard coded. So our input goes from being a vector of shape $[In]$ to a matrix of shape $[Tokens, Embedding\_Dim]$ where: - $Tokens$ is the length of the sequence (image patches, words etc.) - $Embedding\_Dim$ is the size of token >[! Input Data] >![[Input Data.svg]] How exactly you arrange your inputs into to tokens (tokenisation) is clearly critical, and it's an active area of research in itself. For brevity we're leaving that out of this discussion. Our first step towards a dynamically calculated weight matrix is to take each token separately and run them through the same fully connected linear layer. We do this by multiplying each token vector with the same learned weight matrix. Thankfully this is parallelized on the GPU, meaning we can process all the tokens in a single step. This linear layer is called the $QKV$ Pojection, and it projects each token from a length of $[Embedding\_Dim]$ to $[3 \times Inner\_Dim ]$. Where $Inner\_Dim$ is a hyper parameter that is usually just set so $Inner\_Dim = Embedding\_Dim$. Why we do this will be clear in a second. We project to $[3 \times Inner\_Dim ]$ because we immediately split each token into 3 sperate tokens of shape $[Inner\_Dim ]$. The three outputs are callled the $Query$, $Key$, and $Value$ ($Q$,$K$,$V$) Tokens. The trick here is that we have basically done 3 separate linear projections of the input with a single matrix multiplication. To make this clear diagrams will typically show the three projections as separate but parallel. >[!QKV Projection] >![[QKV Projection.svg]] Why are the $Queries$, $Keys$, and $Values$ called that? Well, when we're calculating the dynamic weights the $Keys$ represent the input neurons, the $Querys$ represent the output neurons, and the $Values$ are the data being sent. So the outputs **query** the input **keys** and receive their **values**. If that's as clear as mud don't worry about it. You just have to call them something, and this is what the original authors chose [^1]. Finally we reshape the $Q$, $K$, and $V$ tokens to have a new "Head" dimension. This is a trivial shuffling around of data, nothing is created or removed, just reorganised. This is the "Multi Head" part of MHSA. Before the reshaping each token has a length $[Inner\_Dim]$. After the reshaping each token has the shape $[Heads, Head\_Dim]$. Where $Head\_Dim = \frac{Inner\_Dim}{Heads}$ >[!Full Input Projection] >![[QKV Projection With Heads.svg]] Ok, why? Why bother shuffling the data around into a new dimension? Well, this actually lets us implement not just one fully connected layer of variable length, but as many as we have heads. This is why the QKV projection is necessary even if $Embedding\_Dim = Inner\_Dim$. What it's actually doing is mapping each token into many smaller tokens. Why is it good to have multiple heads? Well it allows each one to become an "expert". Each one learns to handle a different sort of data. This isn't just useful in attention, it's also useful in traditional fully connected layers. There's recently been a push towards replacing all the regular fully connected layers in transformers (outside of MHSA) with "mixture-of-experts" layers (basically just multi-headed linear layers) [^13]. Getting extremely meta, you might wonder "if people are replacing the linear layers outside of MHSA with mixture-of-experts layers, can you replace the $QKV$ linear projection with a mixture-of-experts layer"? The answer is yes! [^14] Although for whatever reason it seems to only really work for the $V$ projection, not the $Q$ or $K$. Still, it's not common practice because it's a bit involved and doesn't help much. --- # Step 2: Getting the Dynamic Weights Next comes the real magic. We are going to use the Q and K projections to generate the weights for our token-to-token connections. Lets first remember how matrix multiplications work. For matrices $A$ and $B$ of shapes $[m, n]$ and $[n, p]$ their product ( $\otimes$ ) looks like this: $ \begin{array}{l c c c} & A & \otimes & B & = & AB \\[6pt] Shapes: & m\times{\color{red}{n}} & & {\color{red}{n}}\times p & & m\times p \end{array} $ The inner dimension $n$ has to be the same in both matrices, but when we multiply them together $n$ disappears. Spooky. At this point the Q, K, and V tensors have the shape $[Heads, Tokens, Head\_Dim]$. $Head\_Dim$ is going to be our inner dim that gets sacrificed (RIP). First we transpose the K tensor, meaning we swap the last two dimensions, giving K the shape $[Heads, Head\_Dim, Tokens]$. Then, taking each head separately, we do the matrix multiply $Q \otimes K^T$ $ \begin{array}{c c c} Q &\otimes & K^T & = & Output \\[6pt] T\times{\color{red}{H\_D}} & & {\color{red}{H\_D}}\times T & & T\times T \end{array} $ *Here $T=Tokens$, and $H\_D = Head\_Dim$.* By the magic of matrices the output has the shape $[Heads, Tokens, Tokens]$. $Q \otimes K^T$ has given us our token-to-token weight matrix of size $[Heads, Out, In]$. It's just that there are as many input tokens as output tokens so it's $[Heads, Tokens, Tokens]$. >[!Q Times Kᵀ] >![[Q Times KT.svg]] So what is actually happening here? Multiplying $Q$ and $K$ gives us an output matrix that's the right shape, but what's actually *in* that matrix? For any given head, $Q$ and $K$ both have the shape $[Tokens, Head\_Dim]$. We can view these as arrays of tokens. Lets imagine we have a sequence of 3 tokens. Then: $ Q = \begin{bmatrix} Q_1 \\ Q_2 \\ Q_3 \end{bmatrix} \quad\quad K = \begin{bmatrix} K_1 \\ K_2 \\ K_3 \end{bmatrix} $ Where every element $Q_n$ and $K_n$ is a token of length $Head\_Dim$. When we do $Q \otimes K^T$ we have: $ \begin{bmatrix} Q_1 \\ Q_2 \\ Q_3 \end{bmatrix} \otimes \begin{bmatrix} K_1 & K_2 & K_3 \end{bmatrix} = \begin{bmatrix} Q_1 K_1 & Q_1 K_2 & Q_1 K_3 \\ Q_2 K_1 & Q_2 K_2 & Q_2 K_3 \\ Q_3 K_1 & Q_3 K_2 & Q_3 K_3 \end{bmatrix} $ Which again is our weight matrix of shape $[Out, In]$ ($[Rows, Columns]$). So the weights for output neuron 1 are: $ \begin{bmatrix} Q{\color{red}_1} K_1 & Q{\color{red}_1}K_2 & Q{\color{red}_1}K_3 \end{bmatrix}$ For each output neuron there is a single $Q$ token and as many $K$ tokens as there are inputs. The $Q$ tokens can therefore be said to represent the output neurons, and the $K$ tokens represent the input neurons. Each weight is the vector product of the given pair of $Q$ and $K$ tokens. There's two ways to do vector multiplication. There's the algebraic way: multiply the vectors together element-wise and then sum the outputs. Or the geometric way, using their magnitudes and the angle ($\theta$) between them: $ Q_1K_1 = \|Q_1\| \, \|K_1\| \cos(\theta) $ Doing it the geometric way we can see that the value of the weight is limited by angle between the $Q$ and $K$ token. This is the trick. The network can make the weights of the connections larger or smaller by changing the $Q$ and $K$ projections to make the tokens more or less similar. Imagine your model is processing an image. Each token represents an area of the image. And say, for whatever reason, the model wants to send information from areas that have red squares 🟥 to areas that have blue circles 🔵. It just has to learn the projections $Q$ and $K$ such that the output of $Q(🔵)$ is similar to $K(🟥)$. Easy. But now we get to the main reason that multiple heads are necessary! A regular fully connected layer isn't limited to routing a single type of data. It can do $ \begin{aligned} \text{🟥} &\to \text{🔵}\\ \text{🔵} &\to \text{🟩}\\ \text{🟩} &\to \text{❌} \end{aligned} $ All at once in a single layer. It's possible there could exist a single pair of $Q$ and $K$ projections that could achieve this but it's not guaranteed. However, if we have multiple heads, each with it's own $Q$ and $K$ projection then they can each learn to route different types of information, more closely emulating a regular fully connected layer. But there's a slight snag. The outputs of a linear layer are just a bunch of weighted sums of the inputs. Imagine if all the weights in our $Q \otimes K^T$ weight matrix were 1. Then the output for any given token would be a sum of the entire input sequence. And as the length of the sequence approached infinity so could the magnitude of the output. Disaster. But we can fix this with a simple SoftMax [^15] across the $In$ dimension of the weight matrix $Q \otimes K^T$. SoftMax is a nonlinear function similar to a Sigmoid [^16] that squashes every input to the range $0 \to 1$. It also scales each row of the weight matrix so the weights sum to 1. This basically turns them into percentages. Each output token says "I want to receive a% of my information from token 1, b% of my information from token 2..." and so on. >[!The Attention Matrix] >![[Attn.svg]] Technically 🙄, we actually use a scaled SoftMax. Before applying the SoftMax we divide the $Q \otimes K^T$ matrix by $\sqrt{Head\_Dim}$ [^1]. I'm not wanting to get too complicated here, but if each token in $Q$ and $K$ were random noise with variance 1 then each row of the output $Q \otimes K^T$ would have a variance of $Head\_Dim$. Dividing by $\sqrt{Head\_Dim}$ restores the variance of 1. This scaling factor is typically called the "temperature" of the SoftMax and some folk will let the model learn it [^17]. In our diagrams we have left out the scaling, considering it to be part of the SoftMax. The final SoftMax squished version of $Q \otimes K^T$ is called the attention matrix, because each output token's weights sort of show you (in percentages) where it's paying attention to. --- # Step 3: Using the Dynamic Weights Now that we have the dynamic token-to-token weights (Attention Matrix) we can apply them to our input. But, again, we don't have 1 set of weights we have as many as we have heads. That's why the V projection was necessary before. It projected our input from shape $[Tokens, Embedding\_Dim]$ into $[Heads, Tokens, Head\_Dim]$. We can apply these weights by taking each head separately and doing the matrix multiplication $Attn \otimes V$. Again we're able to do all the matrix multiplications for all the heads at once thanks to GPU parallelisation. >[! Multiplying Attn and V] >![[Attn Times V.svg]] Note that a regular fully connected [^3] layer with an input of shape $[In]$, output of shape $[Out]$, and a weight matrix $W$ of shape $[Out, In]$ is implemented as: $ \begin{array}{c c c} Input & \otimes & W^T & = & Output \\[6pt] {\color{red}{In}} & & {\color{red}{In}}\times Out & & Out \end{array} $ In our MHSA block our weight matrix is $Attn$, and our input is $V$. So why aren't we doing $V \otimes Attn^T$ ? It's because the tokens are vectors. Lets look at the shapes. For each head we have: $\begin{array}{c c c} Attn &\otimes& V & = & Output \\[6pt] T_O\times{\color{red}{T_I}} & & {\color{red}{T_I}}\times H\_D & & T_O\times H\_D \end{array}$ Where $T_O$ and $T_I$ are the output and input tokens respectively, and $H\_D$ is the $Head\_Dim$ . If we were to try doing $V \otimes Attn^T$ like in a regular linear layer we would have: $\begin{array}{c c c} V &\otimes & Attn^T & = & ?? \\[6pt] T_I\times{{H\_D}} & & T_I\times T_O & & \end{array}$ The $Head\_Dim$ gets in the way. The inner dimensions don't match. Of course we could transpose $V$, and do $V^T \otimes Attn^T$. That would work. But: $V^T \otimes Attn^T \; = \; [V \otimes Attn]^T \; = \; Attn \otimes V$ --- # Step 4: Output Projection At this point we're basically done. We've generated the dynamic weights and used them. All that's left is to undo the whole "heads" thing. And that's very simple. It's just the inverse of the input projection we did right at the start. First we take the output of $Attn \cdot V$ which has the shape $[Heads, Tokens, Head\_Dim]$ and we reshape back to $[Tokens, Inner\_Dim]$ where $Inner\_Dim = Heads \times Head\_Dim$. Finally we take each token of shape $[Inner\_Dim]$ and use a regular fully connected layer to project it back to a shape $[Embedding\_Dim]$. Again we do this for all of the tokens at once with GPU parallelisation. >[!Output Projection] >![[Output Projection.svg]] And that's It! --- # The Code The following code is taken from lucidrains' pytorch port [^18] of the official JAX code [^19] for the Vision Transformer (ViT) [^20]. Please note in this implementation there is a Layer Norm added before the input projection, and optional dropout added after the output projection. ```python class Attention(nn.Module): def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head ** -0.5 self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) attn = self.dropout(attn) out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) ``` --- # References [^1]: Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L. and Polosukhin, I. (2017) _Attention Is All You Need_. CoRR, abs/1706.03762. Available at: [http://arxiv.org/abs/1706.03762](http://arxiv.org/abs/1706.03762) (Accessed: 30 April 2025). [^2]: PyTorch Team (2025) ‘torch.nn.MultiheadAttention — PyTorch 2.7 documentation’, _PyTorch Documentation_. Available at: [https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) (Accessed: 30 April 2025). [^3]: PyTorch Team (2025) ‘torch.nn.Linear — PyTorch 2.7 documentation’, _PyTorch Documentation_. Available at: [https://pytorch.org/docs/stable/generated/torch.nn.Linear.html](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) (Accessed: 30 April 2025). [^4]: Cohen, N. and Shashua, A. (2017) ‘Inductive Bias of Deep Convolutional Networks through Pooling Geometry’. _Proceedings of the 5th International Conference on Learning Representations (ICLR 2017),_ Toulon, France, 24–26 April 2017. OpenReview.net. Available at: [https://openreview.net/forum?id=BkVsEMYel](https://openreview.net/forum?id=BkVsEMYel) (Accessed: 30 April 2025). [^5]: Elman, J.L. (1990) ‘Finding Structure in Time’. _Cognitive Science_, 14(2), pp. 179–211. doi: 10.1207/S15516709COG1402_1. Available at: [https://www.sciencedirect.com/science/article/pii/036402139090002E](https://www.sciencedirect.com/science/article/pii/036402139090002E) (Accessed: 30 April 2025). [^6]: Saxe, A.M., Koh, P.W., Chen, Z., Bhand, M., Suresh, B. and Ng, A.Y. (2011) ‘On Random Weights and Unsupervised Feature Learning’. _Proceedings of the 28th International Conference on Machine Learning (ICML 2011),_ Bellevue, WA, USA, 28 June – 2 July 2011, pp. 1089–1096. Available at: https://dblp.org/rec/conf/icml/SaxeKCBSN11 (Accessed: 30 April 2025). [^7]: Schmidhuber, J., Wierstra, D., Gagliolo, M. and Gomez, F. (2007) ‘Training Recurrent Networks by Evolino’. _Neural Computation_, 19(3), pp. 757–779. MIT Press. doi: 10.1162/neco.2007.19.3.757. [^8]: Gaier, A. and Ha, D. (2019) ‘Weight Agnostic Neural Networks’, _arXiv preprint_ arXiv:1906.04358. Available at: [http://arxiv.org/abs/1906.04358](http://arxiv.org/abs/1906.04358) (Accessed: 30 April 2025). [^9]: Yue, X., Ni, Y., Zhang, K., Zheng, T., Liu, R., Zhang, G. _et al._ (2024) ‘MMMU: A massive multi-discipline multimodal understanding and reasoning benchmark for expert AGI’, _Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR 2024)_. Available at: [https://arxiv.org/abs/2311.16502](https://arxiv.org/abs/2311.16502) (Accessed: 30 April 2025). [^10]: Huh, M., Cheung, B., Wang, T. and Isola, P. (2024) ‘The Platonic Representation Hypothesis’, _arXiv preprint_ arXiv:2405.07987. Available at: [https://arxiv.org/abs/2405.07987](https://arxiv.org/abs/2405.07987) (Accessed: 30 April 2025). [^11]: Rosenblatt, F. (1958) ‘The perceptron: A probabilistic model for information storage and organization in the brain’, _Psychological Review_, 65(6), pp. 386–408. doi: 10.1037/h0042519. [^12]: Olazaran, M. (1996) ‘A sociological study of the official history of the perceptrons controversy’, _Social Studies of Science_, 26(3), pp. 611–659. Available at: [http://www.jstor.org/stable/285702](http://www.jstor.org/stable/285702) (Accessed: 30 April 2025). [^13]: Cai, W., Jiang, J., Wang, F., Tang, J., Kim, S. and Huang, J. (2025) ‘A Survey on Mixture of Experts in Large Language Models’, _IEEE Transactions on Knowledge and Data Engineering_, pp. 1–20. doi: 10.1109/TKDE.2025.3554028. Preprint available at: https://arxiv.org/pdf/2407.06204 (Accessed: 30 April 2025) [^14]: Csordás, R., Piękos, P., Irie, K. & Schmidhuber, J., 2024. SwitchHead: Accelerating Transformers with Mixture-of-Experts Attention. In: A. Globerson, L. Mackey, D. Belgrave, A. Fan, U. Paquet, J. Tomczak & C. Zhang, eds. Advances in Neural Information Processing Systems. Vol. 37. Curran Associates, Inc., pp.74411–74438. Available at: https://proceedings.neurips.cc/paper_files/paper/2024/file/87be61bf9338389702712f5e9754a986-Paper-Conference.pdf [^15]: PyTorch Team (2025) ‘torch.nn.Softmax — PyTorch 2.7 documentation’, _PyTorch Documentation_. Available at: [https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html](https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html) (Accessed: 30 April 2025). [^16]: PyTorch Team (2025) ‘torch.nn.Sigmoid — PyTorch 2.7 documentation’, _PyTorch Documentation_. Available at: [https://pytorch.org/docs/stable/generated/torch.nn.Sigmoid.html](https://pytorch.org/docs/stable/generated/torch.nn.Sigmoid.html) (Accessed: 30 April 2025). [^17]: Dufter, P., Schmitt, M. and Schütze, H. (2020) ‘Increasing learning efficiency of self-attention networks through direct position interactions, learnable temperature, and convoluted attention’, _Proceedings of the 28th International Conference on Computational Linguistics (COLING 2020)_, 8–13 December 2020, Barcelona (online). Stroudsburg, PA: Association for Computational Linguistics, pp. 3630–3636. Available at: https://aclanthology.org/2020.coling-main.323 (Accessed: 30 April 2025). [^18]: Lucidrains (n.d.) vit-pytorch: Vision Transformer - Pytorch. GitHub. Available at: https://github.com/lucidrains/vit-pytorch (Accessed: 1 May 2025). [^19]: Google Research (n.d.) vision_transformer: An implementation of the Vision Transformer architecture. GitHub. Available at: https://github.com/google-research/vision_transformer (Accessed: 1 May 2025). [^20]: Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J. and Houlsby, N. (2020) An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929. Available at: https://arxiv.org/abs/2010.11929 (Accessed: 1 May 2025).