使用OpCounter和flops-counter评估pytorch模型大小

在Pytorch中统计模型大小有一个非常好用的工具opcounters,opcounters用法也非常简单,这篇博客介绍opcounters用法。

安装计算量统计包

    pip install thop
    pip install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git
    pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git
    pip install onnx

使用示例

from torchvision.models import resnet50
from thop import profile, clevar_style

model = resnet50()
input = torch.randn(1, 3, 224, 224) # (batch_size, num_channel, Height, Width)
flops, params = profile(model, inputs=(input, )) 
print('flops: {}, params: {}'.format(flops, params))

实现原理

该工具为每一种基本操作都定义了参数统计和运算量计算,如果模型中有自定义的特殊运算类,也可以定义自己的运算统计规则

class ModuleName(nn.Module):
    # your definition
def count_model(model, x, y):
    # your rule here

flops, params = profile(model, inputs=(input, ), 
                        custom_ops={ModuleName: count_model})

比如,对于常规二维卷积,它的统计代码如下所示:

def count_conv2d(m, x, y):
    x = x[0]
    cin = m.in_channels
    cout = m.out_channels
    kh, kw = m.kernel_size
    batch_size = x.size()[0]
    out_h = y.size(2)
    out_w = y.size(3)
    # ops per output element
    # kernel_mul = kh * kw * cin
    # kernel_add = kh * kw * cin - 1
    kernel_ops = multiply_adds * kh * kw
    bias_ops = 1 if m.bias is not None else 0
    ops_per_element = kernel_ops + bias_ops
    # total ops
    # num_out_elements = y.numel()
    output_elements = batch_size * out_w * out_h * cout
    total_ops = output_elements * ops_per_element * cin // m.groups
    m.total_ops = torch.Tensor([int(total_ops)])

发表评论