我们在深度学习模型转换和推理过程中常常会遇到算子不被工具链支持的情况,这时我们可以通过其它算子来等价实现我们想要的算子,比如3d卷积不被工具链支持,我们可以将其拆分成2d卷积,拆分之后我们需要进行权重字典的拆分。这篇博客贴一下pytorch的实现。
首先我们自己用2D卷积实现一个3D卷积
class MSNConv3d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): super(MSNConv3d, self).__init__() self.in_chns = in_channels self.tmp_chns = kernel_size[0] self.out_chns = out_channels self.chns = in_channels / stride[0] self.conv2d_list = nn.ModuleList([nn.Conv2d(self.tmp_chns, 1, kernel_size[1:3], stride[1:3], padding[1:3], 1) for i in range(0, self.in_chns * self.out_chns)]) def forward(self, input): shape_in = input.shape slice_cnt = int(shape_in[2] / self.tmp_chns) output_slc = input.new_zeros([self.out_chns, slice_cnt, self.in_chns, shape_in[3], shape_in[4]]) for iout in range(0, self.out_chns): for islc in range(0, slice_cnt): for iin in range(0, self.in_chns): # fetch data data_slice = input[0, iin, (islc * self.tmp_chns):((islc + 1) * self.tmp_chns), :, :] if len(data_slice.shape) == 3: data_slice = data_slice.unsqueeze(0) # select kernel kernel_id = iout * self.in_chns + iin output_slc[iout, islc, iin, :, :] = self.conv2d_list[kernel_id](data_slice).squeeze() # shape should be ([self.out_chns, slice_cnt, shape_in[3], shape_in[4]]) output = torch.sum(output_slc, dim=(2,), keepdim=False).unsqueeze(0) return output
然后我们进行权重字典的转存(这里注意bias只需要加一次)
class ConvTestNet(nn.Module): def __init__(self): super(ConvTestNet, self).__init__() self.conv3d = nn.Sequential(nn.Conv3d(2, 16, kernel_size=(8, 3, 3), stride=[8, 1, 1], padding=[0, 1, 1]), nn.BatchNorm3d(16), nn.ReLU() ) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.Conv3d): n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.BatchNorm3d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.bias.data.zero_() def forward(self, x): x = self.conv3d(x) return x class ConvTestNet2(nn.Module): def __init__(self): super(ConvTestNet2, self).__init__() self.conv3d = nn.Sequential(MSNConv3d(2, 16, kernel_size=(8, 3, 3), stride=[8, 1, 1], padding=[0, 1, 1]), nn.BatchNorm3d(16), nn.ReLU() ) def convert_state(self, state_dict_outside): state_dict_inside_copy = self.state_dict().copy() state_dict_inside = self.state_dict() # mark conv3d dict name prefix conv3d_name_list_out = ['conv3d.0', 'conv3d.3'] conv3d_name_list_in = ['conv3d.0', 'conv3d.3'] # find matched for i in range(0, len(conv3d_name_list_out)): key_out = conv3d_name_list_out[i] key_in = conv3d_name_list_in[i] val_out = None last_w = 0 for key_out_iter in state_dict_outside: val_out = state_dict_outside[key_out_iter] if key_out_iter.__contains__(key_out) and key_out_iter.__contains__('weight'): for key_in_iter in state_dict_inside: if key_in_iter.__contains__(key_in) and key_in_iter.__contains__('weight'): match_obj = re.match(r'.*conv2d_list\.(\d+)\.weight', key_in_iter) idx_in = int(match_obj[1]) shape_out = val_out.shape[0:2] last_w = shape_out[1] idx_out = [idx_in // shape_out[1], idx_in % shape_out[1]] val_in = state_dict_inside[key_in_iter] state_dict_inside[key_in_iter] = val_out[idx_out[0], idx_out[1], :, :, :].unsqueeze(0) elif key_out_iter.__contains__(key_out) and key_out_iter.__contains__('bias'): for key_in_iter in state_dict_inside: if key_in_iter.__contains__(key_in) and key_in_iter.__contains__('bias'): match_obj = re.match(r'.*conv2d_list\.(\d+)\.bias', key_in_iter) idx_in = int(match_obj[1]) shape_out = val_out.shape[0] idx_out = idx_in // last_w mod_out = idx_in % last_w if mod_out == 0: state_dict_inside[key_in_iter] = val_out[idx_out].unsqueeze(0) else: state_dict_inside[key_in_iter][0] = 0 self.load_state_dict(state_dict_inside) return def forward(self, x): x = self.conv3d(x) return x
最后是我们的测试程序,我们实现的3D卷积和原版3D卷积基本没有误差
def test(): model1 = ConvTestNet() torch.save(model1.state_dict(), "model1.ckpt") state_dict1 = torch.load("model1.ckpt") model1.load_state_dict(state_dict1) model2 = ConvTestNet2() model2.convert_state(state_dict1) input = torch.randn(1, 2, 64, 92, 164) model1.eval() output1 = model1(input) model2.eval() output2 = model2(input) err_sum = torch.sum(torch.abs(output1 - output2)).item() err_max = torch.max(torch.abs(output1 - output2)).item() print("max err: %f, total err: %f" % (err_max, err_sum))
OK,希望能有帮助