SAGAN:自注意力生成对抗网络.
SAGAN向生成对抗网络中引入了自注意力机制(Self-Attention),不仅摆脱了卷积层的感受野大小的限制,也使得网络在生成图像的过程中能够自己学习应关注的区域。
对于含有较多几何或结构约束的图像GAN的生成效果较差,这是因为复杂的几何轮廓需要长距离依赖(long-range dependency),卷积层的特点是局部性,受到感受野大小的限制很难提取到图像中的长距离依赖关系。
SAGAN把自注意力机制引入模型结构中,有助于对图像区域中长距离、多层次的依赖关系进行建模。此外,生成器和判别器均应用了谱归一化,使得网络满足Lipschitz连续性;训练时遵循TTUR原则,判别器$D$和生成器$G$的学习率是不平衡的。
SAGAN的损失函数采用Hinge损失:
\[\begin{aligned} \mathop{ \max}_{D} & \Bbb{E}_{x \text{~} P_{data}(x)}[\min(0,-1+D(x))] \\ &+ \Bbb{E}_{x \text{~} P_{G}(x)}[\min(0,-1-D(x))] \\ \mathop{ \min}_{G}& -\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \end{aligned}\]1. 自注意力机制
自注意力机制在计算输入位置$i$的特征$y_i$时,考虑所有位置$j$的加权:
\[y_i = \sum_{j}^{} \frac{e^{f(x_i)^Tg(x_j)}}{\sum_j e^{f(x_i)^Tg(x_j)}} h(x_j)\]自注意力机制的实现步骤如下:
- $f(x)$、$g(x)$和$h(x)$通过三个$1\times 1$卷积层实现,$f(x)$和$g(x)$改变了通道数(缩小为$C/8$),$h(x)$维持通道数不变;
- 将空间尺寸合并为$H\times W$,将$f(x)$的输出转置后和$g(x)$的输出进行矩阵相乘,经过softmax归一化得到尺寸为$[H\times W,H\times W]$的注意力图;
- 将注意力图与$h(x)$的输出进行矩阵相乘,得到尺寸为$[H\times W,C]$的特征图,经过$1\times 1$卷积层并把输出尺寸调整为为$[H,W,C]$;
- 最终输出的特征可以通过标量缩放$\gamma$和残差连接构造:$y = γy + x$。
在计算注意力图时,$f(x)$和$g(x)$的输出通道数不影响注意力图的尺寸,较少的通道数会减少参数量和计算量,作者在实验中分别使用$C/k(k=1,2,4,8)$训练后发现对结果影响不大,因此最终选用了$C/8$。
$f(x)$和$g(x)$的注意力图得到的是$[H\times W,H\times W]$的输出,因此表示的是像素点与像素点之间的相关性。当经过了softmax函数之后(注意这里是对每一行单独进行softmax),每一行就代表了一个注意力分别,对应一个特征像素位置($C$个像素通道点)与其它所有像素位置的关系,$H\times W$行对应了$H\times W$个像素位置。
注意力图与$h(x)$的输出进行矩阵相乘,使得$h(x)$的每个特征像素都和其余所有像素建立了联系,结果表示为所有像素按照注意力图提供的注意力分布进行加权组合。
最终的输出为$y = γy + x$,其中$γ$是一个可学习的参数,并且初始化为$0$。网络开始训练时,首先学习局部信息,不采用自注意力模块;随着训练的进行,网络逐渐采用注意力模块学习更多长距离的特征。
作者对图像中的随机五个像素点进行自注意力的可视化,通过对最接近输出层的自注意力模块的注意力图进行可视化,可以发现网络不仅能够区分前景和背景,甚至对一些物体的不同结构也能准确的进行划分:
自注意力机制的实现可参考:
class SelfAttention(nn.Module):
def __init__(self, in_channels, k=8):
super(SelfAttention, self).__init__()
self.inter_channels = in_channels/k
self.f = nn.Conv2d(in_channels, self.inter_channels, 1)
self.g = nn.Conv2d(in_channels, self.inter_channels, 1)
self.h = nn.Conv2d(in_channels, in_channels, 1)
self.o = nn.Conv2d(in_channels, in_channels, 1)
self.gamma = torch.zeros(1).requires_grad_(True)
def forward(self, x):
b, c, h, w = x.shape
fx = self.f(x).view(b, self.inter_channels, -1) # [b, c', hw]
fx = fx.permute(0, 2, 1) # [b, hw, c']
gx = self.g(x).view(b, self.inter_channels, -1) # [b, c', hw]
attn = torch.matmul(fx, gx) # [b, hw, hw]
attn = F.softmax(attn, dim=2) # 按行归一化
hx = self.h(x).view(b, c, -1) # [b, c, hw]
hx = hx.permute(0, 2, 1) # [b, hw, c]
y = torch.matmul(attn, hx) # [b, hw, c]
y = y.permute(0, 2, 1).contiguous() # [b, c, hw]
y = y.view(b, c, h, w)
y = self.o(y)
return self.gamma*y + x
2. 谱归一化
谱归一化(Spectral Normalization)是指使用谱范数(spectral norm)对网络参数进行归一化:
\[W \leftarrow \frac{W}{||W||_2^2}\]谱归一化精确地使网络满足Lipschitz连续性。Lipschitz连续性保证了函数对于输入扰动的稳定性,即函数的输出变化相对输入变化是缓慢的。
谱范数是一种由向量范数诱导出来的矩阵范数,作用相当于向量的模长:
\[||W||_2 = \mathop{\max}_{x \neq 0} \frac{||Wx||}{||x||}\]谱范数$||W||_2$的平方的取值为$W^TW$的最大特征值。
model = Model()
def add_sn(m):
for name, layer in m.named_children():
m.add_module(name, add_sn(layer))
if isinstance(m, (nn.Conv2d, nn.Linear)):
return nn.utils.spectral_norm(m)
else:
return m
model = add_sn(model)
在SAGAN中,对生成器和判别器均使用了谱归一化。
3. TTUR
在设置优化函数时,应设法保证判别器的判别能力比生成器的生成能力要好。通常的做法是先更新判别器的参数多次,再更新一次生成器的参数。
TTUR (Two Time-Scale Update Rule)是指判别器和生成器的更新次数相同,将判别器的学习率设置得比生成器的学习率更大,此时网络收敛于局部纳什均衡:
\[\begin{aligned} θ_D & \leftarrow θ_D + \alpha \nabla_{θ_D}L(D,G) \\ \theta_G & \leftarrow θ_G - \beta \nabla_{θ_G}L(D,G) \end{aligned}\]在SAGAN中,判别器$D$的学习率设置为$\alpha = 0.0004$,生成器$G$的学习率设置为$\beta = 0.0001$。