MiniMind 学习笔记 10:Multi-Latent Attention,从 KV Cache 压缩到 MiniMind 改造思路

前面几篇已经看过 MiniMind 的 Attention、RoPE 和推理生成。

这篇继续沿着同一条线看一个近两年很重要的 Attention 改造:Multi-Latent Attention,简称 MLA。

MLA 最直接的目标是:

降低推理阶段的 KV Cache 和显存带宽压力,同时尽量保持 Attention 表达能力。

它不是 MiniMind 当前主线已经实现的结构。MiniMind 当前使用的是 GQA,也就是 Query head 多、KV head 少的折中方案。

但 MLA 很适合作为 MiniMind 的实验性扩展,因为它刚好接在这几个主题后面:

MHA / MQA / GQA
-> KV Cache
-> RoPE
-> 长上下文推理成本
-> MLA

1. 为什么 Attention 推理会越来越贵

LLM 推理是自回归生成:

已有 prompt
-> 生成下一个 token
-> 把新 token 加回上下文
-> 再生成下一个 token

为了避免每一步都重新计算所有历史 token 的 K/V,模型会缓存历史 K/V。

这就是 KV Cache。

标准 Attention 可以写成:

scores = Q @ K^T / sqrt(head_dim)
weights = softmax(scores)
output = weights @ V

生成第 t 个 token 时,需要拿当前 token 的 Q 去和历史所有 K 做匹配,再用权重汇聚历史所有 V:

当前 Q
  @ 历史 K cache
  -> attention weights
  @ 历史 V cache
  -> 当前输出

因此长上下文推理的瓶颈经常不是纯计算,而是:

KV Cache 占显存;
每步 decode 都要读取大量历史 K/V;
batch 越大、上下文越长,显存带宽越紧张。

2. 从 MHA 到 MQA、GQA

先看 head 设计的演化。

MHA

传统 Multi-Head Attention 中,每个 head 都有自己的 K/V:

Q0 -> K0 / V0
Q1 -> K1 / V1
Q2 -> K2 / V2
...

优点是表达能力强。

缺点是 KV Cache 大:

seq_len * num_heads * head_dim * 2

最后的 2 分别对应 K 和 V。

MQA

Multi-Query Attention 让多个 Q head 共用一组 K/V:

Q0 \
Q1  \
Q2   -> shared K / V
Q3  /

它能大幅减少 KV Cache,但共享太狠,表达能力可能下降。

GQA

Grouped-Query Attention 是折中:

Q0, Q1 -> KV group 0
Q2, Q3 -> KV group 1
Q4, Q5 -> KV group 2
Q6, Q7 -> KV group 3

MiniMind 默认就是 GQA:

num_attention_heads = 8
num_key_value_heads = 4
head_dim = 96

也就是:

Q 总维度 = 8 * 96 = 768
K 总维度 = 4 * 96 = 384
V 总维度 = 4 * 96 = 384

对应代码在 model/model_minimind.py

self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)

缓存时存的是:

past_kv = (xk, xv) if use_cache else None

所以 MiniMind 当前的 KV Cache 大小大致是:

seq_len * num_key_value_heads * head_dim * 2

默认配置下,每层每个 token 需要缓存:

4 * 96 * 2 = 768 个数

3. MLA 的核心想法

MQA/GQA 的思路是:

减少 KV head 数。

MLA 的思路更进一步:

不直接缓存完整 K/V,而是缓存更小的 latent 表示。

可以把传统 Attention 看成:

hidden state
-> K / V
-> cache full K / V

MLA 改成:

hidden state
-> compressed latent
-> cache latent
-> 需要时从 latent 恢复 K / V

也就是:

token -> full KV -> cache

变成:

token -> compressed latent -> cache

这就是 Multi-Latent Attention 最重要的直觉。

它有点像把 KV 做了一次低秩分解:

K = W_up_k(W_down_kv h)
V = W_up_v(W_down_kv h)

其中:

c = W_down_kv h

就是要缓存的 latent。

如果 latent_dim 远小于完整 K/V 维度,KV Cache 就会明显下降。

4. 一个最小可理解版 MLA

先不考虑 RoPE,只看最简单的结构:

class SimpleMLA(nn.Module):
    def __init__(self, hidden_size, num_heads, head_dim, latent_dim):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim

        self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)

        # hidden -> latent
        self.kv_down_proj = nn.Linear(hidden_size, latent_dim, bias=False)

        # latent -> K / V
        self.k_up_proj = nn.Linear(latent_dim, num_heads * head_dim, bias=False)
        self.v_up_proj = nn.Linear(latent_dim, num_heads * head_dim, bias=False)

        self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)

    def forward(self, x, past_latent=None):
        bsz, seq_len, _ = x.shape

        q = self.q_proj(x)
        latent = self.kv_down_proj(x)

        if past_latent is not None:
            kv_latent = torch.cat([past_latent, latent], dim=1)
        else:
            kv_latent = latent

        k = self.k_up_proj(kv_latent)
        v = self.v_up_proj(kv_latent)

        q = q.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)

        out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        out = out.transpose(1, 2).contiguous().view(bsz, seq_len, -1)

        return self.o_proj(out), kv_latent

这段代码里最关键的是:

latent = self.kv_down_proj(x)

传统 Attention 缓存:

K cache + V cache

MLA 缓存:

latent cache

需要 K/V 时再恢复:

k = self.k_up_proj(kv_latent)
v = self.v_up_proj(kv_latent)

这个版本适合理解 MLA 的主干,但还不是 DeepSeek 真实 MLA。

真实 MLA 还要处理:

RoPE 怎么放;
Q 是否也做低秩压缩;
如何避免每步显式恢复完整 K/V;
如何兼容高性能推理 kernel。

5. MLA 和 LoRA 为什么看起来很像

MLA 很容易让人联想到 LoRA。

LoRA 的核心是:

W + B @ A

也就是用一个低秩旁路表达参数更新。

MLA 的 K/V 生成可以看成:

K = W_up_k @ W_down_kv @ h
V = W_up_v @ W_down_kv @ h

两者都用了:

低秩;
bottleneck;
latent compression。

但目标不同:

技术目标
LoRA降低微调参数量和训练成本。
MLA降低推理 KV Cache 和带宽成本。

所以它们只是思想相近,不是同一个用途。

6. MLA 和 RoPE 的关键关系

如果没有位置编码,前面的简单 MLA 很好理解。

但 MiniMind 和大多数现代 LLM 一样使用 RoPE。

MiniMind 当前是在 Q/K 上做旋转:

xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)

RoPE 的位置关系来自:

RoPE(Q) @ RoPE(K)^T

问题是:MLA 缓存的是 latent,不是完整 K。

如果把完整 K 都压到 latent 里,再恢复后做 RoPE,会遇到一个工程矛盾:

RoPE 和 token 位置绑定;
latent 希望位置无关、便于压缩和缓存。

DeepSeek MLA 的处理方式是把 K 拆成两部分:

K = concat(K_nope, K_rope)

其中:

K_nope:不带 RoPE 的主要内容部分,可以从 latent 恢复。
K_rope:带 RoPE 的位置部分,维度较小,单独生成和缓存。

Q 也对应拆成:

Q = concat(Q_nope, Q_rope)

Attention score 变成:

score = Q_nope @ K_nope^T + RoPE(Q_rope) @ RoPE(K_rope)^T

这样做的好处是:

内容主干仍然走 latent 压缩;
位置信息保留在较小的 rope 维度里;
KV Cache 不需要回到完整 K/V 的大小。

这也是 MLA 比最小教学版复杂的主要原因。

7. 结合 MiniMind 看 MLA 要改哪里

MiniMind 当前 Attention.forward 主线可以概括成:

x
-> q_proj / k_proj / v_proj
-> view 成多头
-> q_norm / k_norm
-> RoPE(Q, K)
-> 拼接 past K/V
-> repeat_kv
-> scores = Q @ K^T / sqrt(head_dim)
-> softmax
-> weights @ V
-> o_proj

如果新增 MLA,可以先不要改动训练框架、数据集和生成逻辑,而是在 model/model_minimind.py 里新增一个 Attention 变体。

建议分三步做。

8. 第一步:给 Config 增加 MLA 开关

在 MiniMindConfig 中增加:

self.use_mla = kwargs.get("use_mla", False)
self.mla_kv_lora_rank = kwargs.get("mla_kv_lora_rank", 128)
self.mla_q_lora_rank = kwargs.get("mla_q_lora_rank", 0)
self.mla_qk_rope_head_dim = kwargs.get("mla_qk_rope_head_dim", 32)
self.mla_v_head_dim = kwargs.get("mla_v_head_dim", self.head_dim)
self.mla_qk_nope_head_dim = kwargs.get(
    "mla_qk_nope_head_dim",
    self.head_dim - self.mla_qk_rope_head_dim
)

这里的名字参考了 DeepSeek 公开实现里的习惯:

kv_lora_rank:KV latent 维度。
q_lora_rank:Q 是否也走低秩压缩。
qk_nope_head_dim:不带 RoPE 的 Q/K 维度。
qk_rope_head_dim:带 RoPE 的 Q/K 维度。
v_head_dim:V 的 head 维度。

对 MiniMind 默认 head_dim=96,可以先试:

qk_nope_head_dim = 64
qk_rope_head_dim = 32
v_head_dim = 96
kv_lora_rank = 128

9. 第二步:新增 MLAAttention

一个贴近 MiniMind 风格的 naive MLA 可以这样设计:

class MLAAttention(nn.Module):
    def __init__(self, config: MiniMindConfig):
        super().__init__()
        self.n_local_heads = config.num_attention_heads
        self.qk_nope_head_dim = config.mla_qk_nope_head_dim
        self.qk_rope_head_dim = config.mla_qk_rope_head_dim
        self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
        self.v_head_dim = config.mla_v_head_dim
        self.kv_lora_rank = config.mla_kv_lora_rank

        self.q_proj = nn.Linear(config.hidden_size, self.n_local_heads * self.qk_head_dim, bias=False)
        self.kv_down_proj = nn.Linear(config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False)
        self.kv_norm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
        self.kv_up_proj = nn.Linear(
            self.kv_lora_rank,
            self.n_local_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False
        )
        self.o_proj = nn.Linear(self.n_local_heads * self.v_head_dim, config.hidden_size, bias=False)

这和当前 Attention 最大的区别是:

没有单独的 k_proj / v_proj;
先用 kv_down_proj 得到 latent;
再用 kv_up_proj 从 latent 恢复 K_nope 和 V;
K_rope 单独从 kv_down_proj 的尾部切出来。

forward 主线可以写成:

q = self.q_proj(x).view(bsz, seq_len, self.n_local_heads, self.qk_head_dim)
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

kv = self.kv_down_proj(x)
kv_latent, k_rope = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_latent = self.kv_norm(kv_latent)

q_rope, k_rope = apply_rotary_pos_emb(
    q_rope,
    k_rope.unsqueeze(2),
    cos,
    sin
)
k_rope = k_rope.squeeze(2)

if past_key_value is not None:
    kv_latent = torch.cat([past_key_value[0], kv_latent], dim=1)
    k_rope = torch.cat([past_key_value[1], k_rope], dim=1)

past_kv = (kv_latent, k_rope) if use_cache else None

注意这里 cache 不再是:

(xk, xv)

而是:

(kv_latent, k_rope)

这就是 MiniMind 接入 MLA 后最关键的行为变化。

10. 第三步:先做 naive 版本,再做 absorb 优化

为了先跑通训练和推理,建议第一版 MLA 显式恢复 K/V:

kv_full = self.kv_up_proj(kv_latent)
kv_full = kv_full.view(
    bsz,
    -1,
    self.n_local_heads,
    self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = torch.split(kv_full, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat(
    [k_nope, k_rope.unsqueeze(2).expand(-1, -1, self.n_local_heads, -1)],
    dim=-1
)
q = torch.cat([q_nope, q_rope], dim=-1)

然后继续使用普通 attention:

q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.qk_head_dim)
weights = F.softmax(scores.float(), dim=-1).type_as(q)
output = weights @ v

这个 naive 版本已经能做到:

推理 cache 只存 latent + 小维度 rope key;
结构上验证 MLA 是否可训练;
代码路径比较容易调试。

但它还没有完全释放 MLA 的推理收益,因为每步 attention 前仍然恢复了完整 K/V。

进一步优化时,可以学习 DeepSeek 的 absorb 思路:

不显式恢复 K_nope;
把 kv_up_proj 中 K_nope 对应的权重吸收到 Q_nope 上;
直接用 Q_nope' 和 latent cache 计算 score;
V 侧也延迟到 score @ latent 后再投影回每个 head 的 V 空间。

也就是把:

Q_nope @ (W_k_up latent)^T

改写成:

(Q_nope @ W_k_up) @ latent^T

这是 MLA 真正适合高性能推理的地方。

11. MiniMind 默认配置下能省多少 cache

当前 MiniMind GQA 每层每 token 的 KV cache 约为:

num_key_value_heads * head_dim * 2
= 4 * 96 * 2
= 768

如果 MLA 取:

kv_lora_rank = 128
qk_rope_head_dim = 32

那么每层每 token cache 约为:

kv_lora_rank + qk_rope_head_dim
= 128 + 32
= 160

压缩比例约为:

160 / 768 = 20.8%

也就是 KV Cache 理论上下降约:

79.2%

这只是按元素数估算,真实收益还会受 dtype、kernel、batch size、显存访问模式、是否显式恢复 K/V 等因素影响。

12. 需要注意的工程问题

如果在 MiniMind 中真正新增 MLA,需要注意几个点。

RoPE buffer 维度

当前 MiniMind 预计算 RoPE 时使用的是:

precompute_freqs_cis(dim=config.head_dim, ...)

MLA 中 RoPE 只作用在 qk_rope_head_dim 上。

因此需要让 RoPE buffer 支持 MLA 维度:

use_mla=False:dim = head_dim
use_mla=True: dim = mla_qk_rope_head_dim

否则 q_rope / k_rope 的最后一维和 cos / sin 对不上。

causal mask

当前 MiniMind 在有 cache 时会让当前 query 看完整历史:

scores[:, :, :, -seq_len:] += causal_mask

MLA 仍然需要同样的 mask 逻辑。

变化的是 K/V 来源,不是 causal LM 的可见性规则。

FlashAttention

naive MLA 恢复完整 K/V 后,理论上仍可以复用 scaled_dot_product_attention

但 optimized MLA 不再显式构造标准 K/V,普通 FlashAttention kernel 未必能直接吃这种 latent cache。

所以第一版可以先关闭或绕开 flash 路径,等结构正确后再优化。

checkpoint 兼容性

MLA 会改变 attention 权重名字和形状。

原来的:

q_proj / k_proj / v_proj / o_proj

会变成:

q_proj / kv_down_proj / kv_up_proj / o_proj

因此旧的 MiniMind 权重不能直接完整加载。

通常有三种选择:

从头预训练 MLA 版本;
写权重迁移脚本做近似初始化;
只把 MLA 作为新实验分支,不兼容旧 checkpoint。

对教学项目来说,第一种最清晰。

训练脚本参数

trainer/train_pretrain.py 当前创建配置时是:

lm_config = MiniMindConfig(
    hidden_size=args.hidden_size,
    num_hidden_layers=args.num_hidden_layers,
    use_moe=bool(args.use_moe)
)

如果要训练 MLA,需要加上命令行参数:

--use_mla
--mla_kv_lora_rank
--mla_qk_rope_head_dim

并传入 MiniMindConfig

SFT、DPO、PPO、GRPO 等训练脚本如果也要支持 MLA,也需要同样补配置。

13. MLA 的代价

MLA 不是无脑替换。

它带来的好处主要在:

长上下文;
大 batch 推理;
高并发服务;
KV Cache 成为瓶颈的场景。

代价也很明确:

结构更复杂;
参数形状不再兼容原 checkpoint;
latent_dim 太小会损失表达能力;
训练稳定性需要重新验证;
optimized MLA 需要专门推理实现才能吃满收益。

所以对 MiniMind 这种小模型教学项目,比较合适的路线是:

先实现 naive MLA,确认 loss 能下降;
再比较相同训练配置下的 perplexity / 生成效果;
最后再考虑 absorb 优化和推理 benchmark。

14. 小结

可以用一句话概括 MLA:

用低维 latent 压缩 Attention 的历史 K/V 表示,从而降低 KV Cache 和推理带宽成本。

它和 MQA/GQA 的区别是:

MQA/GQA:减少 KV head 数。
MLA:压缩 KV 内容本身。

结合 MiniMind 来看:

当前 Attention 缓存完整 xk / xv;
MLA 版本应缓存 kv_latent / k_rope;
RoPE 需要拆成 nope 部分和 rope 部分;
第一版建议显式恢复 K/V,先保证工程可读、可训、可推理。

理解 MLA 的关键不是背公式,而是抓住这条工程主线:

KV Cache 越大,长上下文推理越贵;
GQA 通过减少 KV head 缓解;
MLA 通过压缩 KV 表示进一步缓解;
RoPE 位置部分需要单独处理。

参考资料

发表评论