Attention 是 LLM 里最核心、也最容易让初学者卡住的模块。
MiniMind 的 Attention 代码同时包含了几个现代 LLM 常见设计:
- Q/K/V 投影。
- RoPE 旋转位置编码。
- causal mask。
- padding attention mask。
- FlashAttention fallback。
- GQA:
num_attention_heads和num_key_value_heads可以不同。 - Q/K Norm。
这篇先聚焦 Attention 主计算和 head 设计,RoPE 单独放到下一篇。
1. Attention 的目标
Attention 的目标是让每个 token 从上下文中取信息。
对于当前位置 token,可以理解成三个问题:
Q:我想找什么? K:每个历史 token 提供什么索引? V:每个历史 token 真正提供什么内容?
最终计算过程是:
scores = Q @ K^T / sqrt(head_dim) weights = softmax(scores + mask) output = weights @ V
也就是:
先算相关性; 再归一化成权重; 最后按权重加权求和内容。
2. Q/K/V 的形状
MiniMind 默认:
hidden_size = 768 num_attention_heads = 8 num_key_value_heads = 4 head_dim = 96
因此:
Q 总维度 = 8 * 96 = 768 K 总维度 = 4 * 96 = 384 V 总维度 = 4 * 96 = 384
这就是为什么代码里:
q_proj: 768 -> 768 k_proj: 768 -> 384 v_proj: 768 -> 384
你调试时看到:
xq: [1, 6, 8, 96] xk: [1, 6, 4, 96] xv: [1, 6, 4, 96]
这里:
1 = batch size 6 = token 长度 8 = Q head 数 4 = K/V head 数 96 = 每个 head 的维度
3. 为什么 Q head 和 KV head 可以不同
传统 Multi-Head Attention 是:
Q heads = K heads = V heads
也就是每个 Q head 都有独立的一套 K/V。
但在推理阶段,K/V 会被缓存到 KV Cache 里。如果 KV head 很多,显存和带宽压力都会增加。
于是出现了:
| 结构 | Q head | KV head | 特点 |
|---|---|---|---|
| MHA | 多个 | 多个,数量相同 | 表达力强,KV Cache 大。 |
| GQA | 多个 | 较少 | 几个 Q head 共享一组 K/V。 |
| MQA | 多个 | 1 组 | KV Cache 最省,但可能损失表达力。 |
MiniMind 默认是 GQA:
8 个 Q head 4 个 KV head
也就是两个 Q head 共享一组 K/V。
4. repeat_kv 在做什么
计算 attention score 时,Q 和 K 的 head 数需要对齐。
但 MiniMind 中:
Q: 8 heads K: 4 heads V: 4 heads
所以需要把 K/V 在 head 维度上扩展成 8 个 head:
K: 4 heads -> 8 heads V: 4 heads -> 8 heads
这就是 repeat_kv 的作用。
它的逻辑不是产生新的语义信息,而是让多个 Q head 共用同一组 K/V:
Q0, Q1 共用 KV0 Q2, Q3 共用 KV1 Q4, Q5 共用 KV2 Q6, Q7 共用 KV3
可以用一个类比:
Q head 像提问的人; K/V head 像资料库视角。 MHA:每个提问的人都有独立资料库。 GQA:几个提问的人共用一套资料库。 MQA:所有提问的人共用一套资料库。
5. 这样做有什么代价
减少 KV heads 的好处很明确:
KV Cache 更小; 推理带宽压力更低; 长上下文更友好; decode 更快。
但它也有潜在代价:
不同 Q head 能看到的 K/V 视角变少; 表达能力可能下降; 极端情况下,多样性不如完整 MHA。
工业界通常会在质量和效率之间折中。
现代大模型里,GQA 非常常见。它比 MQA 保留更多表达能力,又比 MHA 更省 KV Cache。
6. causal mask 为什么是上三角
LLM 是 causal language model,当前位置不能看未来 token。
如果 score 矩阵按下面方式排列:
行:query 位置 i 列:key 位置 j score[i, j] = q_i · k_j
那么第 i 行只能看:
j <= i
不能看:
j > i
矩阵里 j > i 的区域就是上三角。
所以 causal mask 会把上三角位置加上:
-inf
softmax 后这些位置概率变成 0。
这就是代码里:
torch.full((seq_len, seq_len), float("-inf")).triu(1)
的含义。
7. attention_mask 和 causal mask 不一样
MiniMind 代码里还有:
if attention_mask is not None:
scores += (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -1e9
它和 causal mask 是两回事。
| mask | 作用 |
|---|---|
| causal mask | 防止看到未来 token。 |
| attention_mask | 防止关注 padding token。 |
例如 batch 里有 padding:
真实 token: 1 padding: 0
attention_mask=0 的位置会被加上很大的负数,softmax 后基本不被关注。
8. Q/K Norm 是标准流程吗
MiniMind 里对 Q 和 K 做了 RMSNorm:
xq = self.q_norm(xq) xk = self.k_norm(xk)
原始 Scaled Dot-Product Attention 公式里没有这一步。
所以它不是最原始的标准 Attention 必需项,而是现代模型中越来越常见的稳定性设计。
它的作用是让 Q/K 的尺度更可控,避免 attention score 过大或分布不稳定。
注意它不能替代:
1 / sqrt(head_dim)
两者作用不同:
Q/K Norm:控制 Q/K 向量本身的尺度。 sqrt 缩放:控制点积后 score 的尺度。
9. Attention 计算主线
完整手写路径可以概括为:
hidden_states -> q_proj / k_proj / v_proj -> reshape 成多头 -> q_norm / k_norm -> apply_rotary_pos_emb -> 拼接 past_key_values -> repeat_kv 对齐 head 数 -> scores = Q @ K^T / sqrt(head_dim) -> 加 causal mask -> 加 attention_mask -> softmax -> dropout -> weights @ V -> o_proj
每一步都服务于一个目标:
让当前位置 token 在合法上下文中选择该关注的信息。
10. 小结
Attention 的难点不在公式本身,而在工程实现细节:
- Q/K/V 的 head 数和形状。
- causal mask 的方向。
- padding mask 和 causal mask 的区别。
- GQA/MQA 为什么能省 KV Cache。
repeat_kv为什么只是对齐维度,不是增加新信息。- Q/K Norm 是现代稳定性设计。
先掌握这些,再看 RoPE 和 FlashAttention,会更顺。