EBGAN:基于能量的生成对抗网络.

1. 将能量模型引入GAN

能量模型是指使用如下能量分布拟合一批真实数据$x_1,x_2,\cdots,x_n$~\(P_{data}(x)\):

\[q_{\theta}(x) = \frac{e^{-U_{\theta}(x)}}{Z_{\theta}},Z_{\theta} = \int e^{-U_{\theta}(x)}dx\]

其中$U_{\theta}(x)$是带参数的能量函数;$Z_{\theta}$是配分函数(归一化因子)。直观地,真实数据分布在能量函数中势最小的位置。我们希望通过对抗训练使得生成数据$\hat{x}_1,\hat{x}_2,\cdots \hat{x}_n$的势也尽可能小。

使用判别器$D(x)$拟合能量函数$U_{\theta}(x)$,使用生成器$G(x)$构造生成分布$P_G(x)$。则判别器的目标函数为最小化真实数据分布的能量,并最大化生成数据分布的能量:

\[D^* \leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]- \Bbb{E}_{x \text{~} P_G(x)}[D(x) ]\]

与此同时生成器的目标函数为最小化生成数据分布的能量:

\[G^* \leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[D(x) ]\]

至此,在能量模型的角度下,GAN的目标函数写作:

\[\begin{aligned} D^* &\leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]- \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \\ G^* &\leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \end{aligned}\]

2. Energy-based GAN (EBGAN)

EBGAN的判别器采用自编码器的形式,且能量函数采用样本的重构损失:

\[U(x) = ||D(x)-x|| = ||Dec(Enc(x))-x||\]

直观地,如果一幅图像经过自编码器可以被很好的还原,则判别器认为其是真实图像,此时重构误差比较小,可以看作图像的“能量”,且最小能量值为$0$。

EBGAN的判别器实现如下(注:判别器输出编码特征,是因为后续构造损失函数需要用到):

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Downsampling
        self.down = nn.Sequential(nn.Conv2d(opt.channels, 64, 3, 2, 1), nn.ReLU())

        # Embedding layer
        self.down_size = opt.img_size // 2
        down_dim = 64 * self.down_size ** 2
        self.embedding = nn.Linear(down_dim, 32)

        # Fully-connected layers
        self.fc = nn.Sequential(
            nn.BatchNorm1d(32, 0.8),
            nn.ReLU(inplace=True),
            nn.Linear(32, down_dim),
            nn.BatchNorm1d(down_dim),
            nn.ReLU(inplace=True),
        )

        # Upsampling
        self.up = nn.Sequential(nn.Upsample(scale_factor=2), nn.Conv2d(64, opt.channels, 3, 1, 1))

    def forward(self, img):
        out = self.down(img)
        embedding = self.embedding(out.view(out.size(0), -1))
        out = self.fc(embedding)
        out = self.up(out.view(out.size(0), 64, self.down_size, self.down_size))
        return out, embedding

EBGAN的生成器为标准的GAN生成器,实现如下:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = opt.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise):
        out = self.l1(noise)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

3. EBGAN的目标函数

能量模型角度下GAN的目标函数为:

\[\begin{aligned} D^* &\leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]- \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \\ G^* &\leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \end{aligned}\]

注意到EBGAN的能量函数由均方误差构造,因此能量最小值为$0$。直接优化上式容易导致生成样本的能量$\to \infty$,从而使训练不稳定。

在实践中,通常限制生成样本的能量不超过$m$:

\[\begin{aligned} D^* &\leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]+ \Bbb{E}_{x \text{~} P_G(x)}[\max(0, m-D(x)) ] \\ G^* &\leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \end{aligned}\]

此外,作者设计了排斥正则化(repelling regularizer)方法,旨在训练生成器时使得生成样本的多样性越大越好。具体地,对于生成样本$x_i,x_j$,希望其判别器中的编码器提取的特征编码$S_i,S_j$相似程度越低越好,采用余弦相似度衡量:

\[f_{PT}(S) = \frac{1}{N(N-1)}\sum_i \sum_{j \neq i} (\frac{S_i^TS_j}{||S_i|| \cdot ||S_j||})^2\]

将上式作为pulling-away损失加入到生成器的目标函数中:

\[\begin{aligned} D^* &\leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]+ \Bbb{E}_{x \text{~} P_G(x)}[\max(0, m-D(x)) ] \\ G^* &\leftarrow \mathop{ \min}_{G} \Bbb{E}_{(x_i,x_j) \text{~} P_G(x)}[(\frac{E(x_i)^TE(x_j)}{||E(x_i)|| \cdot ||E(x_j)||})^2]+ \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \end{aligned}\]

pulling-away损失的实现如下:

def pullaway_loss(embeddings):
    # embeddings -> [batch, dims]
    norm = torch.sqrt(torch.sum(embeddings ** 2, -1, keepdim=True))
    normalized_emb = embeddings / norm
    similarity = torch.matmul(normalized_emb, normalized_emb.transpose(1, 0)) # [batch, batch]
    batch_size = embeddings.size(0)
    loss_pt = torch.sum(similarity) / (batch_size * (batch_size - 1))
    return loss_pt

下面给出EBGAN的训练过程:

lambda_pt = 0.1
margin = max(1, opt.batch_size / 64.0)

# 定义损失函数
pixelwise_loss = nn.MSELoss() # 均方误差损失

for epoch in range(n_epochs):
    for i, real_imgs in enumerate(dataloader):
        # 采样并生成样本
        z = torch.randn(real_imgs.shape[0], latent_dim)
        gen_imgs = generator(z)

        # 训练判别器
        optimizer_D.zero_grad()
        # 计算判别器的损失
        real_recon, _ = discriminator(real_imgs)
        fake_recon, _ = discriminator(gen_imgs.detach())
        d_loss_real = pixelwise_loss(real_recon, real_imgs)
        d_loss_fake = pixelwise_loss(fake_recon, gen_imgs.detach())
        d_loss = d_loss_real - torch.clamp(d_loss_fake, max=margin)
        # 更新判别器参数
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        optimizer_G.zero_grad()
        fake_recon, img_embeddings = discriminator(gen_imgs)
        g_loss_fake = pixelwise_loss(fake_recon, gen_imgs)
        g_loss_pull = pullaway_loss(img_embeddings)
        g_loss = g_loss_fake + lambda_pt * g_loss_pull
        g_loss.backward()
        optimizer_G.step()

EBGAN的完整pytorch实现可参考PyTorch-GAN