All Lessons

The Mathematical Foundation of Multi-Head Attention and Position-wise Feed-Forward Networks

An exploration of the linear algebraic structures powering the Transformer architecture. We dissect the mechanisms of scaled dot-product attention and the nonlinear transformations of point-wise networks.

AI Narration Press play to listen
0  / 6 paragraphs
Click any paragraph to jump · Scroll freely without breaking narration

To understand Multi-Head Attention (MHA), we must first grasp the intuition of 'content-addressable memory.' In a sequence of tokens, each word must determine which other words in the sentence are most relevant to its meaning. This is achieved by projecting each input embedding into three distinct vector spaces: Queries ($Q$), Keys ($K$), and Values ($V$). The Query represents 'what I am looking for,' the Key represents 'what I contain,' and the Value represents 'the information I provide.' By computing the similarity between a Query and all Keys, the model creates a weighted average of the Values, effectively allowing the network to selectively focus on different parts of the input sequence.

Mathematically, the core of this mechanism is the Scaled Dot-Product Attention. Given input matrices $Q, K, \\∈ \\ℝ^{n \\× d_k}$ and $V \\∈ \\ℝ^{n \\× d_v}$, the attention output is computed as: $$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\ \right)V $$ Here, $QK^T$ computes the raw alignment scores between all pairs of tokens. We divide by $\sqrt{d_k}$ to prevent the gradients from vanishing or exploding during the softmax operation, as high-dimensional dot products can produce extremely large values that push the softmax into regions with near-zero gradients.

While single-head attention is powerful, it is limited to a single 'representation subspace.' Multi-Head Attention overcomes this by running multiple attention mechanisms in parallel. Each head $i$ has its own set of learnable projection matrices $W_Q^{(i)}, W_K^{(i)}, W_V^{(i)}$. The output of the MHA layer is the concatenation of these heads, which is then projected back to the original dimension using a final weight matrix $W_O$: $$ \text{MultiHead}(Q, K, V) = \text{Concat}( ext{head}_1, \dots, \text{head}_h)W_O $$ where $\text{head}_i = \text{Attention}(QW_Q^{(i)}, KW_K^{(i)}, VW_V^{(i)})$. This allows the model to simultaneously attend to different types of relationships, such as syntactic dependencies in one head and semantic associations in another.

Following the attention mechanism, the Transformer employs a Position-wise Feed-Forward Network (FFN). While attention captures global context, the FFN is responsible for processing the information extracted by the attention heads at each individual position. It is 'position-wise' because it is applied to each token independently and identically. This can be viewed as a local transformation that enhances the features of each token after they have been contextualized by the rest of the sequence.

The FFN consists of two linear transformations separated by a non-linear activation function, typically the Rectified Linear Unit (ReLU) or GELU. Mathematically, it is defined as: $$ \text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2 $$ where $W_1$ projects the vector into a higher-dimensional space (often $4 \\×$ the model dimension $d_{model}$) and $W_2$ projects it back down. This 'expansion-contraction' structure allows the network to map the contextual embeddings into a higher-dimensional space to find more complex patterns before condensing them back into the latent space.

The synergy between MHA and FFN creates a powerful duality. The MHA layer acts as a dynamic routing mechanism, determining *where* to look for information across the sequence. Conversely, the FFN acts as a static processing unit, determining *what* to do with that information. Together, these components ensure that the model can handle long-range dependencies while maintaining the capacity to perform deep, non-linear transformations on each token, forming the bedrock of modern Large Language Models.