BEGAN:边界平衡GAN.

1. 将能量模型引入GAN

能量模型是指使用如下能量分布拟合一批真实数据$x_1,x_2,\cdots,x_n$~\(P_{data}(x)\):

\[q_{\theta}(x) = \frac{e^{-U_{\theta}(x)}}{Z_{\theta}},Z_{\theta} = \int e^{-U_{\theta}(x)}dx\]

其中$U_{\theta}(x)$是带参数的能量函数;$Z_{\theta}$是配分函数(归一化因子)。直观地,真实数据分布在能量函数中势最小的位置。我们希望通过对抗训练使得生成数据$\hat{x}_1,\hat{x}_2,\cdots \hat{x}_n$的势也尽可能小。

使用判别器$D(x)$拟合能量函数$U_{\theta}(x)$,使用生成器$G(x)$构造生成分布$P_G(x)$。则判别器的目标函数为最小化真实数据分布的能量,并最大化生成数据分布的能量:

\[D^* \leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]- \Bbb{E}_{x \text{~} P_G(x)}[D(x) ]\]

与此同时生成器的目标函数为最小化生成数据分布的能量:

\[G^* \leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[D(x) ]\]

至此,在能量模型的角度下,GAN的目标函数写作:

\[\begin{aligned} D^* &\leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]- \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \\ G^* &\leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \end{aligned}\]

2. Boundary Equilibrium GAN (BEGAN)

BEGAN的判别器采用自编码器的形式,且能量函数采用样本的L1损失:

\[U(x) = |D(x)-x| = |Dec(Enc(x))-x|\]

BEGAN的生成器与判别器的解码器结构相同:

BEGAN的目标函数为:

\[\begin{aligned} D^* &\leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]- k_t \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \\ G^* &\leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \end{aligned}\]

训练的初始阶段$t=0$,$k_0=0$,判别器最小化真实图像的能量。之后更新$k_t$:

\[k_{t+1} = k_t + \lambda (\gamma D(x)-D(G(z)))\]

直观地,当生成图像的能量$D(G(z))$小于$\gamma$倍真实图像的能量$D(x)$时,$k_t$才会变大,使得判别器考虑增大生成图像的能量。

下面给出BEGAN的训练过程:

# BEGAN hyper parameters
gamma = 0.75
lambda_k = 0.001
k = 0.0

for epoch in range(n_epochs):
    for i, real_imgs in enumerate(dataloader):
        # 采样并生成样本
        z = torch.randn(real_imgs.shape[0], latent_dim)
        gen_imgs = generator(z)

        # 训练判别器
        optimizer_D.zero_grad()
        # 计算判别器的损失
        d_real = discriminator(real_imgs)
        d_fake = discriminator(gen_imgs.detach())
        d_loss_real = torch.mean(torch.abs(d_real - real_imgs))
        d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach()))
        d_loss = d_loss_real - k * d_loss_fake
        d_loss = d_loss_real - torch.clamp(d_loss_fake, max=margin)
        # 更新判别器参数
        d_loss.backward()
        optimizer_D.step()

        # 更新超参数k
        diff = torch.mean(gamma * d_loss_real - d_loss_fake)
        k = k + lambda_k * diff.item()
        k = min(max(k, 0), 1)  # Constraint to interval [0, 1]        

        # 训练生成器
        optimizer_G.zero_grad()
        gen_imgs = generator(z)
        g_loss = torch.mean(torch.abs(discriminator(gen_imgs) - gen_imgs))
        g_loss.backward()
        optimizer_G.step()

BEGAN的完整pytorch实现可参考PyTorch-GAN