MiniMind-O 学习笔记 16:OmniDataset,文本、语音和语音监督如何进入模型

理解训练,最好的入口是 Dataset。

MiniMind-O 的 Dataset 不只是把文本变成 input_ids。它要同时准备三类东西:

1. Thinker 的文本输入和文本监督
2. Thinker 的用户语音/图像输入
3. Talker 的历史 audio codes 和语音监督

代码集中在:

dataset/omni_dataset.py

核心类是:

class OmniDataset(Dataset):

1. parquet 里有什么

训练数据使用 parquet 格式。你可以把它理解成二进制表格文件,每一行是一条训练样本,每一列是一种字段。

常见的语音样本会有三列:

conversations
question_audios
answer_audios

它们分别表示:

conversations:
  文本对话。提供 Thinker 的文本输入和文本监督。

question_audios:
  用户问题的语音。作为输入模态,给 Thinker 理解。

answer_audios:
  assistant 回复语音的 Mimi codes。作为 Talker 的语音监督。

这三列不要混淆:

用户说的话音频 -> question_audios -> 输入
助手回答的音频 -> answer_audios   -> 监督目标
助手回答的文字 -> conversations   -> 文本监督目标

2. 初始化阶段做了什么

OmniDataset.__init__ 会读取 parquet,并保存一些后续要用的配置:

self.table = pa.concat_tables(...)
self.tokenizer = tokenizer
self.audio_processor = audio_processor
self.vision_processor = vision_processor
self.max_length = max_length

还会保存几个特殊 token:

<|audio_pad|>
<|image_pad|>
audio_stop_token
audio_spk_token

其中 bos_id 和 eos_id 很重要:

self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids

它们用来在 token 序列里找到 assistant 回复的开始和结束,从而只对 assistant 的回答计算文本 loss。

3. getitem 的主线

__getitem__ 最终返回:

return input_ids, text_labels, audio_labels, audio_inputs, audio_len, pixel_values, spk_emb

可以先画成:

parquet row
   |
   |-- conversations
   |-- question_audios
   |-- answer_audios
   |
   v
构造 prompt / text_labels
构造 audio_inputs
构造 audio_labels
   |
   v
返回一条训练样本

4. 多轮对话为什么要随机截断

如果一条样本是多轮对话,代码会随机截断到某个 assistant 回合:

asst_indices = [i for i, t in enumerate(conversations) if t['role'] == 'assistant']

然后随机选择一个 assistant 回复作为当前训练终点。

这样做有两个目的:

1. 控制长度,避免超过 max_length。
2. 增加训练多样性,让模型既见到短对话,也见到多轮上下文。

截断之后,当前样本的“最后一个 user”和“最后一个 assistant”就非常关键:

最后一个 user:
  对应 question_audios,用作用户语音输入。

最后一个 assistant:
  对应 answer_audios,用作语音输出监督。

5. question_audios 如何变成输入

用户语音在 question_audios 里,代码取最后一个 user 对应的音频:

audio_bytes = question_audios[user_count - 1]
mel, valid_len = self.load_audio_inputs(audio_bytes)
audio_inputs = mel.unsqueeze(0)

这里的 mel 是 SenseVoice frontend 处理后的 fbank 特征:

audio_inputs shape = (1, T_audio, 560)

它不是 Talker 的 audio code,而是用户输入语音的特征。后面会进入模型:

audio_inputs
-> encode_audio_inputs
-> inject_audio_features
-> Thinker

如果当前样本没有语音,Dataset 会返回一个全 0 dummy:

audio_inputs = zeros(1, 1, 560)
audio_len = 0

模型里会识别全 0 输入并跳过。

6. answer_audios 如何拆成 8 路

assistant 的语音输出监督在 answer_audios 里。

它是扁平排布的 Mimi codes:

tokens = [
  c0_0, c1_0, ..., c7_0,
  c0_1, c1_1, ..., c7_1,
  ...
]

代码会拆成 8 路:

audio_codes_8layers = [[] for _ in range(8)]
for i in range(0, len(tokens) - 7, 8):
    for j in range(8):
        audio_codes_8layers[j].append(tokens[i + j])

拆完以后:

layer0 = [c0_0, c0_1, c0_2, ...]
layer1 = [c1_0, c1_1, c1_2, ...]
...
layer7 = [c7_0, c7_1, c7_2, ...]

每一路最后会追加 audio_stop_token,告诉 Talker 语音什么时候结束。

7. prompt 如何放入音频和图像占位

Dataset 会用 create_chat_prompt 构造聊天 prompt。

如果当前 user 有语音输入,它会在最后一个 user 内容里插入:

<|audio_pad|> * audio_features_length

如果有图像输入,它会把 <image> 替换成:

<|image_pad|> * 64

这些 token 的意义不是让模型读字面字符串,而是给模型预留位置:

<|audio_pad|> 位置后面会被真实音频 hidden 替换
<|image_pad|> 位置后面会被真实图像 hidden 替换

8. text_labels 如何构造

文本监督来自 conversations 里的 assistant 内容。

generate_text_labels 会扫描 token 序列,找到 assistant 回复区间:

<bos>assistant\n
assistant answer
<eos>\n

只有 assistant answer 位置保留 label,其他位置都是 -100

user/system tokens:
  label = -100

assistant answer:
  label = token id

多轮样本里,代码还会把历史 assistant 回复重新 mask 掉,只训练最后一个 assistant。

9. audio_labels 如何构造

语音监督来自 last_audio_codes

代码先初始化:

Y_audio_layers = [[audio_pad_token] * max_length for _ in range(8)]
audio_labels = [[-100] * max_length for _ in range(8)]

然后把 8 路 codes 写进去:

start_pos = assistant_start + layer_idx + 1

也就是说:

codebook 0 从 assistant_start + 1 开始
codebook 1 从 assistant_start + 2 开始
...
codebook 7 从 assistant_start + 8 开始

这就是训练侧的错位展开,对应推理时的错位补齐。

10. 最终为什么是 9 路 input_ids

最后,Dataset 会把文本和音频历史拼起来:

X_audio = torch.tensor([layer[:-1] for layer in Y_audio_layers])
X_text = torch.tensor(input_ids[:-1])
input_ids = torch.cat((X_audio, X_text.unsqueeze(0)), dim=0)

最终单条样本:

input_ids shape = (9, L)

其中:

input_ids[0:8, :]
  Talker 的 8 路历史 audio codes

input_ids[8, :]
  Thinker 的文本 token ids

对应模型里:

text_ids, audio_ids = input_ids[:, 8, :], input_ids[:, :8, :]

11. 本文小结

OmniDataset 做的事情可以用一句话概括:

把一条多模态对话样本,整理成 Thinker 和 Talker 能同时训练的 9 路自回归输入。

其中:

conversations    -> text_ids / text_labels
question_audios  -> audio_inputs,给 Thinker 理解用户语音
answer_audios    -> audio_labels,给 Talker 学会生成语音

下一篇我们进入模型本体,看 MiniMindOmni.forward 如何吃掉这些输入。

发表评论