前面几篇已经看过:
音频如何变成 Mimi codes Dataset 如何构造 9 路输入 forward 如何先跑 Thinker 再跑 Talker Talker 如何生成语音
最后回到训练脚本。
MiniMind-O 的训练入口主要是:
trainer/train.sh trainer/train_sft_omni.py trainer/trainer_utils.py
README 推荐 mini 数据集时使用三阶段训练:
T2A -> A2A audio_proj -> A2A full
这篇就解释为什么要这么训,以及代码里对应到哪里。
1. 三阶段先看一句话
可以先这样记:
T2A: 先让模型学会“把文本回复说出来”。 A2A audio_proj: 再让模型学会“把用户语音映射到 LLM hidden space”。 A2A full: 最后联合微调整个“听、想、说”链路。
如果一开始就全量 A2A,模型要同时学太多新东西:
Talker 还不会稳定生成语音 audio_proj 还没对齐 Thinker 还不适应语音输入 text_loss 和 audio_loss 同时拉扯
三阶段训练就是把问题拆小。
2. 第一阶段:T2A
T2A 是 Text-to-Audio。
命令大致是:
train_sft_omni.py \ --data_path ../dataset/sft_t2a_mini.parquet \ --from_weight llm \ --save_weight sft_zero \ --mode all
这个阶段从纯 LLM 权重开始:
from_weight = llm
输入主要是文本对话,监督有两种:
conversations 中 assistant 文本 -> text_loss answer_audios 中 assistant 语音 codes -> audio_loss
这阶段要学的是:
Thinker 会输出文本回答; Talker 会根据 Thinker hidden 生成语音 codes。
如果加载的是纯 LLM 权重,里面没有 Talker。代码会把 Thinker 后几层复制给 Talker:
Talker层初始化: 复制thinker layers[4:8] -> talker layers[0:4]
这让 Talker 不至于从完全随机开始。
3. 第二阶段:A2A audio_proj
A2A 是 Audio-to-Audio。用户输入从文本变成语音。
命令大致是:
train_sft_omni.py \ --data_path ../dataset/sft_a2a_mini.parquet \ --from_weight sft_zero \ --save_weight sft_zero \ --mode audio_proj
关键是:
mode = audio_proj
代码里会冻结所有参数,只训练 audio_proj:
for p in model.parameters():
p.requires_grad = False
for p in model.audio_proj.parameters():
p.requires_grad = True
为什么只训它?
因为用户语音会先过 SenseVoice:
用户语音 -> SenseVoice encoder -> audio_proj -> Thinker hidden space
SenseVoice 的输出空间和 MiniMind hidden space 不天然一致。audio_proj 的作用就是把语音特征翻译成 Thinker 能理解的 hidden。
这个阶段的目标是:
先让模型听得懂。
主体模型先不乱动,只训练语音投影层,训练会更稳。
4. 第三阶段:A2A full
第三阶段仍然用 A2A 数据,但改为全量训练:
train_sft_omni.py \ --data_path ../dataset/sft_a2a_mini.parquet \ --from_weight sft_zero \ --save_weight sft_zero \ --mode all
完整链路是:
用户语音 -> SenseVoice -> audio_proj -> Thinker -> text_logits -> bridge_states -> Talker -> audio_logits
监督仍然是:
text_loss: assistant 文本回复 audio_loss: assistant 语音 Mimi codes
这阶段让整个链路一起调整:
audio_proj 更适配 Thinker 更适应语音输入 Talker 更适应语音输入场景下的回答 文本和语音输出更一致
5. 训练脚本的主流程
train_sft_omni.py 的主线可以画成:
解析参数 -> 初始化 DDP / seed -> 构造 OmniConfig -> init_omni_model -> 根据 mode 冻结或打开参数 -> 构造 OmniDataset -> DataLoader -> train_epoch
最核心的 forward 在 train_epoch 中:
res = model(
input_ids,
audio_inputs=audio_inputs,
audio_lens=audio_lens,
pixel_values=pixel_values,
spk_emb=spk_emb
)
这里的 batch 来自 OmniDataset:
input_ids: (B, 9, L) labels: (B, L) audio_labels: (B, 8, L) audio_inputs: (B, A, 560)
6. text_loss 怎么算
Thinker 输出:
res.logits
它和 labels 算交叉熵:
text_loss_raw = loss_fct(res.logits.view(-1, res.logits.size(-1)), labels.view(-1)) text_mask = (labels.view(-1) != -100).float() text_loss = (text_loss_raw * text_mask).sum() / (text_mask.sum() + 1e-9)
labels == -100 的位置不参与训练。
所以 text_loss 只训练最后一个 assistant 的文本回复。
7. audio_loss 怎么算
Talker 输出:
res.audio_logits
它是长度为 8 的 list,每一路对应一个 codebook。
训练代码逐路计算:
for i, al in enumerate(res.audio_logits):
target_flat = audio_labels[:, i, :].reshape(-1)
layer_loss = loss_fct(al_flat, target_flat)
audio_labels == -100 的位置不参与训练。
如果目标是 stop token:
stop_mask = (target_flat == 2050).float() weighted_loss = layer_loss * valid_mask * (1 + stop_mask * 9)
stop token 权重更大,因为模型必须学会什么时候停止说话。
最后 8 路平均:
audio_loss = audio_loss / 8
8. 总 loss
总 loss 是:
loss = text_loss + audio_loss + res.aux_loss
其中:
text_loss: Thinker 文本输出监督。 audio_loss: Talker 语音 codes 监督。 aux_loss: MoE 辅助 loss,非 MoE 时基本为 0。
然后正常反向传播、梯度裁剪、optimizer step。
9. 保存权重
训练会保存两类文件:
../out/{save_weight}_768.pth
下一阶段训练和推理使用。
../checkpoints/{save_weight}_768_resume.pth
断点续训使用,包含 optimizer/scaler/epoch/step。
保存时会过滤掉外部 encoder:
SenseVoice SigLIP Mimi
这些外部模块从本地预训练目录重新加载,不需要写进 Omni 权重。
10. 本文小结
MiniMind-O 的训练可以总结成:
T2A: 让模型先会说。 A2A audio_proj: 让模型听得懂。 A2A full: 让听、想、说联合对齐。
代码上则是:
OmniDataset 产出 9 路输入 MiniMindOmni.forward 输出 text_logits + audio_logits 训练循环计算 text_loss + audio_loss 不同阶段通过 mode 控制训练哪些参数
到这里,MiniMind-O 的主链路已经完整串起来了:
数据 -> Dataset -> forward -> Talker 生成 -> 三阶段训练