MiniMind-O 学习笔记 19:三阶段训练,T2A、A2A audio_proj 与 A2A full

前面几篇已经看过:

音频如何变成 Mimi codes
Dataset 如何构造 9 路输入
forward 如何先跑 Thinker 再跑 Talker
Talker 如何生成语音

最后回到训练脚本。

MiniMind-O 的训练入口主要是:

trainer/train.sh
trainer/train_sft_omni.py
trainer/trainer_utils.py

README 推荐 mini 数据集时使用三阶段训练:

T2A
-> A2A audio_proj
-> A2A full

这篇就解释为什么要这么训,以及代码里对应到哪里。

1. 三阶段先看一句话

可以先这样记:

T2A:
  先让模型学会“把文本回复说出来”。

A2A audio_proj:
  再让模型学会“把用户语音映射到 LLM hidden space”。

A2A full:
  最后联合微调整个“听、想、说”链路。

如果一开始就全量 A2A,模型要同时学太多新东西:

Talker 还不会稳定生成语音
audio_proj 还没对齐
Thinker 还不适应语音输入
text_loss 和 audio_loss 同时拉扯

三阶段训练就是把问题拆小。

2. 第一阶段:T2A

T2A 是 Text-to-Audio。

命令大致是:

train_sft_omni.py \
  --data_path ../dataset/sft_t2a_mini.parquet \
  --from_weight llm \
  --save_weight sft_zero \
  --mode all

这个阶段从纯 LLM 权重开始:

from_weight = llm

输入主要是文本对话,监督有两种:

conversations 中 assistant 文本
  -> text_loss

answer_audios 中 assistant 语音 codes
  -> audio_loss

这阶段要学的是:

Thinker 会输出文本回答;
Talker 会根据 Thinker hidden 生成语音 codes。

如果加载的是纯 LLM 权重,里面没有 Talker。代码会把 Thinker 后几层复制给 Talker:

Talker层初始化: 复制thinker layers[4:8] -> talker layers[0:4]

这让 Talker 不至于从完全随机开始。

3. 第二阶段:A2A audio_proj

A2A 是 Audio-to-Audio。用户输入从文本变成语音。

命令大致是:

train_sft_omni.py \
  --data_path ../dataset/sft_a2a_mini.parquet \
  --from_weight sft_zero \
  --save_weight sft_zero \
  --mode audio_proj

关键是:

mode = audio_proj

代码里会冻结所有参数,只训练 audio_proj

for p in model.parameters():
    p.requires_grad = False
for p in model.audio_proj.parameters():
    p.requires_grad = True

为什么只训它?

因为用户语音会先过 SenseVoice:

用户语音
-> SenseVoice encoder
-> audio_proj
-> Thinker hidden space

SenseVoice 的输出空间和 MiniMind hidden space 不天然一致。audio_proj 的作用就是把语音特征翻译成 Thinker 能理解的 hidden。

这个阶段的目标是:

先让模型听得懂。

主体模型先不乱动,只训练语音投影层,训练会更稳。

4. 第三阶段:A2A full

第三阶段仍然用 A2A 数据,但改为全量训练:

train_sft_omni.py \
  --data_path ../dataset/sft_a2a_mini.parquet \
  --from_weight sft_zero \
  --save_weight sft_zero \
  --mode all

完整链路是:

用户语音
-> SenseVoice
-> audio_proj
-> Thinker
-> text_logits
-> bridge_states
-> Talker
-> audio_logits

监督仍然是:

text_loss:
  assistant 文本回复

audio_loss:
  assistant 语音 Mimi codes

这阶段让整个链路一起调整:

audio_proj 更适配
Thinker 更适应语音输入
Talker 更适应语音输入场景下的回答
文本和语音输出更一致

5. 训练脚本的主流程

train_sft_omni.py 的主线可以画成:

解析参数
-> 初始化 DDP / seed
-> 构造 OmniConfig
-> init_omni_model
-> 根据 mode 冻结或打开参数
-> 构造 OmniDataset
-> DataLoader
-> train_epoch

最核心的 forward 在 train_epoch 中:

res = model(
    input_ids,
    audio_inputs=audio_inputs,
    audio_lens=audio_lens,
    pixel_values=pixel_values,
    spk_emb=spk_emb
)

这里的 batch 来自 OmniDataset

input_ids:
  (B, 9, L)

labels:
  (B, L)

audio_labels:
  (B, 8, L)

audio_inputs:
  (B, A, 560)

6. text_loss 怎么算

Thinker 输出:

res.logits

它和 labels 算交叉熵:

text_loss_raw = loss_fct(res.logits.view(-1, res.logits.size(-1)), labels.view(-1))
text_mask = (labels.view(-1) != -100).float()
text_loss = (text_loss_raw * text_mask).sum() / (text_mask.sum() + 1e-9)

labels == -100 的位置不参与训练。

所以 text_loss 只训练最后一个 assistant 的文本回复。

7. audio_loss 怎么算

Talker 输出:

res.audio_logits

它是长度为 8 的 list,每一路对应一个 codebook。

训练代码逐路计算:

for i, al in enumerate(res.audio_logits):
    target_flat = audio_labels[:, i, :].reshape(-1)
    layer_loss = loss_fct(al_flat, target_flat)

audio_labels == -100 的位置不参与训练。

如果目标是 stop token:

stop_mask = (target_flat == 2050).float()
weighted_loss = layer_loss * valid_mask * (1 + stop_mask * 9)

stop token 权重更大,因为模型必须学会什么时候停止说话。

最后 8 路平均:

audio_loss = audio_loss / 8

8. 总 loss

总 loss 是:

loss = text_loss + audio_loss + res.aux_loss

其中:

text_loss:
  Thinker 文本输出监督。

audio_loss:
  Talker 语音 codes 监督。

aux_loss:
  MoE 辅助 loss,非 MoE 时基本为 0。

然后正常反向传播、梯度裁剪、optimizer step。

9. 保存权重

训练会保存两类文件:

../out/{save_weight}_768.pth
  下一阶段训练和推理使用。

../checkpoints/{save_weight}_768_resume.pth
  断点续训使用,包含 optimizer/scaler/epoch/step。

保存时会过滤掉外部 encoder:

SenseVoice
SigLIP
Mimi

这些外部模块从本地预训练目录重新加载,不需要写进 Omni 权重。

10. 本文小结

MiniMind-O 的训练可以总结成:

T2A:
  让模型先会说。

A2A audio_proj:
  让模型听得懂。

A2A full:
  让听、想、说联合对齐。

代码上则是:

OmniDataset 产出 9 路输入
MiniMindOmni.forward 输出 text_logits + audio_logits
训练循环计算 text_loss + audio_loss
不同阶段通过 mode 控制训练哪些参数

到这里,MiniMind-O 的主链路已经完整串起来了:

数据
-> Dataset
-> forward
-> Talker 生成
-> 三阶段训练

发表评论