Flash Attention — Making Transformers Actually Fast
A scroll-driven visual deep dive into Flash Attention. Learn why standard attention is broken, how GPU memory works, and how tiling fixes everything — with quizzes to test your understanding.
🐌 → ⚡
ChatGPT’s brain has a speed problem.
Here’s how they fixed it.
Every AI model you use — ChatGPT, Claude, Gemini — runs something called “attention.”
It’s powerful, but it’s painfully slow. Flash Attention fixes that.
No PhD required to understand this. Seriously.
↓ Scroll to learn — there are quizzes to make sure you actually get it
First, what even is Attention?
You know how when you read a sentence, some words matter more for understanding than others? That’s basically what attention does in AI models.
When ChatGPT reads “The cat sat on the mat,” it figures out which words are related to each other. That process has 4 steps:
The Attention Recipe (don't panic, it's simpler than it looks)
Q, K, V = XWq, XWk, XWv S = QKᵀ / √d P = softmax(S) O = PV This works great! But there’s a big, ugly problem lurking underneath:
But here’s the twist most people miss — the math itself is fast. Modern GPUs are absolute beasts at multiplication.
The real problem? All that scratch paper needs to be saved and loaded from memory, and that’s what’s slow. It’s like having a calculator that can do math in 0.001 seconds, but you have to walk to a different room every time you need a new sheet of paper.
What's the REAL bottleneck in standard transformer attention?
💡 Think about what takes longer: doing math or moving data around...
Nailed it! The GPU does math insanely fast — it's a math machine. But it spends most of its time waiting for data to be shuffled between slow main memory (HBM) and the fast compute units. Think of it this way: the calculator is fast, but the filing cabinet is far away.
Your GPU Has Two Types of Memory (and That Matters A LOT)
So What Does Standard Attention Do With This?
Here’s the painful part — standard attention treats this like the worst possible study strategy:
Standard attention = walking to the bookshelf 3 times 🤦
Bookshelf → Desk: Grab Q, K Desk: Do the math (S = QKᵀ) Desk → Bookshelf: Save S Bookshelf → Desk: Pick S back up Desk → Bookshelf: Save P Most time spent walking, not studying! Standard vs Flash: The Data Movement Side-by-Side
SRAM is ~19 TB/s while HBM is ~2 TB/s. Your algorithm reads the attention matrix from HBM 3 times. What happens?
💡 What happens when a fast worker is waiting on a slow delivery?
Exactly! The GPU's math units (tensor cores) are incredibly fast, but they spend most of their time twiddling their thumbs, waiting for data to show up from HBM. It's like having a Formula 1 engine stuck in traffic. Flash Attention's whole idea is: stop making so many trips to the bookshelf.
💡 The Big Idea
What if we never save that giant attention table to the bookshelf at all?
Instead, work on small pieces of the problem that fit on your desk. Do the math, use the result, and move on — no walking required.
But Wait… There’s a Catch with Softmax
Softmax is that step where you convert raw scores to percentages. And it has an annoying requirement:
Why softmax seems impossible to do in pieces
softmax(xᵢ) = eˣⁱ / Σⱼ eˣʲ Problem: You can't know the total without seeing everything first Solution: The 'online softmax' trick! 🧠 Why can't we just compute softmax block-by-block without any tricks?
💡 Think about what that denominator in the softmax formula needs...
Right! Softmax divides each value by the sum of ALL exponentials in the row. So you'd normally need to see every single element before producing any output. The online softmax trick solves this by keeping running statistics and rescaling as you go — giving the exact same result.
Tiling: Working on Small Pieces at a Time
Instead of building the entire giant table at once, Flash Attention works on small tiles — like solving a jigsaw puzzle one section at a time, instead of dumping all 1,000 pieces on the floor.
The Algorithm in Plain English (no jargon, pinky promise)
- Chop up your data into desk-sized blocks — small enough to fit in SRAM
- For each block of Questions:
- Grab it from the bookshelf (one trip — that’s fine)
- For each block of Keys & Values:
- Grab them too (one trip each)
- Do the math right there on your desk — fast!
- Update your running softmax stats (the online trick from before)
- Add the partial result to your running total
- Put only the final answer back on the bookshelf (one trip)
- Done! The giant intermediate table (S and P) never left your desk — it was computed, used, and discarded on the spot.
In Flash Attention, which data NEVER gets saved to HBM (the bookshelf)?
💡 Think about which data is 'scratch work' vs 'final answer'...
Bingo! The intermediate matrices S (the giant 'who relates to who' table) and P (softmax percentages) are computed and used tile by tile entirely on the desk (SRAM). They never get saved to the bookshelf (HBM). Only Q, K, V get read from HBM, and only the final output O gets written back.
This Isn’t Just a Hack — It’s Mathematically Optimal
Here’s where it gets cool. The authors didn’t just build something that works — they proved no algorithm can possibly do better.
Memory access comparison (big deal alert 🚨)
Reads and writes the full N×N attention matrix to slow HBM memory multiple times. Each round trip wastes time the GPU could spend computing.
O(N²) HBM accesses Processes attention in small tiles that fit entirely in fast SRAM. The N×N matrix is never materialized in HBM — only final output O is written back.
O(N²d / M) HBM accesses With SRAM (M) at 20MB and word vector size (d) at 64-128, the desk is ~300× bigger than what each tile needs. This translates to massive savings in memory traffic.
This isn't just a clever heuristic — it's been mathematically proven that no algorithm can achieve fewer HBM accesses for exact attention computation.
The Memory Savings Are Wild
Since that giant table never hits the bookshelf:
- Standard attention memory: O(n²) — doubles the input, 4x the memory. Brutal.
- Flash Attention memory: O(n) — doubles the input, doubles the memory. Linear!
This is the reason modern AI models can handle 128K+ word conversations. Before Flash Attention, 2K words was already pushing it. Now ChatGPT can read entire books in one shot.
What This Actually Changed in the Real World
You’re Probably Already Using It
Flash Attention isn’t some research curiosity — it’s everywhere:
- PyTorch —
F.scaled_dot_product_attention()uses it by default since PyTorch 2.0 - Hugging Face — One flag:
attn_implementation="flash_attention_2" - ChatGPT, Claude, Gemini — All use it under the hood
- LLaMA, Mistral, Falcon — Every popular open-source model too
You want to train a model on 64K-word documents. Standard attention needs ~8 GB for the attention table per head per layer. With Flash Attention, what changes?
💡 Remember: what massive intermediate data does Flash Attention avoid storing?
Flash Attention uses O(n) memory instead of O(n²) by never saving the full attention table. For n=64K, that's the difference between ~4 billion entries (64K²) vs ~64 thousand entries (64K). This is literally the reason why AI models can now read entire books in one conversation.
🎓 What You Now Know
✓ Standard attention is memory-bound — the GPU does math fast but wastes time walking to the bookshelf (HBM) to save/load that giant n² table.
✓ GPUs have two memory types — SRAM (tiny desk, ~10x higher bandwidth) and HBM (massive bookshelf, slow). Flash Attention keeps all the work on the desk.
✓ Online softmax makes tiling possible — running statistics let you compute softmax block-by-block and get the exact same result.
✓ The score table never hits the bookshelf — S and P are computed, used, and trashed in SRAM. Only the final answer goes to HBM.
✓ Result: O(n) memory, 2–4x speed, provably optimal — this is what enabled 128K+ context windows and made modern AI practical.
Check your quiz score → How many did you nail? 🎯
📄 Read the original paper: FlashAttention (Dao et al., 2022)
↗ Keep Learning
Transformers — The Architecture That Changed AI
A scroll-driven visual deep dive into the Transformer architecture. From RNNs to self-attention to GPT — understand the engine behind every modern AI model.
Speculative Decoding — Making LLMs Think Ahead
A scroll-driven visual deep dive into speculative decoding. Learn why LLM inference is slow, how a small 'draft' model can speed up a big model by 2-3x, and why the output is mathematically identical.
Comments
No comments yet. Be the first!