Naive attention writes the N×N attention matrix to memory. FlashAttention computes it block-by-block in SRAM (on-chip cache), never materializing the full matrix. Same math, ~5× faster on long context.
The IO problem
# Naive:
# 1. Compute Q·Kᵀ → write [N, N] to HBM (slow)
# 2. Apply softmax over rows (slow read+write)
# 3. Multiply by V → write [N, dv] to HBM (slow)
#
# For N=8192, d=64: 256 MB just for scoresHBM (GPU global memory) is slow relative to SRAM (on-chip cache, ~10× faster but ~1000× smaller). Materializing the full N×N matrix forces it into HBM. FlashAttention avoids this.
Tiling and recompute
for q_block in Q_tiles:
for k_block, v_block in zip(K_tiles, V_tiles):
# all in SRAM:
scores = q_block · k_blockᵀ
softmax_partial(scores, running_max, running_sum)
out_block += softmax_partial · v_block
write out_block to HBMEach block is small enough to fit in SRAM. Online softmax (maintain running max and sum across blocks) gives exact result without ever materializing the full softmax. Output written once. Memory I/O: O(N·d) instead of O(N²).
The math equivalence
The output is identical to naive attention. Online softmax is an algebraic rearrangement: softmax(s) = exp(s - max) / sum(exp(s - max)). The running max and sum can be combined across blocks correctly. Standard 'log-sum-exp trick' applied incrementally.
Speed and memory
# FlashAttention 2:
# Memory: O(N·d) (no full attention matrix)
# Time: O(N²·d), same FLOPs, but 5-10x faster wall-timeSame compute count; far better memory access pattern. PyTorch's F.scaled_dot_product_attention uses FlashAttention 2 automatically on supported GPUs. Standard in production training and inference.
CPU equivalents
CPU attention can use the same tiling idea. llama.cpp's CPU attention does block-wise computation to fit L2/L3 cache. Not as dramatic a win as GPU FlashAttention, but still ~2× for long contexts.