我们在深度学习模型转换和推理过程中常常会遇到算子不被工具链支持的情况,这时我们可以通过其它算子来等价实现我们想要的算子,比如3d卷积不被工具链支持,我们可以将其拆分成2d卷积,拆分之后我们需要进行权重字典的拆分。这篇博客贴一下pytorch的实现。
首先我们自己用2D卷积实现一个3D卷积
class MSNConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros'):
super(MSNConv3d, self).__init__()
self.in_chns = in_channels
self.tmp_chns = kernel_size[0]
self.out_chns = out_channels
self.chns = in_channels / stride[0]
self.conv2d_list = nn.ModuleList([nn.Conv2d(self.tmp_chns, 1, kernel_size[1:3], stride[1:3], padding[1:3], 1)
for i in range(0, self.in_chns * self.out_chns)])
def forward(self, input):
shape_in = input.shape
slice_cnt = int(shape_in[2] / self.tmp_chns)
output_slc = input.new_zeros([self.out_chns, slice_cnt, self.in_chns, shape_in[3], shape_in[4]])
for iout in range(0, self.out_chns):
for islc in range(0, slice_cnt):
for iin in range(0, self.in_chns):
# fetch data
data_slice = input[0, iin, (islc * self.tmp_chns):((islc + 1) * self.tmp_chns), :, :]
if len(data_slice.shape) == 3:
data_slice = data_slice.unsqueeze(0)
# select kernel
kernel_id = iout * self.in_chns + iin
output_slc[iout, islc, iin, :, :] = self.conv2d_list[kernel_id](data_slice).squeeze()
# shape should be ([self.out_chns, slice_cnt, shape_in[3], shape_in[4]])
output = torch.sum(output_slc, dim=(2,), keepdim=False).unsqueeze(0)
return output
然后我们进行权重字典的转存(这里注意bias只需要加一次)
class ConvTestNet(nn.Module):
def __init__(self):
super(ConvTestNet, self).__init__()
self.conv3d = nn.Sequential(nn.Conv3d(2, 16, kernel_size=(8, 3, 3), stride=[8, 1, 1], padding=[0, 1, 1]),
nn.BatchNorm3d(16),
nn.ReLU()
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.Conv3d):
n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
x = self.conv3d(x)
return x
class ConvTestNet2(nn.Module):
def __init__(self):
super(ConvTestNet2, self).__init__()
self.conv3d = nn.Sequential(MSNConv3d(2, 16, kernel_size=(8, 3, 3), stride=[8, 1, 1], padding=[0, 1, 1]),
nn.BatchNorm3d(16),
nn.ReLU()
)
def convert_state(self, state_dict_outside):
state_dict_inside_copy = self.state_dict().copy()
state_dict_inside = self.state_dict()
# mark conv3d dict name prefix
conv3d_name_list_out = ['conv3d.0', 'conv3d.3']
conv3d_name_list_in = ['conv3d.0', 'conv3d.3']
# find matched
for i in range(0, len(conv3d_name_list_out)):
key_out = conv3d_name_list_out[i]
key_in = conv3d_name_list_in[i]
val_out = None
last_w = 0
for key_out_iter in state_dict_outside:
val_out = state_dict_outside[key_out_iter]
if key_out_iter.__contains__(key_out) and key_out_iter.__contains__('weight'):
for key_in_iter in state_dict_inside:
if key_in_iter.__contains__(key_in) and key_in_iter.__contains__('weight'):
match_obj = re.match(r'.*conv2d_list\.(\d+)\.weight', key_in_iter)
idx_in = int(match_obj[1])
shape_out = val_out.shape[0:2]
last_w = shape_out[1]
idx_out = [idx_in // shape_out[1], idx_in % shape_out[1]]
val_in = state_dict_inside[key_in_iter]
state_dict_inside[key_in_iter] = val_out[idx_out[0], idx_out[1], :, :, :].unsqueeze(0)
elif key_out_iter.__contains__(key_out) and key_out_iter.__contains__('bias'):
for key_in_iter in state_dict_inside:
if key_in_iter.__contains__(key_in) and key_in_iter.__contains__('bias'):
match_obj = re.match(r'.*conv2d_list\.(\d+)\.bias', key_in_iter)
idx_in = int(match_obj[1])
shape_out = val_out.shape[0]
idx_out = idx_in // last_w
mod_out = idx_in % last_w
if mod_out == 0:
state_dict_inside[key_in_iter] = val_out[idx_out].unsqueeze(0)
else:
state_dict_inside[key_in_iter][0] = 0
self.load_state_dict(state_dict_inside)
return
def forward(self, x):
x = self.conv3d(x)
return x
最后是我们的测试程序,我们实现的3D卷积和原版3D卷积基本没有误差
def test():
model1 = ConvTestNet()
torch.save(model1.state_dict(), "model1.ckpt")
state_dict1 = torch.load("model1.ckpt")
model1.load_state_dict(state_dict1)
model2 = ConvTestNet2()
model2.convert_state(state_dict1)
input = torch.randn(1, 2, 64, 92, 164)
model1.eval()
output1 = model1(input)
model2.eval()
output2 = model2(input)
err_sum = torch.sum(torch.abs(output1 - output2)).item()
err_max = torch.max(torch.abs(output1 - output2)).item()
print("max err: %f, total err: %f" % (err_max, err_sum))
OK,希望能有帮助