SAGAN:自注意力生成对抗网络.

SAGAN向生成对抗网络中引入了自注意力机制(Self-Attention),不仅摆脱了卷积层的感受野大小的限制,也使得网络在生成图像的过程中能够自己学习应关注的区域。

对于含有较多几何或结构约束的图像GAN的生成效果较差,这是因为复杂的几何轮廓需要长距离依赖(long-range dependency),卷积层的特点是局部性,受到感受野大小的限制很难提取到图像中的长距离依赖关系。

SAGAN自注意力机制引入模型结构中,有助于对图像区域中长距离、多层次的依赖关系进行建模。此外,生成器和判别器均应用了谱归一化,使得网络满足Lipschitz连续性;训练时遵循TTUR原则,判别器$D$和生成器$G$的学习率是不平衡的。

SAGAN的损失函数采用Hinge损失:

\[\begin{aligned} \mathop{ \max}_{D} & \Bbb{E}_{x \text{~} P_{data}(x)}[\min(0,-1+D(x))] \\ &+ \Bbb{E}_{x \text{~} P_{G}(x)}[\min(0,-1-D(x))] \\ \mathop{ \min}_{G}& -\Bbb{E}_{x \text{~} P_{G}(x)}[D(x)] \end{aligned}\]

1. 自注意力机制

自注意力机制在计算输入位置$i$的特征$y_i$时,考虑所有位置$j$的加权:

\[y_i = \sum_{j}^{} \frac{e^{f(x_i)^Tg(x_j)}}{\sum_j e^{f(x_i)^Tg(x_j)}} h(x_j)\]

自注意力机制的实现步骤如下:

  1. $f(x)$、$g(x)$和$h(x)$通过三个$1\times 1$卷积层实现,$f(x)$和$g(x)$改变了通道数(缩小为$C/8$),$h(x)$维持通道数不变;
  2. 将空间尺寸合并为$H\times W$,将$f(x)$的输出转置后和$g(x)$的输出进行矩阵相乘,经过softmax归一化得到尺寸为$[H\times W,H\times W]$的注意力图;
  3. 将注意力图与$h(x)$的输出进行矩阵相乘,得到尺寸为$[H\times W,C]$的特征图,经过$1\times 1$卷积层并把输出尺寸调整为为$[H,W,C]$;
  4. 最终输出的特征可以通过标量缩放$\gamma$和残差连接构造:$y = γy + x$。

在计算注意力图时,$f(x)$和$g(x)$的输出通道数不影响注意力图的尺寸,较少的通道数会减少参数量和计算量,作者在实验中分别使用$C/k(k=1,2,4,8)$训练后发现对结果影响不大,因此最终选用了$C/8$。

$f(x)$和$g(x)$的注意力图得到的是$[H\times W,H\times W]$的输出,因此表示的是像素点与像素点之间的相关性。当经过了softmax函数之后(注意这里是对每一行单独进行softmax),每一行就代表了一个注意力分别,对应一个特征像素位置($C$个像素通道点)与其它所有像素位置的关系,$H\times W$行对应了$H\times W$个像素位置。

注意力图与$h(x)$的输出进行矩阵相乘,使得$h(x)$的每个特征像素都和其余所有像素建立了联系,结果表示为所有像素按照注意力图提供的注意力分布进行加权组合。

最终的输出为$y = γy + x$,其中$γ$是一个可学习的参数,并且初始化为$0$。网络开始训练时,首先学习局部信息,不采用自注意力模块;随着训练的进行,网络逐渐采用注意力模块学习更多长距离的特征。

作者对图像中的随机五个像素点进行自注意力的可视化,通过对最接近输出层的自注意力模块的注意力图进行可视化,可以发现网络不仅能够区分前景和背景,甚至对一些物体的不同结构也能准确的进行划分:

自注意力机制的实现可参考:

class SelfAttention(nn.Module):
    def __init__(self, in_channels, k=8):
        super(SelfAttention, self).__init__()
        self.inter_channels = in_channels/k
        self.f = nn.Conv2d(in_channels, self.inter_channels, 1)
        self.g = nn.Conv2d(in_channels, self.inter_channels, 1)
        self.h = nn.Conv2d(in_channels, in_channels, 1)
        self.o = nn.Conv2d(in_channels, in_channels, 1)
        self.gamma = torch.zeros(1).requires_grad_(True)

    def forward(self, x):
        b, c, h, w = x.shape
        fx = self.f(x).view(b, self.inter_channels, -1) # [b, c', hw]
        fx = fx.permute(0, 2, 1) # [b, hw, c']
        gx = self.g(x).view(b, self.inter_channels, -1) # [b, c', hw]
        attn = torch.matmul(fx, gx) # [b, hw, hw]
        attn = F.softmax(attn, dim=2) # 按行归一化

        hx = self.h(x).view(b, c, -1) # [b, c, hw]
        hx = hx.permute(0, 2, 1) # [b, hw, c]
        y = torch.matmul(attn, hx) # [b, hw, c]
        y = y.permute(0, 2, 1).contiguous() # [b, c, hw]
        y = y.view(b, c, h, w)
        y = self.o(y)

        return self.gamma*y + x

2. 谱归一化

谱归一化(Spectral Normalization)是指使用谱范数(spectral norm)对网络参数进行归一化:

\[W \leftarrow \frac{W}{||W||_2^2}\]

谱归一化精确地使网络满足Lipschitz连续性Lipschitz连续性保证了函数对于输入扰动的稳定性,即函数的输出变化相对输入变化是缓慢的。

谱范数是一种由向量范数诱导出来的矩阵范数,作用相当于向量的模长:

\[||W||_2 = \mathop{\max}_{x \neq 0} \frac{||Wx||}{||x||}\]

谱范数$||W||_2$的平方的取值为$W^TW$的最大特征值。

model = Model()
def add_sn(m):
        for name, layer in m.named_children():
             m.add_module(name, add_sn(layer))
        if isinstance(m, (nn.Conv2d, nn.Linear)):
             return nn.utils.spectral_norm(m)
        else:
             return m
model = add_sn(model)

SAGAN中,对生成器和判别器均使用了谱归一化。

3. TTUR

在设置优化函数时,应设法保证判别器的判别能力比生成器的生成能力要好。通常的做法是先更新判别器的参数多次,再更新一次生成器的参数。

TTUR (Two Time-Scale Update Rule)是指判别器和生成器的更新次数相同,将判别器的学习率设置得比生成器的学习率更大,此时网络收敛于局部纳什均衡:

\[\begin{aligned} θ_D & \leftarrow θ_D + \alpha \nabla_{θ_D}L(D,G) \\ \theta_G & \leftarrow θ_G - \beta \nabla_{θ_G}L(D,G) \end{aligned}\]

SAGAN中,判别器$D$的学习率设置为$\alpha = 0.0004$,生成器$G$的学习率设置为$\beta = 0.0001$。