MiniMind 学习笔记 06:Attention 详解,QKV、Mask、GQA 与 MQA

Attention 是 LLM 里最核心、也最容易让初学者卡住的模块。

MiniMind 的 Attention 代码同时包含了几个现代 LLM 常见设计:

  • Q/K/V 投影。
  • RoPE 旋转位置编码。
  • causal mask。
  • padding attention mask。
  • FlashAttention fallback。
  • GQA:num_attention_heads 和 num_key_value_heads 可以不同。
  • Q/K Norm。

这篇先聚焦 Attention 主计算和 head 设计,RoPE 单独放到下一篇。

1. Attention 的目标

Attention 的目标是让每个 token 从上下文中取信息。

对于当前位置 token,可以理解成三个问题:

Q:我想找什么?
K:每个历史 token 提供什么索引?
V:每个历史 token 真正提供什么内容?

最终计算过程是:

scores = Q @ K^T / sqrt(head_dim)
weights = softmax(scores + mask)
output = weights @ V

也就是:

先算相关性;
再归一化成权重;
最后按权重加权求和内容。

2. Q/K/V 的形状

MiniMind 默认:

hidden_size = 768
num_attention_heads = 8
num_key_value_heads = 4
head_dim = 96

因此:

Q 总维度 = 8 * 96 = 768
K 总维度 = 4 * 96 = 384
V 总维度 = 4 * 96 = 384

这就是为什么代码里:

q_proj: 768 -> 768
k_proj: 768 -> 384
v_proj: 768 -> 384

你调试时看到:

xq: [1, 6, 8, 96]
xk: [1, 6, 4, 96]
xv: [1, 6, 4, 96]

这里:

1  = batch size
6  = token 长度
8  = Q head 数
4  = K/V head 数
96 = 每个 head 的维度

3. 为什么 Q head 和 KV head 可以不同

传统 Multi-Head Attention 是:

Q heads = K heads = V heads

也就是每个 Q head 都有独立的一套 K/V。

但在推理阶段,K/V 会被缓存到 KV Cache 里。如果 KV head 很多,显存和带宽压力都会增加。

于是出现了:

结构Q headKV head特点
MHA多个多个,数量相同表达力强,KV Cache 大。
GQA多个较少几个 Q head 共享一组 K/V。
MQA多个1 组KV Cache 最省,但可能损失表达力。

MiniMind 默认是 GQA:

8 个 Q head
4 个 KV head

也就是两个 Q head 共享一组 K/V。

4. repeat_kv 在做什么

计算 attention score 时,Q 和 K 的 head 数需要对齐。

但 MiniMind 中:

Q: 8 heads
K: 4 heads
V: 4 heads

所以需要把 K/V 在 head 维度上扩展成 8 个 head:

K: 4 heads -> 8 heads
V: 4 heads -> 8 heads

这就是 repeat_kv 的作用。

它的逻辑不是产生新的语义信息,而是让多个 Q head 共用同一组 K/V:

Q0, Q1 共用 KV0
Q2, Q3 共用 KV1
Q4, Q5 共用 KV2
Q6, Q7 共用 KV3

可以用一个类比:

Q head 像提问的人;
K/V head 像资料库视角。

MHA:每个提问的人都有独立资料库。
GQA:几个提问的人共用一套资料库。
MQA:所有提问的人共用一套资料库。

5. 这样做有什么代价

减少 KV heads 的好处很明确:

KV Cache 更小;
推理带宽压力更低;
长上下文更友好;
decode 更快。

但它也有潜在代价:

不同 Q head 能看到的 K/V 视角变少;
表达能力可能下降;
极端情况下,多样性不如完整 MHA。

工业界通常会在质量和效率之间折中。

现代大模型里,GQA 非常常见。它比 MQA 保留更多表达能力,又比 MHA 更省 KV Cache。

6. causal mask 为什么是上三角

LLM 是 causal language model,当前位置不能看未来 token。

如果 score 矩阵按下面方式排列:

行:query 位置 i
列:key 位置 j
score[i, j] = q_i · k_j

那么第 i 行只能看:

j <= i

不能看:

j > i

矩阵里 j > i 的区域就是上三角。

所以 causal mask 会把上三角位置加上:

-inf

softmax 后这些位置概率变成 0。

这就是代码里:

torch.full((seq_len, seq_len), float("-inf")).triu(1)

的含义。

7. attention_mask 和 causal mask 不一样

MiniMind 代码里还有:

if attention_mask is not None:
    scores += (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -1e9

它和 causal mask 是两回事。

mask作用
causal mask防止看到未来 token。
attention_mask防止关注 padding token。

例如 batch 里有 padding:

真实 token: 1
padding:   0

attention_mask=0 的位置会被加上很大的负数,softmax 后基本不被关注。

8. Q/K Norm 是标准流程吗

MiniMind 里对 Q 和 K 做了 RMSNorm:

xq = self.q_norm(xq)
xk = self.k_norm(xk)

原始 Scaled Dot-Product Attention 公式里没有这一步。

所以它不是最原始的标准 Attention 必需项,而是现代模型中越来越常见的稳定性设计。

它的作用是让 Q/K 的尺度更可控,避免 attention score 过大或分布不稳定。

注意它不能替代:

1 / sqrt(head_dim)

两者作用不同:

Q/K Norm:控制 Q/K 向量本身的尺度。
sqrt 缩放:控制点积后 score 的尺度。

9. Attention 计算主线

完整手写路径可以概括为:

hidden_states
-> q_proj / k_proj / v_proj
-> reshape 成多头
-> q_norm / k_norm
-> apply_rotary_pos_emb
-> 拼接 past_key_values
-> repeat_kv 对齐 head 数
-> scores = Q @ K^T / sqrt(head_dim)
-> 加 causal mask
-> 加 attention_mask
-> softmax
-> dropout
-> weights @ V
-> o_proj

每一步都服务于一个目标:

让当前位置 token 在合法上下文中选择该关注的信息。

10. 小结

Attention 的难点不在公式本身,而在工程实现细节:

  • Q/K/V 的 head 数和形状。
  • causal mask 的方向。
  • padding mask 和 causal mask 的区别。
  • GQA/MQA 为什么能省 KV Cache。
  • repeat_kv 为什么只是对齐维度,不是增加新信息。
  • Q/K Norm 是现代稳定性设计。

先掌握这些,再看 RoPE 和 FlashAttention,会更顺。

发表评论