🧠 Multi-Head Attention — AI / ML Interview Guide
LLM Internals · interactive visualization + interview prep
Open the interactive Multi-Head Attention visualization on PrepGrind → Step through a live animation, tune the parameters, and read the full theory, math, reference code, and interview Q&A below — free, in your browser.
What it is
One attention head can only focus one way at a time. Multi-head attention runs SEVERAL attention computations in parallel — each with its own Q/K/V projections — so different heads can capture different relationships (syntax, coreference, position…). Their outputs are concatenated and projected back.
Mental model
One attention head can form only a SINGLE distribution of "who looks at whom" — it must average all the relationships in a sentence into one focus. Multi-head runs several such heads in parallel, each in its own learned subspace, so one can track grammar while another tracks coreference and another tracks position — simultaneously. Then their findings are concatenated and mixed by a final projection. And it is essentially free: instead of one head of dimension d, you run h heads of dimension d/h.
Theory
A single attention head produces one softmax distribution per token, so it can represent exactly one notion of relevance at a time. Real language has many simultaneous relationships (subject-verb agreement, pronoun antecedents, positional adjacency). Forcing one head to capture all of them means averaging them into a single blurred focus — multi-head attention removes that bottleneck.
The construction: with model dimension d and h heads, each head gets its own projections W_Qᵢ, W_Kᵢ, W_Vᵢ that map tokens into a d/h-dimensional subspace. Each head runs standard scaled-dot-product attention independently in its subspace. The h outputs (each d/h-dim) are concatenated back to d dimensions and passed through an output projection W_O.
Crucially the compute is roughly the same as one full-width head: h heads of size d/h total the same d, so you get representational diversity "for free". The subspaces are what enable specialization — different heads provably attend to different linguistic phenomena.
The output projection W_O is not optional bookkeeping: concatenation just stacks the per-head results side by side, and W_O is what mixes information ACROSS heads into a single coherent representation. Skipping it leaves the heads siloed.
Two nuances interviewers like. (1) Many heads are redundant — research shows a large fraction can be pruned with little loss — yet enough specialize that multi-head still wins. (2) Modern serving variants Multi-Query (MQA) and Grouped-Query Attention (GQA) keep multiple query heads but SHARE one (or a few) K/V projections across them, drastically shrinking the KV cache at inference for a small quality cost.
Concrete example
In a transformer, one head might track subject→verb agreement while another links a pronoun to its antecedent and a third attends to nearby words. GPT/BERT use many heads (e.g. 12–96) per layer; you can literally see heads specialize when you visualize their attention maps.
Key equations
split into h heads, each with smaller dim d_k = d/hheadᵢ = Attention(Q·Wqᵢ, K·Wkᵢ, V·Wvᵢ)MultiHead = Concat(head₁…head_h) · W_Osame total compute as one big head, but multiple subspaces
Step by step
- Project the tokens into h separate Q/K/V subspaces (one per head).
- Each head computes its own attention matrix independently.
- Different heads attend to different patterns (tour through them).
- Concatenate all heads’ outputs and apply a final projection W_O.
Interview questions & answers
Why multiple heads instead of one big attention?
A single softmax attention averages everything into one focus. Multiple heads attend to different relationships in parallel (different subspaces), giving the model more representational power at the same total cost.
How does the dimension split work?
With model dim d and h heads, each head uses d_k = d/h. So h heads of size d/h cost about the same as one head of size d — you get diversity “for free” in compute.
Do heads learn redundant things?
Some do — research shows many heads can be pruned with little loss. But many specialize (positional, syntactic, rare-token heads), which is why multi-head still helps.
What gets concatenated and why the final W_O?
Each head outputs a d_k-vector per token; concatenating gives back a d-vector, and W_O mixes information across heads into the layer’s output.
Common pitfalls
- Assuming every head is meaningful — many are redundant/prunable.
- Confusing heads (parallel subspaces) with layers (sequential depth).
- Forgetting the output projection W_O that combines the heads.
Where it shows up
- Every transformer (GPT, BERT, ViT, Llama)
- The core of self-attention layers
More AI / ML interview concepts
- Neural Networks & Backpropagation
- Gradient Descent & Optimizers
- Activation Functions
- K-Means Clustering
- Self-Attention
- Softmax, Temperature & Sampling
- Tokenization (Byte-Pair Encoding)
- Positional Encoding
- KV Cache
- Rotary Position Embedding (RoPE)
- The Transformer Block
- Normalization (LayerNorm / RMSNorm)
- Multi-Query & Grouped-Query Attention
- Flash Attention
- Decoding: Beam Search & Speculative Decoding
- Embeddings & Cosine Similarity
- RAG (Retrieval-Augmented Generation) Pipeline
- Vector Search (HNSW)
- Chunking & Reranking
- ReAct Agent Loop
- Tool / Function Calling
- Multi-Agent Orchestration
- Planning & Task Decomposition
- Agent Memory
- Model Context Protocol (MCP)
- Quantization
- LoRA / PEFT Fine-Tuning
- Mixture of Experts (MoE)
- RLHF / DPO Alignment
- Evals & LLM-as-Judge
- Prompt Injection & Guardrails
- Knowledge Distillation
PrepGrind runs entirely in your browser, free, no installation required. Loading the interactive playground…