FlashAttention rewrites self-attention to be IO-aware — tiling computation to keep activations in SRAM rather than spilling to HBM. Same math, 2-10x faster on modern GPUs. Standard in inference and training pipelines now.
The IO bottleneck
Naive attention is O(N²) memory because it materializes the full attention matrix in HBM. With N=8K that's 64M floats per head — massive HBM bandwidth use that the GPU is bound by.
Tiling trick
Compute attention in blocks small enough to fit in SRAM (fast on-chip memory). Recompute as needed during backward pass (FlashAttention 2 added more optimizations). Memory becomes O(N) instead of O(N²).
Practical impact
Long contexts that were OOM at 16K now work at 128K+. Training speedup 1.5-3x; inference for long sequences 5-10x. Now baseline; PyTorch SDPA uses it automatically on supported hardware.