Pytorch模型蒸馏Distillation

网络模型在部署时会通过剪枝蒸馏等方式加快推理速度,模型蒸馏大概可以分为通道蒸馏、特征蒸馏和目标蒸馏。这里需要特别强调的是,蒸馏的student网络学习的是teacher的泛化能力,而不是过拟合训练数据。这篇博客会以pytorch代码为基础,介绍常用的模型蒸馏方法。

一、蒸馏的概念

知识蒸馏的概念最早由Hinton在NIPS2014年提出,旨在把一个大模型或者多个模型ensemble学到的知识迁移到另一个轻量级单模型上,方便部署。这里涉及到如下几个概念:

  • teacher – 原始模型或者模型ensemble
  • student – 新模型
  • transfer set – 用来迁移teacher知识、训练student的数据集合
  • soft target – teacher输出的预测结果(一般是softmax之后的概率)
  • hard target – 样本原本的标签
  • temperature – 蒸馏目标函数中的超参数
  • born-gain network – 蒸馏的一种,指student和teacher的结构和尺寸完全一样
  • teacher annealing – 防止student的表现被teacher限制,在蒸馏时逐渐减少soft targets 的权重

二、蒸馏的经验

试验证实,soft target可以起到正则化的作用(不用soft target的时候需要early stopping,用soft target后稳定收敛)

数据过少的话无法完整表达teacher学到的知识,需要增加无监督数据(用teacher的预测作为标签)或者进行数据增强,可以使用的方法有 1. 增加Mask 2. 用相同POS标签的词替换 3. 随记n-gram采用

超参数T越大,越能学到teacher模型的泛化信息。

三、目标蒸馏

这里给一个简单的示例代码:

import torch
import torch.nn as nn
import numpy as np

from torch.nn import CrossEntropyLoss
from torch.utils.data import TensorDataset,DataLoader,SequentialSampler

class model(nn.Module):
	def __init__(self,input_dim,hidden_dim,output_dim):
		super(model,self).__init__()
		self.layer1 = nn.LSTM(input_dim,hidden_dim,output_dim,batch_first = True)
		self.layer2 = nn.Linear(hidden_dim,output_dim)
	def forward(self,inputs):
		layer1_output,layer1_hidden = self.layer1(inputs)
		layer2_output = self.layer2(layer1_output)
		layer2_output = layer2_output[:,-1,:]#取出一个batch中每个句子最后一个单词的输出向量即该句子的语义向量!!!!!!!!!
		return layer2_output

#建立小模型
model_student = model(input_dim = 2,hidden_dim = 8,output_dim = 4)

#建立大模型(此处仍然使用LSTM代替,可以使用训练好的BERT等复杂模型)
model_teacher = model(input_dim = 2,hidden_dim = 16,output_dim = 4)

#设置输入数据,此处只使用随机生成的数据代替
inputs = torch.randn(4,6,2)
true_label = torch.tensor([0,1,0,0])

#生成dataset
dataset = TensorDataset(inputs,true_label)

#生成dataloader
sampler = SequentialSampler(inputs)
dataloader = DataLoader(dataset = dataset,sampler = sampler,batch_size = 2)

loss_fun = CrossEntropyLoss()
criterion  = nn.KLDivLoss()
optimizer = torch.optim.SGD(model_student.parameters(),lr = 0.1,momentum = 0.9)

for step,batch in enumerate(dataloader):
	inputs = batch[0]
	labels = batch[1]
	
	#分别使用学生模型和教师模型对输入数据进行计算
	output_student = model_student(inputs)
	output_teacher = model_teacher(inputs)
	
	#计算学生模型和真实标签之间的交叉熵损失函数值
	loss_hard = loss_fun(output_student,labels)
	
	#计算学生模型预测结果和教师模型预测结果之间的KL散度
	loss_soft = criterion(output_student,output_teacher)
	
	loss = 0.9*loss_soft + 0.1*loss_hard
	print(loss)
	optimizer.zero_grad()
	loss.backward()
	optimizer.step()

四、特征蒸馏

五、通道蒸馏

六、参考资料

发表评论