Flash Attention
Rewrite the attention kernel as an IO-aware tiled CUDA operation — eliminates the O(N²) memory bottleneck of standard attention without changing the math or the output.
Intent & Description
🎯 Intent
Standard attention materializes the full N×N attention matrix in HBM — slow memory IO, not compute, is the bottleneck. Flash Attention tiles the computation into blocks that fit in fast GPU SRAM and fuses the softmax + matmul into a single kernel pass.
📋 Context
For a sequence of length N, standard attention writes and reads a full (N×N) attention matrix to GPU HBM. At N=4K, that’s 16M elements per head per layer — repeated read/write trips to slow memory dominate wall-clock time. The computation is cheap; the memory IO is not.
💡 Solution
Tile Q, K, V into blocks that fit in SRAM. Compute attention in tiles — using the online softmax trick to maintain a running max/sum without materializing the full matrix. Fuse the entire softmax + dropout + attention + output into one kernel. HBM reads/writes drop from O(N²) to approximately O(N). Flash Attention v3 (2024) adds support for NVIDIA Hopper GPUs with variable-length batching and rotary embeddings natively. Now standard in PyTorch 2.0+ via F.scaled_dot_product_attention.
Real-world Use Case
📌 TL;DR
Same attention math, radically less memory IO. Tiles the N×N computation into SRAM-resident blocks. Non-negotiable baseline for training and serving any model with context > 2K.
Advantages
- 2–4x memory reduction vs. standard attention — enables longer context at the same GPU budget
- Significant wall-clock speedup (2–4x on common sequence lengths) due to IO reduction
- Exact attention output — no approximation, no quality degradation vs. standard attention
Disadvantages
- Kernel implementations are hardware-specific — AMD ROCm and older GPUs need separate ports
- Custom CUDA kernel complicates debugging and gradient inspection
- Chunked/tiled computation makes layer-wise attention patterns harder to visualize