🧠 Multi-Query & Grouped-Query Attention — AI / ML Interview Guide
LLM Internals · interactive visualization + interview prep
Open the interactive Multi-Query & Grouped-Query Attention visualization on PrepGrind → Step through a live animation, tune the parameters, and read the full theory, math, reference code, and interview Q&A below — free, in your browser.
What it is
In standard multi-head attention every query head has its OWN Key/Value projection, so the KV cache stores K/V for all heads — the main memory cost at inference. MQA shares ONE K/V across all heads; GQA shares a few K/V groups. Fewer K/V heads → a much smaller cache, faster decoding, tiny quality cost.
Mental model
Picture the query heads as researchers and the K/V projections as shared filing cabinets. MHA gives every researcher a private cabinet (lots of memory). MQA makes them all share one cabinet (tiny memory, some contention). GQA puts them in a few teams, one cabinet per team — most of the memory savings, most of the quality. The knob is simply: how many cabinets?
Theory
The KV cache stores past tokens' Key and Value vectors for every layer and head, and it is the dominant inference-memory cost — it grows with sequence length × layers × heads and caps batch size and context length (see the KV Cache concept). Shrinking the per-token K/V is the highest-leverage way to reduce it.
Multi-Query Attention (MQA) keeps all the QUERY heads but uses a SINGLE shared Key and Value projection for them. This cuts the KV cache by a factor of n_heads, dramatically speeding memory-bound decoding — but collapsing all K/V into one head can cost some quality and make training less stable.
Grouped-Query Attention (GQA) is the middle ground: partition the query heads into G groups and give each group its own K/V. G = n_heads is plain MHA; G = 1 is MQA; G = 8 (say, 64 query heads → 8 K/V groups) keeps almost all the quality while shrinking the cache ~8×. It interpolates smoothly between the two extremes.
Only the K and V projections shrink — the number and size of QUERY heads is unchanged, so expressiveness on the query side is preserved. At attention time the shared K/V is broadcast to the query heads in its group.
This is purely an inference-efficiency trade-off and composes with the rest of the stack (RoPE, FlashAttention, quantized cache). It is why GQA is now standard in large open models — the quality hit is small and the serving wins (throughput, context length) are large.
Concrete example
Llama-2-70B and Llama-3 use GQA; many fast-serving and on-device models use MQA. Going from MHA to GQA-8 on a 64-head model cuts KV-cache memory ~8×, which directly raises the batch size and context length you can serve on the same GPU.
Key equations
MHA: n_q query heads, n_kv = n_q → KV cache ∝ n_qMQA: n_q query heads, n_kv = 1 → KV cache ∝ 1 (n_q× smaller)GQA: n_q query heads, n_kv = G → KV cache ∝ G (1 < G < n_q)each K/V group is shared (broadcast) by n_q / G query headsquery heads/size unchanged — only K,V projections shrink
Step by step
- Start from MHA: 8 query heads, 8 K/V heads — the full cache.
- MQA: keep 8 query heads but a single shared K/V — cache ÷8.
- GQA: split the 8 query heads into 2 groups, each with its own K/V — cache ÷4.
- Each query head attends using its group's shared K/V.
- Compare the KV-cache bars: fewer K/V heads = smaller cache, faster decode.
Interview questions & answers
What problem do MQA/GQA solve?
The KV-cache memory bottleneck at inference. By sharing K/V across query heads they shrink the cache, raising max batch size and context length and speeding up memory-bound decoding.
How does GQA relate to MHA and MQA?
It generalizes both: with G K/V groups, G = n_heads is MHA and G = 1 is MQA. GQA picks an intermediate G to get most of MQA's savings with most of MHA's quality.
What stays the same — queries or keys/values?
The query heads (count and size) are unchanged; only the number of Key/Value projections is reduced and shared. So query-side expressiveness is preserved.
What's the downside of MQA?
Collapsing all K/V into one head can reduce quality and make training less stable. GQA recovers most of that quality, which is why GQA is the common choice for large models.
Common pitfalls
- Thinking MQA/GQA reduce query heads — they reduce K/V heads only.
- Assuming MQA is always best — the quality hit can matter; GQA is usually the sweet spot.
- Forgetting the win is at INFERENCE (cache/throughput), not training FLOPs.
Where it shows up
- GQA: Llama-2-70B, Llama-3, Mistral, many large open models
- MQA: fast-serving / on-device models (e.g. PaLM, Falcon variants)
- Inference-throughput and long-context optimization
More AI / ML interview concepts
- Neural Networks & Backpropagation
- Gradient Descent & Optimizers
- Activation Functions
- K-Means Clustering
- Self-Attention
- Multi-Head Attention
- Softmax, Temperature & Sampling
- Tokenization (Byte-Pair Encoding)
- Positional Encoding
- KV Cache
- Rotary Position Embedding (RoPE)
- The Transformer Block
- Normalization (LayerNorm / RMSNorm)
- Flash Attention
- Decoding: Beam Search & Speculative Decoding
- Embeddings & Cosine Similarity
- RAG (Retrieval-Augmented Generation) Pipeline
- Vector Search (HNSW)
- Chunking & Reranking
- ReAct Agent Loop
- Tool / Function Calling
- Multi-Agent Orchestration
- Planning & Task Decomposition
- Agent Memory
- Model Context Protocol (MCP)
- Quantization
- LoRA / PEFT Fine-Tuning
- Mixture of Experts (MoE)
- RLHF / DPO Alignment
- Evals & LLM-as-Judge
- Prompt Injection & Guardrails
- Knowledge Distillation
PrepGrind runs entirely in your browser, free, no installation required. Loading the interactive playground…