🧠 Flash Attention — AI / ML Interview Guide
LLM Internals · interactive visualization + interview prep
Open the interactive Flash Attention visualization on PrepGrind → Step through a live animation, tune the parameters, and read the full theory, math, reference code, and interview Q&A below — free, in your browser.
What it is
Standard attention builds the full n×n score matrix in memory — O(n²) — which is slow and memory-hungry for long sequences. FlashAttention computes the SAME attention in tiles, keeping a running softmax so it never stores the big matrix. It is an exact, IO-aware reimplementation: same result, far less memory traffic.
Mental model
Don't write the whole giant spreadsheet, then sum it. Stream it block by block, keeping a running total as you go. FlashAttention does attention this way — process a tile of keys/values, update a running max and sum (the "online softmax"), discard the tile, repeat. The n×n matrix never exists all at once, so you move tiny amounts of data to/from slow GPU memory.
Theory
The bottleneck in attention is usually NOT arithmetic — it is memory bandwidth. Materializing the n×n scores matrix means writing it to and reading it back from the GPU's slow high-bandwidth memory (HBM). For long sequences that traffic dominates runtime and the O(n²) matrix dominates memory.
FlashAttention is IO-AWARE: it restructures the computation to maximize work done in fast on-chip SRAM and minimize reads/writes to HBM. It loads small TILES of Q, K, V into SRAM, computes their partial attention there, and only writes the final output back — never the full score matrix.
The enabling trick is the ONLINE (streaming) softmax. Softmax normally needs the whole row to find its max and sum, but you can process the row in chunks while maintaining a running max and a running normalizer, rescaling the partial result as each new tile arrives. This makes the tiled, never-materialized computation mathematically EXACT — not an approximation.
So the win is memory traffic, not fewer FLOPs: the arithmetic is still O(n²·d), but HBM I/O drops from O(n²) to roughly O(n²/SRAM) and peak EXTRA memory from O(n²) to O(n). In practice that means large wall-clock speedups and the ability to train/serve much longer contexts on the same hardware.
Contrast this with sparse or linear attention, which change the math (and the results) to cut cost. FlashAttention changes only the IMPLEMENTATION — identical outputs to vanilla attention — which is why it became a default kernel rather than a modeling choice.
Concrete example
FlashAttention (and v2/v3) is the default attention kernel in PyTorch (scaled_dot_product_attention), vLLM, and most training stacks. It is a big reason 32K–128K context windows became practical without exotic approximations.
Key equations
standard: S = QKᵀ (n×n) in HBM → softmax → ·V (O(n²) memory I/O)flash: tile Q,K,V into SRAM; for each tile updaterunning max m, running sum ℓ, running output o (online softmax)exact result, peak extra memory O(n) instead of O(n²)FLOPs unchanged O(n²·d); HBM traffic ↓↓ (the real win)
Step by step
- Split Q, K, V into tiles small enough to fit in fast on-chip SRAM.
- Load a tile; compute its partial attention scores there.
- Update the running max and sum (online softmax) and the running output.
- Drop the tile and stream the next — the full n×n matrix is never stored.
- After the last tile, write the final output to memory.
Interview questions & answers
Does FlashAttention change the attention result?
No — it is mathematically exact, identical outputs to standard attention. It only reorders the computation to be IO-aware; nothing is approximated (unlike sparse/linear attention).
What does it actually optimize, given the FLOPs are the same?
Memory traffic. Attention is memory-bandwidth bound, so avoiding writing/reading the n×n matrix to slow HBM — by tiling in fast SRAM — is what yields the speedup and the O(n) peak memory.
What is the online softmax and why is it needed?
A way to compute softmax incrementally with a running max and normalizer, rescaling partial sums as new tiles arrive. It lets you finish softmax without ever holding a full row, which is what makes tiling exact.
FlashAttention vs sparse/linear attention?
FlashAttention is an exact kernel-level optimization (same math, less I/O). Sparse/linear attention change the math to reduce O(n²) cost and give approximate results. They are complementary, not the same thing.
Common pitfalls
- Thinking it reduces FLOPs or approximates attention — it does neither.
- Confusing it with sparse/linear attention (those change the result).
- Expecting gains on tiny sequences — the win grows with sequence length.
Where it shows up
- PyTorch scaled_dot_product_attention, vLLM, most training stacks
- Long-context training & inference (32K–128K+)
- FlashAttention-2 / 3 kernels on modern GPUs
More AI / ML interview concepts
- Neural Networks & Backpropagation
- Gradient Descent & Optimizers
- Activation Functions
- K-Means Clustering
- Self-Attention
- Multi-Head Attention
- Softmax, Temperature & Sampling
- Tokenization (Byte-Pair Encoding)
- Positional Encoding
- KV Cache
- Rotary Position Embedding (RoPE)
- The Transformer Block
- Normalization (LayerNorm / RMSNorm)
- Multi-Query & Grouped-Query Attention
- Decoding: Beam Search & Speculative Decoding
- Embeddings & Cosine Similarity
- RAG (Retrieval-Augmented Generation) Pipeline
- Vector Search (HNSW)
- Chunking & Reranking
- ReAct Agent Loop
- Tool / Function Calling
- Multi-Agent Orchestration
- Planning & Task Decomposition
- Agent Memory
- Model Context Protocol (MCP)
- Quantization
- LoRA / PEFT Fine-Tuning
- Mixture of Experts (MoE)
- RLHF / DPO Alignment
- Evals & LLM-as-Judge
- Prompt Injection & Guardrails
- Knowledge Distillation
PrepGrind runs entirely in your browser, free, no installation required. Loading the interactive playground…