ECCV 2020 Tutorial on PyTorch Performance Tuning Guide.

目录:

  1. use async data loading / augmentation
  2. enable cuDNN autotuner
  3. increase batch size
  4. remove unnecessary computation
  5. use DistributedDataParallel instead of DataParallel
  6. efficiently zero-out gradients
  7. apply PyTorch JIT to fuse pointwise operations
  8. checkpoint to recompute intermediates

1. use async data loading / augmentation

PytorchDataLoader支持异步的数据加载和增强操作,默认情况下有:

{num_workers=0, pin_memory=False}

下表是训练MNIST图像分类实验中不同参数的对照试验(环境PyTorch 1.6 + NVIDIA Quadro RTX 8000):

2. enable cuDNN autotuner

在训练卷积神经网络时,cuDNN支持多种不同的算法计算卷积,使用调校工具autotuner可以运行一个较小的benchmark检测这些算法,并从中选择表现最好的算法。

对于卷积神经网络,只需要设置:

torch.backends.cudnn.benchmark = True

下表是使用nn.Conv2d(64,3)处理大小为(32,64,64,64)数据的对照试验(环境PyTorch 1.6 + NVIDIA Quadro RTX 8000):

3. increase batch size

GPU内存允许的情况下增加batch size,通常结合混合精度训练使batch size更大。该方法通常结合学习率策略或更换优化方法:

4. remove unnecessary computation

batch norm中会有rescalereshift操作,因此其之前的卷积层中的bias参数可以被合并:

5. use DistributedDataParallel instead of DataParallel

DataParallel针对单一进程开启多个线程,用一个CPU核驱动多个GPU,总体还是在这些GPU上运行单一python进程。

DistributedDataParallel同时开启多个进程,用多个CPU核分别驱动多个GPU,每个GPU上都运行一个进程。

6. efficiently zero-out gradients

每次更新时需要进行梯度置零,通常使用以下语句:

model.zero_grad() 或 optimizer.zero_grad()

上述语句会对每一个参数执行memset(为新申请的内存做初始化工作),反向传播更新梯度时使用‘$+=$’操作(读+写)。为提高效率,可以将梯度置零语句替换成:

for param in model.parameters():
    param.grad = None

上述语句不会对每个参数执行memset,并且反向传播更新梯度时使用‘$=$’操作(写)。

7. apply PyTorch JIT to fuse pointwise operations

PyTorch JIT能够将逐点操作(pointwise operations)融合到单个CUDA核上,从而减小执行时间。

如下图,只需要在执行语句前加上@torch.jit.script便可以实现(环境PyTorch 1.6 + NVIDIA Quadro RTX 8000):

8. checkpoint to recompute intermediates

在常规的训练过程中,前向传播会存储中间运算的输出值用以反向传播,这一步需要更多的内存,从而限制了训练时batch size的大小;在反向传播更新参数时不再需要额外的运算。

torch.utils.checkpoint提供了checkpoint操作。在前向传播只存储部分中间运算的输出值,减小了对内存的占用,可以使用更大的batch size;反向传播时需要额外的运算。

checkpoint操作是一种用时间换内存的操作。通常需要选择对合适的操作进行,如较小的重复计算代价(re-computation cost)和较大的内存占用(memory footprint),包括激活函数、上下采样和较小堆积深度(accumulation depth)的矩阵向量运算。