GAN的TTUR训练方法和FID评估指标.
1. Two Time-Scale Update Rule (TTUR)
在GAN的训练过程中,当判别器收敛到一个局部极小值时,如果生成器的变化足够慢,则判别器仍然收敛。
在设置优化函数时,应设法保证判别器的判别能力比生成器的生成能力要好。通常的做法是先更新判别器的参数多次,再更新一次生成器的参数。
本文作者提出了一种更简单的学习策略,即将判别器的学习率设置得比生成器的学习率更大,使得判别器收敛速度更快。
所提出的双时间尺度更新规则(Two Time-Scale Update Rule, TTUR)表示如下:
\[\begin{aligned} θ_D & \leftarrow θ_D + \alpha \nabla_{θ_D}L(D,G) \\ \theta_G & \leftarrow θ_G - \beta \nabla_{θ_G}L(D,G) \end{aligned}\]作者证明了在上述TTUR训练下,当判别器和生成器具有不同的学习率$\alpha > \beta$时,网络收敛于局部纳什均衡。
2. Fréchet Inception Distance (FID)
为生成模型设定合适的性能度量是比较困难的,需要衡量真实数据分布$P_{data}(x)$和生成数据分布$P_G(x)$之间的距离。
本文作者通过计算数据分布$p(x)$在一个多项式基$f(x)$上的矩$\int p(x)f(x)dx$来构造生成模型的性能指标,称为Fréchet Inception Distance (FID)。
作者使用Inception模型的编码层(分类层之前)提取图像的视觉特征$x$,并将多项式基$f(x)$取前两项,对应特征的均值和协方差。
由于高斯分布是给定均值和协方差的最大熵分布,因此假设编码特征服从多维高斯分布。两个高斯分布的差异可以通过Fréchet距离 (也称为Wasserstein-2距离) 来衡量。
记从真实数据分布$P_{data}(x)$中获得的特征分布为$N(m,C)$,从生成数据分布$P_G(x)$中获得的特征分布为$N(m_w,C_w)$,则两个分布之间的FID距离定义为:
\[d^2((m,C),(m_w,C_w)) = ||m-m_w||_2^2+Tr(C+C_w-2(CC_w)^{1/2})\]FID值越小,表明两种数据分布的相似程度越高。下图给出了向图像中增加不同程度的噪声时,FID值的变化情况:
在GAN模型中FID值的计算方法归纳如下:
- 取一批合成图像输入Inception网络,取分类层之前的特征来计算均值和协方差$(m_w,C_w)$;
- 取一批真实图像输入Inception网络,取分类层之前的特征来计算均值和协方差$(m,C)$;
- 根据上述公式计算Fréchet距离。
使用pytorch-fid
库可以便捷地实现FID值的计算:
# pip install pytorch-fid
python -m pytorch_fid path/to/dataset1 path/to/dataset2 --dims 2048
其中dims
指定了特征的维数,不同维数对应不同层的特征:
64: first max pooling features
192: second max pooling featurs
768: pre-aux classifier features
2048: final average pooling features (this is the default)