🧠 Flash Attention — AI / ML Interview Guide

LLM Internals · interactive visualization + interview prep

Open the interactive Flash Attention visualization on PrepGrind → Step through a live animation, tune the parameters, and read the full theory, math, reference code, and interview Q&A below — free, in your browser.

What it is

Standard attention builds the full n×n score matrix in memory — O(n²) — which is slow and memory-hungry for long sequences. FlashAttention computes the SAME attention in tiles, keeping a running softmax so it never stores the big matrix. It is an exact, IO-aware reimplementation: same result, far less memory traffic.

Mental model

Don't write the whole giant spreadsheet, then sum it. Stream it block by block, keeping a running total as you go. FlashAttention does attention this way — process a tile of keys/values, update a running max and sum (the "online softmax"), discard the tile, repeat. The n×n matrix never exists all at once, so you move tiny amounts of data to/from slow GPU memory.

Theory

The bottleneck in attention is usually NOT arithmetic — it is memory bandwidth. Materializing the n×n scores matrix means writing it to and reading it back from the GPU's slow high-bandwidth memory (HBM). For long sequences that traffic dominates runtime and the O(n²) matrix dominates memory.

FlashAttention is IO-AWARE: it restructures the computation to maximize work done in fast on-chip SRAM and minimize reads/writes to HBM. It loads small TILES of Q, K, V into SRAM, computes their partial attention there, and only writes the final output back — never the full score matrix.

The enabling trick is the ONLINE (streaming) softmax. Softmax normally needs the whole row to find its max and sum, but you can process the row in chunks while maintaining a running max and a running normalizer, rescaling the partial result as each new tile arrives. This makes the tiled, never-materialized computation mathematically EXACT — not an approximation.

So the win is memory traffic, not fewer FLOPs: the arithmetic is still O(n²·d), but HBM I/O drops from O(n²) to roughly O(n²/SRAM) and peak EXTRA memory from O(n²) to O(n). In practice that means large wall-clock speedups and the ability to train/serve much longer contexts on the same hardware.

Contrast this with sparse or linear attention, which change the math (and the results) to cut cost. FlashAttention changes only the IMPLEMENTATION — identical outputs to vanilla attention — which is why it became a default kernel rather than a modeling choice.

Concrete example

FlashAttention (and v2/v3) is the default attention kernel in PyTorch (scaled_dot_product_attention), vLLM, and most training stacks. It is a big reason 32K–128K context windows became practical without exotic approximations.

Key equations

Step by step

  1. Split Q, K, V into tiles small enough to fit in fast on-chip SRAM.
  2. Load a tile; compute its partial attention scores there.
  3. Update the running max and sum (online softmax) and the running output.
  4. Drop the tile and stream the next — the full n×n matrix is never stored.
  5. After the last tile, write the final output to memory.

Interview questions & answers

Does FlashAttention change the attention result?

No — it is mathematically exact, identical outputs to standard attention. It only reorders the computation to be IO-aware; nothing is approximated (unlike sparse/linear attention).

What does it actually optimize, given the FLOPs are the same?

Memory traffic. Attention is memory-bandwidth bound, so avoiding writing/reading the n×n matrix to slow HBM — by tiling in fast SRAM — is what yields the speedup and the O(n) peak memory.

What is the online softmax and why is it needed?

A way to compute softmax incrementally with a running max and normalizer, rescaling partial sums as new tiles arrive. It lets you finish softmax without ever holding a full row, which is what makes tiling exact.

FlashAttention vs sparse/linear attention?

FlashAttention is an exact kernel-level optimization (same math, less I/O). Sparse/linear attention change the math to reduce O(n²) cost and give approximate results. They are complementary, not the same thing.

Common pitfalls

Where it shows up

More AI / ML interview concepts

PrepGrind runs entirely in your browser, free, no installation required. Loading the interactive playground…