Table of Contents
- TL;DR
- The setup: generation is a loop
- The wasted work
- Why keys and values, but not queries
- Prefill and decode: the two phases of inference
- The bill: doing the memory math
- Implementing one from scratch
- Shrinking the cache
- What to take away
- Sources and further reading
Every time an LLM streams an answer at you, it is running the same trick behind the scenes, thousands of times per response: it refuses to redo work it has already done. That trick is the KV cache, and it is arguably the single most important inference optimization in modern language models. Without it, generating token 1,000 of a response would cost roughly a thousand times more than generating token 1, chat latency would grow quadratically with conversation length, and long-context models would be unusable in practice.
This post builds the KV cache up from first principles: why the transformer’s attention math creates redundant work, why exactly two of the three attention ingredients can be cached (and the third cannot), what the cache costs in memory, how to implement one in a few lines of PyTorch, and the family of architecture and systems tricks (GQA, MLA, PagedAttention, prefix caching) that exist purely to keep this one data structure small.
TL;DR
- Autoregressive generation feeds the model its own output, one token at a time. Naively, every step reprocesses the entire sequence so far, recomputing key and value vectors that are bit-for-bit identical to last step’s.
- Causal attention makes past tokens immutable: token \(t\) cannot see the future, so its key and value vectors never change once computed. Cache them, and each step only has to process the newest token.
- Queries are not cached because each query is used exactly once, at its own step. Keys and values are read by every future step; that asymmetry is the whole design.
- The cache is a pure compute optimization, not an approximation: with and without it, outputs are token-identical. Measured speedups are consistently around 5x on modest models, and grow with sequence length.
- The price is memory that grows linearly with context length and batch size: \(2 \times \text{layers} \times \text{kv heads} \times \text{head dim} \times \text{bytes}\) per token. For Llama 2 7B that is 0.5 MiB per token, or 64 GiB at a 128k context, more than the weights themselves.
- A surprising amount of modern LLM architecture (grouped-query attention, DeepSeek’s multi-head latent attention, sliding windows) and serving infrastructure (vLLM’s PagedAttention, prompt caching discounts) is best understood as engineering around the size of this cache.
The setup: generation is a loop
A decoder-only transformer generates text autoregressively. It takes a sequence of tokens (word pieces, roughly), predicts a probability distribution over the next token, samples one, appends it to the sequence, and repeats. One forward pass of the whole network per token of output.
Everything in this post follows from what happens inside that forward pass, so it is worth a ninety-second refresher on attention with the tensor shapes spelled out. If you can already write attention from memory, skip to the wasted work.
A ninety-second attention refresher
Each token enters the model as an embedding vector, and every layer transforms it into a new hidden vector \(x_i\) of width \(d_{\text{model}}\): token \(i\)’s current representation, as refined by the layers so far. Self-attention is the step where tokens exchange information, and it starts by projecting each token’s hidden vector three ways using learned weight matrices:
\[\underset{1 \times d_{\text{head}}}{q_i} \;=\; \underset{1 \times d_{\text{model}}}{x_i}\;\underset{d_{\text{model}} \times d_{\text{head}}}{W_Q}, \qquad k_i = x_i W_K, \qquad v_i = x_i W_V\](\(k_i\) and \(v_i\) have the same shapes as \(q_i\); each is just a different learned projection of the same \(x_i\).)
The query is what token \(i\) uses to interrogate the past. The keys are what each token offers to be matched against. The values are the actual content that gets mixed together. Token \(i\)’s attention output is a weighted average of the values, with weights decided by how well its query matches each key. Stacking the first \(i\) tokens’ keys into a matrix \(K\) and their values into \(V\) (one row per token):
\[\underset{1 \times d_{\text{head}}}{\text{out}_i} \;=\; \underbrace{\text{softmax}\!\left(\underset{1 \times d_{\text{head}}}{q_i}\;\underset{d_{\text{head}} \times i}{K^{\top}} \,/\, \sqrt{d_{\text{head}}}\right)}_{\text{attention weights: } 1 \times i} \;\; \underset{i \times d_{\text{head}}}{V}\]Read the shapes left to right and the mechanism is plain: one query row against \(i\) key rows gives \(i\) scores, softmax turns them into weights that sum to 1, and those weights blend \(i\) value rows into a single output row.
Two structural facts matter for everything below:
- The model is causal. A mask ensures position \(i\) can only attend to positions \(1, \dots, i\); the future is invisible. This is usually presented as a training necessity, but it is what makes caching possible at all.
- This happens many times in parallel. Everything above describes one attention head. A layer runs \(n_{\text{heads}}\) of them side by side, each with its own \(W_Q, W_K, W_V\) and its own slice of width \(d_{\text{head}} = d_{\text{model}} / n_{\text{heads}}\), and the model stacks \(n_{\text{layers}}\) such layers. So every head in every layer has its own keys and values. Remember this when the memory bill arrives.
For concreteness, here is the notation with Llama 2 7B’s actual values, since it recurs throughout the post:
- \(d_{\text{model}} = 4096\): the width of a token’s hidden vector.
- \(n_{\text{heads}} = 32\): attention heads per layer.
- \(d_{\text{head}} = d_{\text{model}} / n_{\text{heads}} = 128\): the width of one head’s q, k and v vectors.
- \(n_{\text{layers}} = 32\): stacked transformer layers, each with its own attention.
- \(t\): tokens in the sequence so far. Grows by one with every generated token.
The wasted work
Watch what the naive generation loop actually does. To generate token 7, you feed in tokens 1-6 and run the full forward pass: project all six tokens to queries, keys and values, form the full attention matrix, and take the prediction from the last position. To generate token 8, you feed in tokens 1-7 and do it all again.
Here is the problem: the keys and values you just computed for tokens 1-6 are identical to the ones you computed one step ago. Token 3’s key depends only on token 3’s hidden state, which depends only on tokens 1-3, which have not changed. Causality means nothing about a past token’s representation is affected by tokens that come after it. Every step, you are recomputing an ever-growing pile of results you already had:
The cost of this redundancy compounds. At step \(t\), the naive approach projects \(t\) tokens and computes a \(t \times t\) attention matrix, so generating \(n\) tokens costs \(O(n^2)\) projections and up to \(O(n^3)\) attention work overall. Generation gets progressively slower the longer the response runs, which is precisely the wrong behavior for a chat product.
The fix is the obvious one once you see the picture: store the keys and values the first time you compute them. That store, one \(K\) matrix and one \(V\) matrix per attention head per layer, growing by one row per generated token, is the KV cache.
Why keys and values, but not queries
The natural follow-up is why the cache holds exactly two of the three projections. The answer falls out of looking at what a single decode step actually needs.
When the model generates token \(t+1\), the only attention output that matters is the one at position \(t\), the newest token. Positions \(1, \dots, t-1\) already produced their outputs at earlier steps; recomputing them would produce results you would immediately throw away. And the newest token’s attention needs three things: its own query, everyone’s keys, and everyone’s values.
This is the asymmetry that makes the design click:
- Keys and values are producers. Token 3’s key and value will be read again at step 4, step 5, and every step after that, unchanged each time. Computing them once and storing them pays off for the rest of the generation.
- Queries are consumers. Token 3’s query was used exactly once, at step 3, to compute token 3’s output. No future step ever looks at it again. Caching it would be storing garbage.
So each decode step does a tiny amount of new work: project the single newest token to \(q_t, k_t, v_t\), append \(k_t\) and \(v_t\) to the cache, and compute one row of attention, a \(1 \times t\) score vector instead of a \(t \times t\) matrix. Per-step cost drops from quadratic in sequence length to linear.
This is not an approximation. A cached and an uncached forward pass compute mathematically identical outputs, and implementations verify this with token-identical generations (Raschka’s from-scratch build does exactly this check). The KV cache trades memory for compute and changes nothing about model quality. Variants that shrink the cache after the fact (eviction, sliding windows, quantization) are approximations; the cache itself is exact.
One more subtlety worth internalizing: there is not “a” KV cache, there is one per attention head in every layer. A 32-layer model with 32 heads maintains 1,024 little K/V stores, all growing in lockstep. That per-layer, per-head multiplication is exactly why the memory bill below gets steep.
Prefill and decode: the two phases of inference
The cache also explains a rhythm you have felt in every chatbot: the pause before the first token, then the steady stream after it. With a KV cache, inference has two distinct phases with completely different performance characters.
Prefill processes the entire prompt in a single parallel forward pass. All prompt tokens’ queries, keys and values are computed at once as large matrix-matrix multiplications, and the K/V results are written into the cache in bulk. GPUs love this: it is dense, parallel work, and the phase is compute-bound. Its duration is what you experience as time-to-first-token; during that pause, the model is literally building the KV cache for your prompt.
Decode is everything after: one token per forward pass. Each step is a skinny matrix-vector computation, but it has to read the entire cache (plus the model weights) from GPU memory to produce one token. Arithmetic is cheap; moving bytes is not. Decode is memory-bandwidth-bound, and inter-token latency scales with how much cache there is to stream. This is why long conversations do not just cost more memory, they generate slower per token, and why so much of the optimization effort below is really about shrinking the number of bytes decode has to touch every step.
How much does the cache actually buy you? Measured end-to-end on small models, three independent write-ups land on the same order: Hugging Face measured 11.7s vs 61s for 300 tokens from SmolLM2-1.7B on a T4 (5.2x), Raschka got ~5x on a 124M model on CPU, and Daily Dose of DS about 4.5x. The multiplier grows with generation length, because the redundant work it eliminates grows quadratically.
The bill: doing the memory math
The cache is not free. Its size is worth being able to compute on a napkin, because it drives a remarkable amount of LLM economics.
Each cached entry is one key vector plus one value vector, per token, per KV head, per layer. In half precision (2 bytes per element):
\[\text{bytes per token} = \underbrace{2}_{K \text{ and } V} \times\; n_{\text{layers}} \times n_{\text{kv heads}} \times d_{\text{head}} \times \underbrace{2}_{\text{fp16}}\]and the total cache is that times sequence length times batch size. Two things about this formula deserve attention: the model’s width enters only through the KV heads, and the batch size multiplies everything.
Worked example, Llama 2 7B (32 layers, 32 KV heads of dimension 128):
\[2 \times 32 \times 32 \times 128 \times 2 = 524{,}288 \text{ bytes} = 0.5 \text{ MiB per token}\]Half a megabyte per token sounds harmless until you scale the context: a 4,096-token conversation is 2 GiB of cache, and a (hypothetical, for this model) 128k context would be 64 GiB, roughly five times the ~13 GB the weights themselves occupy in FP16. Serve 32 concurrent users at 4k context and the cache alone is 64 GiB; it is routinely the cache, not the model, that caps how many requests fit on a GPU.
The plot holds two genuinely counterintuitive facts. First, cache size is unrelated to parameter count: Llama 3 70B stores less cache per token (0.31 MiB) than Llama 2 7B (0.5 MiB), because it shares each of its 8 KV heads across 8 query heads. Second, DeepSeek-V3, a 671B-parameter model, has the smallest cache on the chart (~69 KiB per token), because it compresses K and V into a small latent vector before caching. Both are the “shrinking the cache” section below in action.
A common miscalculation. Several popular explainers compute cache size as \(2 \times n_{\text{layers}} \times d_{\text{model}} \times 2\), using the full hidden size. That was correct in the multi-head attention era, but for any grouped-query model it overstates the cache by the ratio of query heads to KV heads: 4x for Llama 3 8B (32 vs 8), 8x for Llama 3 70B (64 vs 8). The number that matters is \(n_{\text{kv heads}} \times d_{\text{head}}\), which since roughly 2023 is deliberately much smaller than \(d_{\text{model}}\). Always check
num_key_value_headsin the model config before trusting a blog post’s table (this one included).
Implementing one from scratch
The mechanism is simple enough to fit in a handful of lines. Inside an attention module, the cache is just two tensors that grow along the sequence dimension:
class KVCache:
"""Per-layer cache: K and V of shape [batch, n_kv_heads, seq, head_dim]."""
def __init__(self):
self.k = None
self.v = None
def update(self, k_new, v_new):
if self.k is None:
self.k, self.v = k_new, v_new
else:
self.k = torch.cat([self.k, k_new], dim=2)
self.v = torch.cat([self.v, v_new], dim=2)
return self.k, self.v
The attention forward pass changes in one place: instead of using the keys and values it just computed, it uses the cache’s running versions.
def forward(self, x, cache=None, pos_start=0):
# x is [batch, seq, d_model]; during decode, seq is just 1
q = self.W_q(x) # only the new token's query
k = self.W_k(x)
v = self.W_v(x)
q, k, v = self.split_heads(q, k, v)
q, k = self.apply_rope(q, k, pos_start) # positions must keep counting up
if cache is not None:
k, v = cache.update(k, v) # k, v now cover the whole sequence
scores = q @ k.transpose(-2, -1) / math.sqrt(self.head_dim)
out = torch.softmax(scores, dim=-1) @ v # [batch, heads, 1, head_dim]
return self.W_o(self.merge_heads(out))
And the generation loop stops re-feeding the whole sequence. After prefill, each step passes only the newest token:
logits = model.prefill(prompt_ids) # parallel pass, fills the cache
next_id = sample(logits[:, -1])
for _ in range(max_new_tokens):
logits = model(next_id, use_cache=True) # one token in, one token out
next_id = sample(logits[:, -1])
In Hugging Face Transformers all of this is on by default (use_cache=True); you have been using a KV cache all along. But if you build it yourself, three gotchas account for most of the bugs:
- Positions must keep counting. The token you feed at decode step 5 is at position
prompt_len + 5, not position 0. If RoPE or your positional encoding restarts from zero, attention still “works” numerically and the output is quietly wrong. Track an explicit offset. - Reset between sequences. Reuse the cache across two unrelated generations and the new prompt’s queries will happily attend to the previous conversation’s keys, producing surreal output. A cache needs an explicit
reset(). torch.catin a loop is slow. Reallocating and copying the cache every step fragments memory and can eat the entire speedup on GPU for small models; Raschka measured cached generation losing to uncached in exactly this setup. Production implementations pre-allocate the cache at maximum length and write into slices instead; that is also what unlockstorch.compile.
Shrinking the cache
Once you see that the cache is often the binding constraint on serving, a whole slice of modern LLM design snaps into focus as answers to one question: how do we make this thing smaller? The approaches sort neatly by what they are willing to sacrifice.
Fewer KV heads: MQA and GQA
In classic multi-head attention every query head has a private key and value head, so the cache scales with the full head count. But heads’ K/V content is redundant enough to share. Multi-query attention (MQA) keeps one K/V head for all query heads, shrinking the cache by the head count (32x for a 32-head model) at some quality cost. Grouped-query attention (GQA) is the compromise that won: query heads are partitioned into groups, each sharing one KV head. Llama 3 70B runs 64 query heads against 8 KV heads, an 8x cache reduction at near-MHA quality, and GQA is now the default in Llama 3, Mistral, Qwen and most contemporaries.
Compress what you store: MLA
DeepSeek’s multi-head latent attention (MLA) attacks the other factor: instead of caching fewer heads, cache a compressed representation. K and V are jointly projected down into a low-rank latent vector (512 dimensions in DeepSeek-V2/V3, plus a small 64-dim decoupled key for RoPE), and that latent is what gets cached; per-head keys and values are reconstructed from it on the fly. The DeepSeek-V2 paper reports a 93.3% cache reduction and 5.76x generation throughput versus their prior 67B model (a baseline that already used GQA, which makes the reduction more impressive, not less), all while matching or beating full MHA quality. It is the rare optimization that gives nothing up on the quality axis. It is why 671B-parameter DeepSeek-V3 sits at the bottom of the memory chart above.
Store less history: sliding windows and eviction
If the architecture is fixed, you can bound what you keep. Sliding-window attention caps the cache at the last \(w\) tokens, making cache size constant beyond the window; implementation-wise it is a one-line truncation of the cache tensor. Mistral 7B v0.1 shipped with a 4,096-token window (later versions dropped it for full attention over 32k). More surgical eviction methods exploit the empirical fact that attention mass concentrates on a small set of tokens: H2O keeps only the “heavy hitters” plus recent tokens (5-10x memory reduction), and SnapKV similarly prunes before long generations. These do change outputs; they are approximations bought with accuracy risk, usually small but nonzero.
Cheaper bytes: quantization
The cache does not need the same precision as the weights’ matmuls. FP8 halves it relative to BF16; aggressive 2-bit schemes like KIVI report 2.6x peak-memory reduction and 2.35-3.47x throughput with negligible accuracy loss on their benchmarks. Since decode is bandwidth-bound, halving the bytes read per step also directly speeds up token generation, a rare optimization that saves memory and latency together.
Waste no bytes: PagedAttention and prefix caching
Finally, the systems layer. Early servers pre-allocated each request’s cache contiguously at maximum length; the vLLM paper profiled such systems and found only 20-38% of cache memory held actual token states; the rest was padding and fragmentation. Its PagedAttention manages the cache the way an OS manages virtual memory, in small fixed-size blocks allocated on demand, recovering that waste for a 2-4x throughput gain with zero model change.
Paging also makes cache sharing natural. Prefix caching stores the KV blocks of common prompt prefixes (system prompts, few-shot examples, earlier turns of a conversation) and reuses them across requests, skipping their prefill entirely. When an LLM API offers “prompt caching” at a steep discount for repeated prompt prefixes, this is literally what you are buying: the provider is skipping the compute to rebuild those KV entries. The cache stopped being an implementation detail and became a line item on your bill.
What to take away
The KV cache is the kind of idea that looks trivial in hindsight: causality makes the past immutable, so store it instead of recomputing it. But following its consequences explains a striking amount of how modern LLM systems behave and are built:
- The pause before the first token is prefill building the cache; the steady stream after is decode extending it one row at a time.
- Long chats slow down and cost more because every generated token must stream the entire cache through memory bandwidth.
- Architectures changed shape because of it: GQA and MLA exist because KV heads, not parameters, set the marginal cost of context. A 70B model can be cheaper per context token than a 7B one.
- Serving economics follow it: batch size multiplies the cache, PagedAttention exists to stop wasting it, and prompt-caching discounts are providers passing along the prefill they skipped.
A useful mental model to leave with: an LLM’s weights are its education, but its KV cache is its working memory of your conversation, rebuilt for every request, rented by the token, and fought over by every layer of the stack.
Sources and further reading
- KV Caching Explained: Optimizing Transformer Inference Efficiency — Hugging Face community post with the causal-masking intuition and the SmolLM2 5.2x benchmark.
- Understanding and Coding the KV Cache in LLMs from Scratch — Sebastian Raschka’s implementation walkthrough; the source of the register-buffer pattern, the pre-allocation and reset gotchas, and the token-identical verification.
- KV Caching in LLMs, Explained Visually — Daily Dose of DS visual explainer, including the why-queries-are-not-cached argument and the time-to-first-token observation.
- KV Cache Optimization Strategies for Scalable and Efficient LLM Inference — a 2026 survey (Xu, Khaira & Singh) with a five-way taxonomy of cache optimizations and the comparison tables behind the eviction/quantization numbers here. Industry survey, v1, not peer-reviewed.
- DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model — where multi-head latent attention is introduced, with the 93.3% cache-reduction result.
- Efficient Memory Management for Large Language Model Serving with PagedAttention — the vLLM paper; the OS-paging analogy for cache management.
- Mastering LLM Techniques: Inference Optimization — NVIDIA’s overview of prefill vs decode and the compute-bound/bandwidth-bound distinction.
- How GPT, Claude, and Gemini are actually trained and served — Reiner Pope on the Dwarkesh Patel podcast, linked from the moment the conversation turns to the KV cache. Well worth a watch for how the ideas in this post play out in real frontier-scale serving.
