MiniMind-O 学习笔记 17:从 forward 看懂 Thinker-Talker 如何协作

前面已经看过 Dataset 如何产出 9 路输入。现在进入模型本体:

model/model_omni.py

核心入口是:

MiniMindOmni.forward(...)

这篇只抓主线:一次 forward 里,Thinker 和 Talker 是如何协作的。

1. forward 的输入有哪些

训练时,模型收到的主要输入是:

input_ids
audio_inputs
audio_lens
pixel_values
spk_emb

它们分别表示:

input_ids:
  Dataset 产出的 9 路输入。

audio_inputs:
  用户语音的 fbank 特征,给 Thinker。

audio_lens:
  用户语音有效长度。

pixel_values:
  用户图像输入,给 Thinker。

spk_emb:
  音色条件,给 Talker。

注意两个名字很容易混:

audio_inputs:
  用户输入语音特征。

audio_ids:
  Talker 历史输出音频 codes。

前者用于理解用户语音,后者用于自回归生成语音。

2. 第一步:拆分 9 路 input_ids

训练时 input_ids 形状是:

(B, 9, T)

代码拆成:

text_ids, audio_ids = input_ids[:, 8, :], input_ids[:, :8, :]

也就是:

text_ids:
  shape = (B, T)
  给 Thinker 使用。

audio_ids:
  shape = (B, 8, T)
  给 Talker 使用。

如果推理时只传二维文本输入,代码会自动构造全 pad 的 audio_ids

3. Thinker:先把文本 token 变成 hidden

Thinker 本质上就是 MiniMind 的语言模型主干:

object.__setattr__(self, 'thinker', self.model)

所以:

self.thinker == self.model

forward 里先做文本 embedding:

hidden_states = self.thinker.dropout(self.thinker.embed_tokens(text_ids))

此时:

hidden_states shape = (B, T, hidden_size)

这还是普通文本 token 的 hidden states。接下来会把音频和图像特征注入进去。

4. 用户语音如何注入 Thinker

如果有 audio_inputs,代码会先编码:

audio_features = self.encode_audio_inputs(audio_inputs, audio_lens)

这里大致是:

audio_inputs: (B, A, 560)
-> SenseVoice encoder
-> audio_proj
-> List[Tensor(A_i, hidden_size)]

然后注入:

hidden_states = self.inject_audio_features(text_ids, hidden_states, audio_features, seq_length)

它会在 text_ids 里找到连续的 <|audio_pad|>,把这些位置的 hidden state 替换成真实音频特征。

可以画成:

原始 hidden:
  [文本] [audio_pad] [audio_pad] [文本]

注入后:
  [文本] [audio_feat_0] [audio_feat_1] [文本]

这样,用户语音就进入了 Thinker 的序列。

5. 图像如何注入 Thinker

图像走类似路径:

pixel_values
-> SigLIP vision encoder
-> vision_proj
-> 替换 <|image_pad|> 位置

在 prompt 里,一张图通常占 64 个 <|image_pad|>

所以 Thinker 最终看到的是一条统一序列:

文本 hidden + 音频 hidden + 图像 hidden

Transformer 不需要特别区分这些来源,只在同一个 hidden 序列里建模。

6. bridge_states 在哪里保存

Thinker 会逐层运行:

for i, layer in enumerate(self.thinker.layers):
    hidden_states, present = layer(...)
    if i == self.config.bridge_layer:
        bridge_states = hidden_states

默认:

bridge_layer = num_hidden_layers // 2 - 1

8 层 Thinker 时:

bridge_layer = 3

也就是保存第 4 层输出,而不是最后一层。

原因是:

中间层更像通用语义表示;
最后层更贴近文本 token 预测。

Talker 要的是语义条件,所以默认取中间层。

7. Thinker 输出文本 logits

Thinker 跑完所有层后:

h_thinker = self.thinker.norm(hidden_states)
text_logits = self.thinker.lm_head(h_thinker)

text_logits 用来训练或采样文本 token。

训练时它和 text_labels 算交叉熵。

推理时它会采样出下一个文本 token。

8. Talker:融合语义和历史语音

Talker 部分从这里开始:

talker_emb = self.talker.embed_tokens(audio_ids)

audio_ids 是历史输出音频 codes,不是用户输入语音。它告诉 Talker:

前面已经怎么说了。

bridge_states 告诉 Talker:

接下来要说什么。

两路信息融合:

hidden_states = (
    self.talker.embed_proj(bridge_states) * self.talker.text_scale
    + self.talker.codec_proj(talker_emb) * self.talker.audio_scale
)

可以画成:

bridge_states -> embed_proj  ----+
                                 +-> Talker hidden
audio_ids     -> embed_tokens
              -> codec_proj  ----+

如果有 spk_emb,还会在 audio_spk_token 位置注入音色条件。

9. Talker 输出 8 路 audio_logits

融合后的 hidden 会经过 Talker 自己的 Transformer layers:

for layer in self.talker.layers:
    hidden_states, present = layer(...)

最后:

h_talker = self.talker.norm(hidden_states)
audio_logits = self.talker.lm_head(h_talker)

audio_logits 是长度为 8 的 list:

audio_logits[0] -> codebook 0
audio_logits[1] -> codebook 1
...
audio_logits[7] -> codebook 7

训练时,它和 audio_labels 算 loss。推理时,从中采样 Mimi codes。

10. forward 最后返回什么

最后返回:

out.logits = text_logits
out.audio_logits = audio_logits
out.past_key_values = presents

也就是:

Thinker 的文本预测分布
Talker 的 8 路音频 code 预测分布
Thinker + Talker 的 KV cache

11. forward 总图

input_ids
   |
   |-- text_ids  -> Thinker
   |-- audio_ids -> Talker
   |
   v
Thinker:
  text embedding
  + audio feature injection
  + image feature injection
  -> thinker layers
  -> text_logits
  -> bridge_states

Talker:
  bridge_states
  + historical audio_ids
  + optional spk_emb
  -> talker layers
  -> 8-way audio_logits

12. 本文小结

MiniMindOmni.forward 的主线很清楚:

先跑 Thinker,得到文本输出和中间语义表示;
再跑 Talker,把语义表示和历史音频 codes 融合,生成 8 路音频 code logits。

下一篇我们专门看 Talker:它的结构、为什么要接收 audio_ids,以及它如何在推理时流式生成语音。

发表评论