Compute the parameters and FLOPs of the model using thop.

衡量模型的好坏,除了任务特定的性能指标(如准确率),还需要考虑模型的效率,比如模型的参数量和运算量。参数量是指模型的参数个数,描述模型存储所需内存运算量通常用FLOPs衡量,描述模型使用所需计算力

注意是FLOPs(floating point operations),指浮点运算数量,通常以GFLOPs ($10^9$)为单位;而不是FLOPS(floating point operations per second),指每秒浮点运算次数,后者通常用于衡量硬件的性能指标。

对于一个卷积核尺寸为$(h \times w \times c_{in})$的卷积层,其输出特征图的尺寸为$(H \times W \times c_{out})$,则该卷积层的:

通常网络中的全连接层参数量较大,需要较大的内存,但其运算量较小;卷积层参数量较小,但运算量较大,是一种计算密集型的操作。此外,还有一些网络结构(如池化和Dropout)没有参数但存在计算。

1. 使用thop库计算模型的参数量和FLOPs

PyTorch-OpCounter是为Pytorch框架设计的模型参数量和运算量统计工具,安装语句如下:

pip install thop

使用语句如下:

from thop import profile

tensor = (torch.rand(1, C, H, W),)
flops, params = profile(model, inputs=tensor)
print('FLOPs =', flops/1e9)
print('params =', params/1e6)

值得一提的是,如果在工程中使用thop库测试模型的参数等信息,在后续保存模型torch.save(model.state_dict())时也会把total_paramstotal_ops等注册到网络中,导致直接加载模型model.load_state_dict(state_dict)时报错:

Missing key(s) in state_dict: "total_ops", "total_params"...

解决办法是在加载模型时指定strict参数:

model.load_state_dict(state_dict, strict=False)

2. 使用fvcore库计算模型的参数量和FLOPs

fvcoreFacebook开源的轻量级核心库,它提供计算机视觉框架中常见且基本的功能;其中就包括统计模型的参数以及FLOPs等。安装语句如下:

pip install fvcore

使用语句如下:

from fvcore.nn import FlopCountAnalysis, parameter_count_table

tensor = (torch.rand(1, C, H, W),)
flops = FlopCountAnalysis(model, tensor).total()
params = parameter_count_table(model)
print('flops =', flops/1e9)
print(params)