Compute the parameters and FLOPs of the model using thop.
衡量模型的好坏,除了任务特定的性能指标(如准确率),还需要考虑模型的效率,比如模型的参数量和运算量。参数量是指模型的参数个数,描述模型存储所需内存;运算量通常用FLOPs衡量,描述模型使用所需计算力。
注意是FLOPs(floating point operations),指浮点运算数量,通常以GFLOPs ($10^9$)为单位;而不是FLOPS(floating point operations per second),指每秒浮点运算次数,后者通常用于衡量硬件的性能指标。
对于一个卷积核尺寸为$(h \times w \times c_{in})$的卷积层,其输出特征图的尺寸为$(H \times W \times c_{out})$,则该卷积层的:
- 参数量(包含偏置参数):$c_{out} \times (h \times w \times c_{in}+1)$
- FLOPs(包含偏置参数,考虑乘法与加法):$H \times W \times c_{out} \times (2 \times h \times w \times c_{in})$
通常网络中的全连接层参数量较大,需要较大的内存,但其运算量较小;卷积层参数量较小,但运算量较大,是一种计算密集型的操作。此外,还有一些网络结构(如池化和Dropout)没有参数但存在计算。
1. 使用thop库计算模型的参数量和FLOPs
PyTorch-OpCounter是为Pytorch框架设计的模型参数量和运算量统计工具,安装语句如下:
pip install thop
使用语句如下:
from thop import profile
tensor = (torch.rand(1, C, H, W),)
flops, params = profile(model, inputs=tensor)
print('FLOPs =', flops/1e9)
print('params =', params/1e6)
值得一提的是,如果在工程中使用thop库测试模型的参数等信息,在后续保存模型torch.save(model.state_dict())
时也会把total_params和total_ops等注册到网络中,导致直接加载模型model.load_state_dict(state_dict)
时报错:
Missing key(s) in state_dict: "total_ops", "total_params"...
解决办法是在加载模型时指定strict
参数:
model.load_state_dict(state_dict, strict=False)
2. 使用fvcore库计算模型的参数量和FLOPs
fvcore是Facebook开源的轻量级核心库,它提供计算机视觉框架中常见且基本的功能;其中就包括统计模型的参数以及FLOPs等。安装语句如下:
pip install fvcore
使用语句如下:
from fvcore.nn import FlopCountAnalysis, parameter_count_table
tensor = (torch.rand(1, C, H, W),)
flops = FlopCountAnalysis(model, tensor).total()
params = parameter_count_table(model)
print('flops =', flops/1e9)
print(params)