在Pytorch中统计模型大小有一个非常好用的工具opcounters,opcounters用法也非常简单,这篇博客介绍opcounters用法。
安装计算量统计包
pip install thop pip install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git pip install onnx
使用示例
from torchvision.models import resnet50 from thop import profile, clevar_style model = resnet50() input = torch.randn(1, 3, 224, 224) # (batch_size, num_channel, Height, Width) flops, params = profile(model, inputs=(input, )) print('flops: {}, params: {}'.format(flops, params))
实现原理
该工具为每一种基本操作都定义了参数统计和运算量计算,如果模型中有自定义的特殊运算类,也可以定义自己的运算统计规则
class ModuleName(nn.Module): # your definition def count_model(model, x, y): # your rule here flops, params = profile(model, inputs=(input, ), custom_ops={ModuleName: count_model})
比如,对于常规二维卷积,它的统计代码如下所示:
def count_conv2d(m, x, y): x = x[0] cin = m.in_channels cout = m.out_channels kh, kw = m.kernel_size batch_size = x.size()[0] out_h = y.size(2) out_w = y.size(3) # ops per output element # kernel_mul = kh * kw * cin # kernel_add = kh * kw * cin - 1 kernel_ops = multiply_adds * kh * kw bias_ops = 1 if m.bias is not None else 0 ops_per_element = kernel_ops + bias_ops # total ops # num_out_elements = y.numel() output_elements = batch_size * out_w * out_h * cout total_ops = output_elements * ops_per_element * cin // m.groups m.total_ops = torch.Tensor([int(total_ops)])