All articles
· 15 min deep-divetransformersgpuflash-attention
Article 1 in your session

Flash Attention — Making Transformers Actually Fast

A scroll-driven visual deep dive into Flash Attention. Learn why standard attention is broken, how GPU memory works, and how tiling fixes everything — with quizzes to test your understanding.

Introduction 0%
Introduction
🎯 0/5 0%

🐌 → ⚡

ChatGPT’s brain has a speed problem.


Here’s how they fixed it.

Every AI model you use — ChatGPT, Claude, Gemini — runs something called “attention.”
It’s powerful, but it’s painfully slow. Flash Attention fixes that.
No PhD required to understand this. Seriously.

↓ Scroll to learn — there are quizzes to make sure you actually get it

The Problem

First, what even is Attention?

You know how when you read a sentence, some words matter more for understanding than others? That’s basically what attention does in AI models.

When ChatGPT reads “The cat sat on the mat,” it figures out which words are related to each other. That process has 4 steps:

The Attention Recipe (don't panic, it's simpler than it looks)

1
Q, K, V = XWq, XWk, XWv
Turn each word into 3 different vectors: a Question ('what am I looking for?'), a Key ('what do I contain?'), and a Value ('what info do I carry?')
2
S = QKᵀ / √d
Compare every word to every other word — like making a giant 'who's related to who' table
3
P = softmax(S)
Convert those scores to percentages (so they add up to 100% per word)
4
O = PV
Mix the word info together based on those percentages — done!

This works great! But there’s a big, ugly problem lurking underneath:

n=864n=644K262Kn=5124Mn=204867Mentries!n=8192n² growth → double words = 4× memory
The attention matrix grows QUADRATICALLY — double the words, 4x the memory

But here’s the twist most people miss — the math itself is fast. Modern GPUs are absolute beasts at multiplication.

The real problem? All that scratch paper needs to be saved and loaded from memory, and that’s what’s slow. It’s like having a calculator that can do math in 0.001 seconds, but you have to walk to a different room every time you need a new sheet of paper.

↑ Answer the question above to continue ↑
🟢 Quick Check Knowledge Check

What's the REAL bottleneck in standard transformer attention?

GPU Memory 101

Your GPU Has Two Types of Memory (and That Matters A LOT)

⚡ SRAMYour Desk~20 MB19 TB/sinstant access🐌 HBMThe Bookshelf40–80 GB2 TB/sgotta walk there1000×speed gapData transfer = bottleneckTensor Coresmath happens here (fast!)Q, K, V, S, P, Oall stored here (slow!)
NVIDIA A100 GPU: Your desk (SRAM) has ~10x more bandwidth, but the bookshelf (HBM) is ~4,000x bigger

So What Does Standard Attention Do With This?

Here’s the painful part — standard attention treats this like the worst possible study strategy:

Standard attention = walking to the bookshelf 3 times 🤦

1
Bookshelf → Desk: Grab Q, K
Walk to the bookshelf, grab Query and Key. Bring them back to your desk.
2
Desk: Do the math (S = QKᵀ)
This part is fast! You're at your desk, crunching numbers.
3
Desk → Bookshelf: Save S
⚠️ Walk that huge NxN result all the way back to the bookshelf
4
Bookshelf → Desk: Pick S back up
⚠️ Walk back AGAIN to grab it for the next step (softmax). Why?!
5
Desk → Bookshelf: Save P
⚠️ Walk BACK to save the softmax result. THREE round trips total.
6
Most time spent walking, not studying!
Your GPU does the math in milliseconds but spends most of its time waiting for data to travel

Standard vs Flash: The Data Movement Side-by-Side

🐌 HBM (Bookshelf) Q, K, V, S, P stored here SRAM (Desk) Math happens here ① Read Q, K ② Write S back ③ Read S again
Standard attention: 3 round trips between desk (SRAM) and bookshelf (HBM)
📚 HBM (Bookshelf) Only Q, K, V, O touch this SRAM (Desk) S, P stay here forever! Read tile of Q, K Write final O only
Flash Attention: compute Q×K tile on desk, use it, discard. S never leaves SRAM!
↑ Answer the question above to continue ↑
🟡 Checkpoint Knowledge Check

SRAM is ~19 TB/s while HBM is ~2 TB/s. Your algorithm reads the attention matrix from HBM 3 times. What happens?

The Breakthrough

💡 The Big Idea

What if we never save that giant attention table to the bookshelf at all?



Instead, work on small pieces of the problem that fit on your desk. Do the math, use the result, and move on — no walking required.

But Wait… There’s a Catch with Softmax

Softmax is that step where you convert raw scores to percentages. And it has an annoying requirement:

Why softmax seems impossible to do in pieces

1
softmax(xᵢ) = eˣⁱ / Σⱼ eˣʲ
To get the percentage for one value, you need to know the TOTAL of all values in that row
2
Problem: You can't know the total without seeing everything first
It's like asking 'what % of the class got an A?' when you haven't graded all the exams yet
3
Solution: The 'online softmax' trick! 🧠
Turns out, you CAN update your answer as new data comes in. Mind = blown.
↑ Answer the question above to continue ↑
🟡 Checkpoint Knowledge Check

Why can't we just compute softmax block-by-block without any tricks?

Tiling

Tiling: Working on Small Pieces at a Time

Instead of building the entire giant table at once, Flash Attention works on small tiles — like solving a jigsaw puzzle one section at a time, instead of dumping all 1,000 pieces on the floor.

K₁K₂K₃K₄Q₁Q₂Q₃Q₄ON DESKcomputing nownext →queuedqueuedOne tile at a time — S and P never leave the desk
The attention table is split into blocks. Only one tiny tile lives on your desk (SRAM) at a time.

The Algorithm in Plain English (no jargon, pinky promise)

  1. Chop up your data into desk-sized blocks — small enough to fit in SRAM
  2. For each block of Questions:
    • Grab it from the bookshelf (one trip — that’s fine)
    • For each block of Keys & Values:
      • Grab them too (one trip each)
      • Do the math right there on your desk — fast!
      • Update your running softmax stats (the online trick from before)
      • Add the partial result to your running total
    • Put only the final answer back on the bookshelf (one trip)
  3. Done! The giant intermediate table (S and P) never left your desk — it was computed, used, and discarded on the spot.
↑ Answer the question above to continue ↑
🟡 Checkpoint Knowledge Check

In Flash Attention, which data NEVER gets saved to HBM (the bookshelf)?

IO Complexity

This Isn’t Just a Hack — It’s Mathematically Optimal

Here’s where it gets cool. The authors didn’t just build something that works — they proved no algorithm can possibly do better.

Memory access comparison (big deal alert 🚨)

🐌 Standard Attention vs Flash Attention vs 🖥️ A100 GPU Numbers vs 🏆 Proven Optimal
🐌 Standard Attention

Reads and writes the full N×N attention matrix to slow HBM memory multiple times. Each round trip wastes time the GPU could spend computing.

O(N²) HBM accesses
Flash Attention

Processes attention in small tiles that fit entirely in fast SRAM. The N×N matrix is never materialized in HBM — only final output O is written back.

O(N²d / M) HBM accesses
🖥️ A100 GPU Numbers

With SRAM (M) at 20MB and word vector size (d) at 64-128, the desk is ~300× bigger than what each tile needs. This translates to massive savings in memory traffic.

🏆 Proven Optimal

This isn't just a clever heuristic — it's been mathematically proven that no algorithm can achieve fewer HBM accesses for exact attention computation.

The Memory Savings Are Wild

Since that giant table never hits the bookshelf:

  • Standard attention memory: O(n²) — doubles the input, 4x the memory. Brutal.
  • Flash Attention memory: O(n) — doubles the input, doubles the memory. Linear!

This is the reason modern AI models can handle 128K+ word conversations. Before Flash Attention, 2K words was already pushing it. Now ChatGPT can read entire books in one shot.

Real Impact

What This Actually Changed in the Real World

🔥Longer Conversations512 → 128K+ words.ChatGPT can readentire PDFs now.Faster & Cheaper2–4× faster attention.Millions saved inGPU costs.💾Runs on Small GPUsO(n²) → O(n) memory.Fine-tune on agaming GPU.Memory at n=8K:Standard128 MBFlash~1 MB ✓128× less memory 🤯
Flash Attention went from research paper to industry standard in under a year

You’re Probably Already Using It

Flash Attention isn’t some research curiosity — it’s everywhere:

  • PyTorchF.scaled_dot_product_attention() uses it by default since PyTorch 2.0
  • Hugging Face — One flag: attn_implementation="flash_attention_2"
  • ChatGPT, Claude, Gemini — All use it under the hood
  • LLaMA, Mistral, Falcon — Every popular open-source model too
↑ Answer the question above to continue ↑
🔴 Challenge Knowledge Check

You want to train a model on 64K-word documents. Standard attention needs ~8 GB for the attention table per head per layer. With Flash Attention, what changes?

🎓 What You Now Know

Standard attention is memory-bound — the GPU does math fast but wastes time walking to the bookshelf (HBM) to save/load that giant n² table.

GPUs have two memory types — SRAM (tiny desk, ~10x higher bandwidth) and HBM (massive bookshelf, slow). Flash Attention keeps all the work on the desk.

Online softmax makes tiling possible — running statistics let you compute softmax block-by-block and get the exact same result.

The score table never hits the bookshelf — S and P are computed, used, and trashed in SRAM. Only the final answer goes to HBM.

Result: O(n) memory, 2–4x speed, provably optimal — this is what enabled 128K+ context windows and made modern AI practical.

Check your quiz score → How many did you nail? 🎯

📄 Read the original paper: FlashAttention (Dao et al., 2022)


📄 Flash Attention 2 paper

Keep Learning