前面几篇已经看过 MiniMind 的 Attention、RoPE 和推理生成。
这篇继续沿着同一条线看一个近两年很重要的 Attention 改造:Multi-Latent Attention,简称 MLA。
MLA 最直接的目标是:
降低推理阶段的 KV Cache 和显存带宽压力,同时尽量保持 Attention 表达能力。
它不是 MiniMind 当前主线已经实现的结构。MiniMind 当前使用的是 GQA,也就是 Query head 多、KV head 少的折中方案。
但 MLA 很适合作为 MiniMind 的实验性扩展,因为它刚好接在这几个主题后面:
MHA / MQA / GQA -> KV Cache -> RoPE -> 长上下文推理成本 -> MLA
1. 为什么 Attention 推理会越来越贵
LLM 推理是自回归生成:
已有 prompt -> 生成下一个 token -> 把新 token 加回上下文 -> 再生成下一个 token
为了避免每一步都重新计算所有历史 token 的 K/V,模型会缓存历史 K/V。
这就是 KV Cache。
标准 Attention 可以写成:
scores = Q @ K^T / sqrt(head_dim) weights = softmax(scores) output = weights @ V
生成第 t 个 token 时,需要拿当前 token 的 Q 去和历史所有 K 做匹配,再用权重汇聚历史所有 V:
当前 Q @ 历史 K cache -> attention weights @ 历史 V cache -> 当前输出
因此长上下文推理的瓶颈经常不是纯计算,而是:
KV Cache 占显存; 每步 decode 都要读取大量历史 K/V; batch 越大、上下文越长,显存带宽越紧张。
2. 从 MHA 到 MQA、GQA
先看 head 设计的演化。
MHA
传统 Multi-Head Attention 中,每个 head 都有自己的 K/V:
Q0 -> K0 / V0 Q1 -> K1 / V1 Q2 -> K2 / V2 ...
优点是表达能力强。
缺点是 KV Cache 大:
seq_len * num_heads * head_dim * 2
最后的 2 分别对应 K 和 V。
MQA
Multi-Query Attention 让多个 Q head 共用一组 K/V:
Q0 \ Q1 \ Q2 -> shared K / V Q3 /
它能大幅减少 KV Cache,但共享太狠,表达能力可能下降。
GQA
Grouped-Query Attention 是折中:
Q0, Q1 -> KV group 0 Q2, Q3 -> KV group 1 Q4, Q5 -> KV group 2 Q6, Q7 -> KV group 3
MiniMind 默认就是 GQA:
num_attention_heads = 8 num_key_value_heads = 4 head_dim = 96
也就是:
Q 总维度 = 8 * 96 = 768 K 总维度 = 4 * 96 = 384 V 总维度 = 4 * 96 = 384
对应代码在 model/model_minimind.py:
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
缓存时存的是:
past_kv = (xk, xv) if use_cache else None
所以 MiniMind 当前的 KV Cache 大小大致是:
seq_len * num_key_value_heads * head_dim * 2
默认配置下,每层每个 token 需要缓存:
4 * 96 * 2 = 768 个数
3. MLA 的核心想法
MQA/GQA 的思路是:
减少 KV head 数。
MLA 的思路更进一步:
不直接缓存完整 K/V,而是缓存更小的 latent 表示。
可以把传统 Attention 看成:
hidden state -> K / V -> cache full K / V
MLA 改成:
hidden state -> compressed latent -> cache latent -> 需要时从 latent 恢复 K / V
也就是:
token -> full KV -> cache
变成:
token -> compressed latent -> cache
这就是 Multi-Latent Attention 最重要的直觉。
它有点像把 KV 做了一次低秩分解:
K = W_up_k(W_down_kv h) V = W_up_v(W_down_kv h)
其中:
c = W_down_kv h
就是要缓存的 latent。
如果 latent_dim 远小于完整 K/V 维度,KV Cache 就会明显下降。
4. 一个最小可理解版 MLA
先不考虑 RoPE,只看最简单的结构:
class SimpleMLA(nn.Module):
def __init__(self, hidden_size, num_heads, head_dim, latent_dim):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
# hidden -> latent
self.kv_down_proj = nn.Linear(hidden_size, latent_dim, bias=False)
# latent -> K / V
self.k_up_proj = nn.Linear(latent_dim, num_heads * head_dim, bias=False)
self.v_up_proj = nn.Linear(latent_dim, num_heads * head_dim, bias=False)
self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
def forward(self, x, past_latent=None):
bsz, seq_len, _ = x.shape
q = self.q_proj(x)
latent = self.kv_down_proj(x)
if past_latent is not None:
kv_latent = torch.cat([past_latent, latent], dim=1)
else:
kv_latent = latent
k = self.k_up_proj(kv_latent)
v = self.v_up_proj(kv_latent)
q = q.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
out = out.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
return self.o_proj(out), kv_latent
这段代码里最关键的是:
latent = self.kv_down_proj(x)
传统 Attention 缓存:
K cache + V cache
MLA 缓存:
latent cache
需要 K/V 时再恢复:
k = self.k_up_proj(kv_latent) v = self.v_up_proj(kv_latent)
这个版本适合理解 MLA 的主干,但还不是 DeepSeek 真实 MLA。
真实 MLA 还要处理:
RoPE 怎么放; Q 是否也做低秩压缩; 如何避免每步显式恢复完整 K/V; 如何兼容高性能推理 kernel。
5. MLA 和 LoRA 为什么看起来很像
MLA 很容易让人联想到 LoRA。
LoRA 的核心是:
W + B @ A
也就是用一个低秩旁路表达参数更新。
MLA 的 K/V 生成可以看成:
K = W_up_k @ W_down_kv @ h V = W_up_v @ W_down_kv @ h
两者都用了:
低秩; bottleneck; latent compression。
但目标不同:
| 技术 | 目标 |
|---|---|
| LoRA | 降低微调参数量和训练成本。 |
| MLA | 降低推理 KV Cache 和带宽成本。 |
所以它们只是思想相近,不是同一个用途。
6. MLA 和 RoPE 的关键关系
如果没有位置编码,前面的简单 MLA 很好理解。
但 MiniMind 和大多数现代 LLM 一样使用 RoPE。
MiniMind 当前是在 Q/K 上做旋转:
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
RoPE 的位置关系来自:
RoPE(Q) @ RoPE(K)^T
问题是:MLA 缓存的是 latent,不是完整 K。
如果把完整 K 都压到 latent 里,再恢复后做 RoPE,会遇到一个工程矛盾:
RoPE 和 token 位置绑定; latent 希望位置无关、便于压缩和缓存。
DeepSeek MLA 的处理方式是把 K 拆成两部分:
K = concat(K_nope, K_rope)
其中:
K_nope:不带 RoPE 的主要内容部分,可以从 latent 恢复。 K_rope:带 RoPE 的位置部分,维度较小,单独生成和缓存。
Q 也对应拆成:
Q = concat(Q_nope, Q_rope)
Attention score 变成:
score = Q_nope @ K_nope^T + RoPE(Q_rope) @ RoPE(K_rope)^T
这样做的好处是:
内容主干仍然走 latent 压缩; 位置信息保留在较小的 rope 维度里; KV Cache 不需要回到完整 K/V 的大小。
这也是 MLA 比最小教学版复杂的主要原因。
7. 结合 MiniMind 看 MLA 要改哪里
MiniMind 当前 Attention.forward 主线可以概括成:
x -> q_proj / k_proj / v_proj -> view 成多头 -> q_norm / k_norm -> RoPE(Q, K) -> 拼接 past K/V -> repeat_kv -> scores = Q @ K^T / sqrt(head_dim) -> softmax -> weights @ V -> o_proj
如果新增 MLA,可以先不要改动训练框架、数据集和生成逻辑,而是在 model/model_minimind.py 里新增一个 Attention 变体。
建议分三步做。
8. 第一步:给 Config 增加 MLA 开关
在 MiniMindConfig 中增加:
self.use_mla = kwargs.get("use_mla", False)
self.mla_kv_lora_rank = kwargs.get("mla_kv_lora_rank", 128)
self.mla_q_lora_rank = kwargs.get("mla_q_lora_rank", 0)
self.mla_qk_rope_head_dim = kwargs.get("mla_qk_rope_head_dim", 32)
self.mla_v_head_dim = kwargs.get("mla_v_head_dim", self.head_dim)
self.mla_qk_nope_head_dim = kwargs.get(
"mla_qk_nope_head_dim",
self.head_dim - self.mla_qk_rope_head_dim
)
这里的名字参考了 DeepSeek 公开实现里的习惯:
kv_lora_rank:KV latent 维度。 q_lora_rank:Q 是否也走低秩压缩。 qk_nope_head_dim:不带 RoPE 的 Q/K 维度。 qk_rope_head_dim:带 RoPE 的 Q/K 维度。 v_head_dim:V 的 head 维度。
对 MiniMind 默认 head_dim=96,可以先试:
qk_nope_head_dim = 64 qk_rope_head_dim = 32 v_head_dim = 96 kv_lora_rank = 128
9. 第二步:新增 MLAAttention
一个贴近 MiniMind 风格的 naive MLA 可以这样设计:
class MLAAttention(nn.Module):
def __init__(self, config: MiniMindConfig):
super().__init__()
self.n_local_heads = config.num_attention_heads
self.qk_nope_head_dim = config.mla_qk_nope_head_dim
self.qk_rope_head_dim = config.mla_qk_rope_head_dim
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
self.v_head_dim = config.mla_v_head_dim
self.kv_lora_rank = config.mla_kv_lora_rank
self.q_proj = nn.Linear(config.hidden_size, self.n_local_heads * self.qk_head_dim, bias=False)
self.kv_down_proj = nn.Linear(config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False)
self.kv_norm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_up_proj = nn.Linear(
self.kv_lora_rank,
self.n_local_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False
)
self.o_proj = nn.Linear(self.n_local_heads * self.v_head_dim, config.hidden_size, bias=False)
这和当前 Attention 最大的区别是:
没有单独的 k_proj / v_proj; 先用 kv_down_proj 得到 latent; 再用 kv_up_proj 从 latent 恢复 K_nope 和 V; K_rope 单独从 kv_down_proj 的尾部切出来。
forward 主线可以写成:
q = self.q_proj(x).view(bsz, seq_len, self.n_local_heads, self.qk_head_dim)
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
kv = self.kv_down_proj(x)
kv_latent, k_rope = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_latent = self.kv_norm(kv_latent)
q_rope, k_rope = apply_rotary_pos_emb(
q_rope,
k_rope.unsqueeze(2),
cos,
sin
)
k_rope = k_rope.squeeze(2)
if past_key_value is not None:
kv_latent = torch.cat([past_key_value[0], kv_latent], dim=1)
k_rope = torch.cat([past_key_value[1], k_rope], dim=1)
past_kv = (kv_latent, k_rope) if use_cache else None
注意这里 cache 不再是:
(xk, xv)
而是:
(kv_latent, k_rope)
这就是 MiniMind 接入 MLA 后最关键的行为变化。
10. 第三步:先做 naive 版本,再做 absorb 优化
为了先跑通训练和推理,建议第一版 MLA 显式恢复 K/V:
kv_full = self.kv_up_proj(kv_latent)
kv_full = kv_full.view(
bsz,
-1,
self.n_local_heads,
self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = torch.split(kv_full, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat(
[k_nope, k_rope.unsqueeze(2).expand(-1, -1, self.n_local_heads, -1)],
dim=-1
)
q = torch.cat([q_nope, q_rope], dim=-1)
然后继续使用普通 attention:
q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.qk_head_dim) weights = F.softmax(scores.float(), dim=-1).type_as(q) output = weights @ v
这个 naive 版本已经能做到:
推理 cache 只存 latent + 小维度 rope key; 结构上验证 MLA 是否可训练; 代码路径比较容易调试。
但它还没有完全释放 MLA 的推理收益,因为每步 attention 前仍然恢复了完整 K/V。
进一步优化时,可以学习 DeepSeek 的 absorb 思路:
不显式恢复 K_nope; 把 kv_up_proj 中 K_nope 对应的权重吸收到 Q_nope 上; 直接用 Q_nope' 和 latent cache 计算 score; V 侧也延迟到 score @ latent 后再投影回每个 head 的 V 空间。
也就是把:
Q_nope @ (W_k_up latent)^T
改写成:
(Q_nope @ W_k_up) @ latent^T
这是 MLA 真正适合高性能推理的地方。
11. MiniMind 默认配置下能省多少 cache
当前 MiniMind GQA 每层每 token 的 KV cache 约为:
num_key_value_heads * head_dim * 2 = 4 * 96 * 2 = 768
如果 MLA 取:
kv_lora_rank = 128 qk_rope_head_dim = 32
那么每层每 token cache 约为:
kv_lora_rank + qk_rope_head_dim = 128 + 32 = 160
压缩比例约为:
160 / 768 = 20.8%
也就是 KV Cache 理论上下降约:
79.2%
这只是按元素数估算,真实收益还会受 dtype、kernel、batch size、显存访问模式、是否显式恢复 K/V 等因素影响。
12. 需要注意的工程问题
如果在 MiniMind 中真正新增 MLA,需要注意几个点。
RoPE buffer 维度
当前 MiniMind 预计算 RoPE 时使用的是:
precompute_freqs_cis(dim=config.head_dim, ...)
MLA 中 RoPE 只作用在 qk_rope_head_dim 上。
因此需要让 RoPE buffer 支持 MLA 维度:
use_mla=False:dim = head_dim use_mla=True: dim = mla_qk_rope_head_dim
否则 q_rope / k_rope 的最后一维和 cos / sin 对不上。
causal mask
当前 MiniMind 在有 cache 时会让当前 query 看完整历史:
scores[:, :, :, -seq_len:] += causal_mask
MLA 仍然需要同样的 mask 逻辑。
变化的是 K/V 来源,不是 causal LM 的可见性规则。
FlashAttention
naive MLA 恢复完整 K/V 后,理论上仍可以复用 scaled_dot_product_attention。
但 optimized MLA 不再显式构造标准 K/V,普通 FlashAttention kernel 未必能直接吃这种 latent cache。
所以第一版可以先关闭或绕开 flash 路径,等结构正确后再优化。
checkpoint 兼容性
MLA 会改变 attention 权重名字和形状。
原来的:
q_proj / k_proj / v_proj / o_proj
会变成:
q_proj / kv_down_proj / kv_up_proj / o_proj
因此旧的 MiniMind 权重不能直接完整加载。
通常有三种选择:
从头预训练 MLA 版本; 写权重迁移脚本做近似初始化; 只把 MLA 作为新实验分支,不兼容旧 checkpoint。
对教学项目来说,第一种最清晰。
训练脚本参数
trainer/train_pretrain.py 当前创建配置时是:
lm_config = MiniMindConfig(
hidden_size=args.hidden_size,
num_hidden_layers=args.num_hidden_layers,
use_moe=bool(args.use_moe)
)
如果要训练 MLA,需要加上命令行参数:
--use_mla --mla_kv_lora_rank --mla_qk_rope_head_dim
并传入 MiniMindConfig。
SFT、DPO、PPO、GRPO 等训练脚本如果也要支持 MLA,也需要同样补配置。
13. MLA 的代价
MLA 不是无脑替换。
它带来的好处主要在:
长上下文; 大 batch 推理; 高并发服务; KV Cache 成为瓶颈的场景。
代价也很明确:
结构更复杂; 参数形状不再兼容原 checkpoint; latent_dim 太小会损失表达能力; 训练稳定性需要重新验证; optimized MLA 需要专门推理实现才能吃满收益。
所以对 MiniMind 这种小模型教学项目,比较合适的路线是:
先实现 naive MLA,确认 loss 能下降; 再比较相同训练配置下的 perplexity / 生成效果; 最后再考虑 absorb 优化和推理 benchmark。
14. 小结
可以用一句话概括 MLA:
用低维 latent 压缩 Attention 的历史 K/V 表示,从而降低 KV Cache 和推理带宽成本。
它和 MQA/GQA 的区别是:
MQA/GQA:减少 KV head 数。 MLA:压缩 KV 内容本身。
结合 MiniMind 来看:
当前 Attention 缓存完整 xk / xv; MLA 版本应缓存 kv_latent / k_rope; RoPE 需要拆成 nope 部分和 rope 部分; 第一版建议显式恢复 K/V,先保证工程可读、可训、可推理。
理解 MLA 的关键不是背公式,而是抓住这条工程主线:
KV Cache 越大,长上下文推理越贵; GQA 通过减少 KV head 缓解; MLA 通过压缩 KV 表示进一步缓解; RoPE 位置部分需要单独处理。
参考资料
- DeepSeek-V2 论文:DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model
- DeepSeek-V3 官方推理实现:deepseek-ai/DeepSeek-V3 inference/model.py
- MiniMind 当前 Attention 实现:
model/model_minimind.py