老学庵

天行健,君子以自强不息;地势坤,君子以厚德载物!

0%

LLM之KVCache

  在Transformer架构的生成式模型(如GPT系列)中,推理过程需要逐个生成token。传统方式每次生成都需重新计算所有历史token的注意力信息,导致计算复杂度达到O(n²)。KV Cache技术通过缓存历史token的Key和Value矩阵,将后续生成的计算复杂度降至O(n),实现推理加速,本质上是一种用空间来换取时间的加速策略。

1.自注意力机制的冗余计算

  GPT系列模型是经典的自回归模型,在推理过程中,一次推理只输出一个token,输出token会与输入tokens拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符。以GPT2为例,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model = GPT2LMHeadModel.from_pretrained("gpt2", torchscript=True).eval()

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
in_text = "Hello, how "
in_tokens = torch.tensor(tokenizer.encode(in_text))

token_eos = torch.tensor([198])
out_token = None
i = 0
with torch.no_grad():
while out_token != token_eos:
logits, _ = model(in_tokens)
out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
in_tokens = torch.cat((in_tokens, out_token), 0)
text = tokenizer.decode(in_tokens)
print(f'step {i} input: {text}', flush=True)
i += 1

out_text = tokenizer.decode(in_tokens)
print(f'Input: {in_text}')
print(f'Output: {out_text}')
---

执行输出结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
step 0  input: Hello, how   
step 1 input: Hello, how do
step 2 input: Hello, how do you
step 3 input: Hello, how do you feel
step 4 input: Hello, how do you feel about
step 5 input: Hello, how do you feel about the
step 6 input: Hello, how do you feel about the new
step 7 input: Hello, how do you feel about the new season
step 8 input: Hello, how do you feel about the new season of
step 9 input: Hello, how do you feel about the new season of The
step 10 input: Hello, how do you feel about the new season of The Walking
step 11 input: Hello, how do you feel about the new season of The Walking Dead
step 12 input: Hello, how do you feel about the new season of The Walking Dead?
step 13 input: Hello, how do you feel about the new season of The Walking Dead?

Input: Hello, how
Output: Hello, how do you feel about the new season of The Walking Dead?

2. KV Cache工作流程

  显然上面的计算过程中存在大量重复计算,在上面的推理过程中,每step内,输入一个token序列,经过Embedding层将输入token序列变为一个三维张量\([b, s, h]\),经过一通计算,最后经logits层将计算结果映射至词表空间,输出张量维度为\([b, s, vocab_{size}]\)。当前轮输出token与输入tokens拼接,并作为下一轮的输入tokens,反复多次。可以看出第轮输入\(n\)数据只比第\(n-1\)轮输入数据新增了一个token,其他全部相同!因此第\(n\)轮推理时必然包含了第\(n-1\)轮的部分计算。   基于上面的观察,KV Cache策略缓存当前轮可重复利用的计算结果,下一轮计算时直接读取缓存结果,避免了大量的冗余计算。采用KV Cache后,自回归模型的推理策略分为两个阶段:

  • 预填充阶段

      发生在计算第一个输出token过程中,此时Cache为空,需要为每个 transformer layer 计算并保存key cache和value cache,在输出token时Cache完成填充;FLOPs同KV Cache关闭一致,存在大量gemm操作,推理速度慢。

  • 使用KV Cache阶段

      发生在计算非首个token的过程中,这时Cache已缓存历史kv,每轮推理只需读取Cache,同时将当前轮计算出的新的Key、Value追加写入至Cache;FLOPs降低,gemm变为gemv操作,推理速度相对第一阶段变快,这时属于Memory-bound类型计算。

需要注意的时,由于自注意力机制的特性,在推理时仅需计算当前token的Q,并与缓存中的K/V进行注意力计算。

同样以GPT2为例,代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class GPT2Attention(nn.Module):
def __init__(self):
self.register_buffer('cached_k', torch.zeros(...))
self.register_buffer('cached_v', torch.zeros(...))

def forward(self, x, layer_past=None):
# 计算当前Q/K/V
q, k, v = self.qkv_proj(x).split(...)

if layer_past is not None: # 存在历史缓存
past_k, past_v = layer_past
k = torch.cat([past_k, k], dim=-2) # 序列维度拼接
v = torch.cat([past_v, v], dim=-2)

return (k, v) if use_cache else output
  KV Cache是典型的以空间换时间的策略,显然其通过缓存历史KV的方式,避免了重复计算,使得自回归模型的推理速度提升了2~3倍,尤其在处理长序列时效果显著,显著减少了对计算资源的要求,同时缓存完整的历史Key和Value,确保模型在长文本生成中保持语义一致性。

3.改良与优化

  当然,由于KV Cache的机制,随着序列长度的增加,缓存的Key和Value矩阵会占用大量显存。以一个层数为\(L\),隐藏层维度为\(H\),输入长度为\(S\),输出长度为\(N\)的模型,KV Cache的显存占用为: \[M = 4 * B * L * H * (S + N)\] 其中B为批量大小。为了缓解推理过程中的显存压力,针对性的提出了一些优化策略:

  • 注意力结构改进

      采用多查询注意力(Multi-Query Attention,MQA)或分组查询注意力(Grouped-Query Attention,GQA)减少Key和Value矩阵的参数量。 

  • 滑动窗口机制

      仅缓存最近L个token的Key和Value,结合初始token的“注意力锚点”稳定计算,减少显存占用。 

  • 量化压缩

      对KV Cache进行低精度存储(如FP8),进一步减少显存需求。 

  除了上面几种策略外,基于FlashAttention的优化注意力计算访存模式的技巧也进一步提升推理效率,KV Cache通过空间换时间的策略,已成为大模型推理优化的标配技术。

参考

[1] 大模型推理加速:看图学KV Cache [1] Transformers官方文档