Pytorch使用2d卷积来实现3d卷积

我们在深度学习模型转换和推理过程中常常会遇到算子不被工具链支持的情况,这时我们可以通过其它算子来等价实现我们想要的算子,比如3d卷积不被工具链支持,我们可以将其拆分成2d卷积,拆分之后我们需要进行权重字典的拆分。这篇博客贴一下pytorch的实现。

首先我们自己用2D卷积实现一个3D卷积

class MSNConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 padding_mode='zeros'):
        super(MSNConv3d, self).__init__()
        self.in_chns = in_channels
        self.tmp_chns = kernel_size[0]
        self.out_chns = out_channels
        self.chns = in_channels / stride[0]
        self.conv2d_list = nn.ModuleList([nn.Conv2d(self.tmp_chns, 1, kernel_size[1:3], stride[1:3], padding[1:3], 1)
                            for i in range(0, self.in_chns * self.out_chns)])

    def forward(self, input):
        shape_in = input.shape
        slice_cnt = int(shape_in[2] / self.tmp_chns)

        output_slc = input.new_zeros([self.out_chns, slice_cnt, self.in_chns, shape_in[3], shape_in[4]])

        for iout in range(0, self.out_chns):
            for islc in range(0, slice_cnt):
                for iin in range(0, self.in_chns):
                    # fetch data
                    data_slice = input[0, iin, (islc * self.tmp_chns):((islc + 1) * self.tmp_chns), :, :]
                    if len(data_slice.shape) == 3:
                        data_slice = data_slice.unsqueeze(0)
                    # select kernel
                    kernel_id = iout * self.in_chns + iin
                    output_slc[iout, islc, iin, :, :] = self.conv2d_list[kernel_id](data_slice).squeeze()

        # shape should be ([self.out_chns, slice_cnt, shape_in[3], shape_in[4]])
        output = torch.sum(output_slc, dim=(2,), keepdim=False).unsqueeze(0)
        return output

然后我们进行权重字典的转存(这里注意bias只需要加一次)

class ConvTestNet(nn.Module):
    def __init__(self):
        super(ConvTestNet, self).__init__()
        self.conv3d = nn.Sequential(nn.Conv3d(2, 16, kernel_size=(8, 3, 3), stride=[8, 1, 1], padding=[0, 1, 1]),
                                    nn.BatchNorm3d(16),
                                    nn.ReLU()
                                    )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.Conv3d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()


    def forward(self, x):
        x = self.conv3d(x)
        return x


class ConvTestNet2(nn.Module):
    def __init__(self):
        super(ConvTestNet2, self).__init__()
        self.conv3d = nn.Sequential(MSNConv3d(2, 16, kernel_size=(8, 3, 3), stride=[8, 1, 1], padding=[0, 1, 1]),
                                    nn.BatchNorm3d(16),
                                    nn.ReLU()
                                    )

    def convert_state(self, state_dict_outside):
        state_dict_inside_copy = self.state_dict().copy()
        state_dict_inside = self.state_dict()
        # mark conv3d dict name prefix
        conv3d_name_list_out = ['conv3d.0', 'conv3d.3']
        conv3d_name_list_in = ['conv3d.0', 'conv3d.3']

        # find matched
        for i in range(0, len(conv3d_name_list_out)):
            key_out = conv3d_name_list_out[i]
            key_in = conv3d_name_list_in[i]

            val_out = None
            last_w = 0
            for key_out_iter in state_dict_outside:
                val_out = state_dict_outside[key_out_iter]
                if key_out_iter.__contains__(key_out) and key_out_iter.__contains__('weight'):
                    for key_in_iter in state_dict_inside:
                        if key_in_iter.__contains__(key_in) and key_in_iter.__contains__('weight'):
                            match_obj = re.match(r'.*conv2d_list\.(\d+)\.weight', key_in_iter)
                            idx_in = int(match_obj[1])
                            shape_out = val_out.shape[0:2]
                            last_w = shape_out[1]
                            idx_out = [idx_in // shape_out[1], idx_in % shape_out[1]]
                            val_in = state_dict_inside[key_in_iter]
                            state_dict_inside[key_in_iter] = val_out[idx_out[0], idx_out[1], :, :, :].unsqueeze(0)
                elif key_out_iter.__contains__(key_out) and key_out_iter.__contains__('bias'):
                    for key_in_iter in state_dict_inside:
                        if key_in_iter.__contains__(key_in) and key_in_iter.__contains__('bias'):
                            match_obj = re.match(r'.*conv2d_list\.(\d+)\.bias', key_in_iter)
                            idx_in = int(match_obj[1])
                            shape_out = val_out.shape[0]
                            idx_out = idx_in // last_w
                            mod_out = idx_in % last_w
                            if mod_out == 0:
                                state_dict_inside[key_in_iter] = val_out[idx_out].unsqueeze(0)
                            else:
                                state_dict_inside[key_in_iter][0] = 0

        self.load_state_dict(state_dict_inside)
        return

    def forward(self, x):
        x = self.conv3d(x)
        return x

最后是我们的测试程序,我们实现的3D卷积和原版3D卷积基本没有误差

def test():
    model1 = ConvTestNet()
    torch.save(model1.state_dict(), "model1.ckpt")

    state_dict1 = torch.load("model1.ckpt")
    model1.load_state_dict(state_dict1)

    model2 = ConvTestNet2()
    model2.convert_state(state_dict1)

    input = torch.randn(1, 2, 64, 92, 164)

    model1.eval()
    output1 = model1(input)

    model2.eval()
    output2 = model2(input)

    err_sum = torch.sum(torch.abs(output1 - output2)).item()
    err_max = torch.max(torch.abs(output1 - output2)).item()

    print("max err: %f,  total err: %f" % (err_max, err_sum))

OK,希望能有帮助

发表评论