Speculative Decoding — Making LLMs Think Ahead
A scroll-driven visual deep dive into speculative decoding. Learn why LLM inference is slow, how a small 'draft' model can speed up a big model by 2-3x, and why the output is mathematically identical.
🐢 → 🐇
What if a small model could make
a big model faster?
GPT-4 generates text one word at a time. That’s painfully slow.
Speculative decoding uses a tiny “draft” model to guess ahead, then lets the big model verify in parallel.
Same output. 2–3x faster.
↓ Scroll to learn — quizzes will test your understanding
Why Is ChatGPT So Slow at Typing?
Ever noticed how ChatGPT “types” one word at a time? That’s not a UI trick — the model literally generates one token at a time. And each token requires a full forward pass through the entire model.
The Autoregressive Bottleneck
Here’s the thing most people don’t realize: the GPU is barely working during text generation. The model is so big that most of the time is spent loading model weights from memory, not doing math.
Why each token is expensive
1 token = 1 full forward pass GPU utilization: ~1-5% 100 tokens at 30ms each = 3 seconds Why is autoregressive text generation slow even on powerful GPUs?
💡 Think about what takes longer: doing the math or loading the data...
During generation, each token requires loading the full model from memory. The GPU's compute units are incredibly fast, but they spend most of their time idle, waiting for weights to arrive. This is called being 'memory-bandwidth-bound.'
💡 The Big Idea
What if a tiny model guessed the next 5 words, and then the big model checked all 5 at once?
The small model is 100x faster but less accurate.
The big model can verify multiple tokens in one pass (same cost as generating one!).
If the guesses are right? Free speedup. If wrong? Just try again.
Here’s the key insight that makes this work: verification is parallel, generation is sequential.
When you give a language model a sequence like “The cat sat on the”, it computes the probability of every next token at every position in a single forward pass. That’s just how transformers work! So checking 5 guesses costs the same as generating 1 token.
Why can the big model verify K guessed tokens in the same time it takes to generate 1 token?
💡 How does a transformer process a prompt of 100 tokens — one at a time or all at once?
Transformers are inherently parallel — a single forward pass computes the next-token probability at every position simultaneously. So feeding in 5 draft tokens and checking them all costs roughly the same as generating 1 token from scratch.
The Algorithm: Draft, Verify, Accept
Step by Step
The speculative decoding loop
Step 1: Draft model generates K tokens Step 2: Run big model on all K tokens at once Step 3: Compare draft vs target probabilities Step 4: Accept prefix, reject suffix Result: Got 3-4 tokens for the cost of ~1 big-model pass! The draft model generates 5 tokens. The big model verifies and finds token 3 is wrong. What happens?
💡 Think about it: if token 3 is wrong, can tokens 4-5 (which depend on token 3) be trusted?
Speculative decoding accepts the longest valid prefix. Tokens 1-2 are correct, so they're kept. Token 3 is wrong, so it gets resampled from the big model's distribution. Tokens 4-5 are after the rejection point, so they're discarded — they were conditioned on a wrong token.
The Rejection Sampling Trick (Why the Output Is Identical)
This is the magical part: speculative decoding doesn’t produce “approximate” output. It produces the exact same distribution as running the big model alone.
How rejection sampling preserves the target distribution
p(x) = target model probability q(x) = draft model probability Accept if: rand() < min(1, p(x)/q(x)) If rejected: sample from (p(x) - q(x))⁺ / Z Result: output distribution = p(x) exactly Expected Speedup
The speedup depends on the acceptance rate — how often the small model’s guesses match the big model.
Does speculative decoding reduce the quality of the generated text?
💡 Remember the rejection sampling formula — what does it guarantee about the output distribution?
The rejection sampling mechanism mathematically guarantees that the output distribution is identical to the big model. If the draft token doesn't match, it gets rejected and resampled from the target distribution. Zero quality loss — only the speed changes.
Real-World Impact
Who Uses It?
- Google — Uses it in Gemini for faster inference
- Meta — LLaMA models support speculative decoding natively
- vLLM — The most popular LLM serving framework supports it out of the box
- Apple — Uses it for on-device LLM inference in Apple Intelligence
- Medusa, EAGLE — Variants that use multiple draft heads instead of a separate model
A draft model with 1B parameters is drafting for a 70B target model. The draft generates 6 tokens. The target verifies and accepts tokens 1-4 but rejects token 5. How many total tokens do you get from this round?
💡 What happens at the exact position where the rejection occurs?
You get 5 tokens: the 4 accepted draft tokens, plus 1 token resampled from the target model at the rejection point. Token 6 is discarded because it was conditioned on the wrong token 5. So you got 5 tokens for the cost of ~1 target forward pass — that's a great deal!
🎓 What You Now Know
✓ LLM generation is memory-bound — The GPU spends most of its time loading weights, not computing. Each token requires a full forward pass.
✓ A small model drafts, a big model verifies — The draft model is fast but imperfect. The big model checks all draft tokens in parallel.
✓ Rejection sampling preserves quality — The output distribution is mathematically identical to the big model alone.
✓ 2-3x speedup for free — No quality loss, no extra hardware, no approximations. Just clever scheduling.
Check your quiz score → How many did you nail? 🎯
📄 Read the paper: Fast Inference from Transformers via Speculative Decoding (Leviathan et al., 2022)
📄 Accelerating LLM Inference with Staged Speculative Decoding (Chen et al., 2023)
↗ Keep Learning
Transformers — The Architecture That Changed AI
A scroll-driven visual deep dive into the Transformer architecture. From RNNs to self-attention to GPT — understand the engine behind every modern AI model.
Flash Attention — Making Transformers Actually Fast
A scroll-driven visual deep dive into Flash Attention. Learn why standard attention is broken, how GPU memory works, and how tiling fixes everything — with quizzes to test your understanding.
Caching — The Art of Remembering What's Expensive to Compute
A visual deep dive into caching. From CPU caches to CDNs — understand cache strategies, eviction policies, and the hardest problem in computer science: cache invalidation.
Comments
No comments yet. Be the first!