BEGAN:边界平衡GAN.
1. 将能量模型引入GAN
能量模型是指使用如下能量分布拟合一批真实数据$x_1,x_2,\cdots,x_n$~\(P_{data}(x)\):
\[q_{\theta}(x) = \frac{e^{-U_{\theta}(x)}}{Z_{\theta}},Z_{\theta} = \int e^{-U_{\theta}(x)}dx\]其中$U_{\theta}(x)$是带参数的能量函数;$Z_{\theta}$是配分函数(归一化因子)。直观地,真实数据分布在能量函数中势最小的位置。我们希望通过对抗训练使得生成数据$\hat{x}_1,\hat{x}_2,\cdots \hat{x}_n$的势也尽可能小。
使用判别器$D(x)$拟合能量函数$U_{\theta}(x)$,使用生成器$G(x)$构造生成分布$P_G(x)$。则判别器的目标函数为最小化真实数据分布的能量,并最大化生成数据分布的能量:
\[D^* \leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]- \Bbb{E}_{x \text{~} P_G(x)}[D(x) ]\]与此同时生成器的目标函数为最小化生成数据分布的能量:
\[G^* \leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[D(x) ]\]至此,在能量模型的角度下,GAN的目标函数写作:
\[\begin{aligned} D^* &\leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]- \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \\ G^* &\leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \end{aligned}\]2. Boundary Equilibrium GAN (BEGAN)
BEGAN的判别器采用自编码器的形式,且能量函数采用样本的L1损失:
\[U(x) = |D(x)-x| = |Dec(Enc(x))-x|\]BEGAN的生成器与判别器的解码器结构相同:
BEGAN的目标函数为:
\[\begin{aligned} D^* &\leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]- k_t \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \\ G^* &\leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \end{aligned}\]训练的初始阶段$t=0$,$k_0=0$,判别器最小化真实图像的能量。之后更新$k_t$:
\[k_{t+1} = k_t + \lambda (\gamma D(x)-D(G(z)))\]直观地,当生成图像的能量$D(G(z))$小于$\gamma$倍真实图像的能量$D(x)$时,$k_t$才会变大,使得判别器考虑增大生成图像的能量。
下面给出BEGAN的训练过程:
# BEGAN hyper parameters
gamma = 0.75
lambda_k = 0.001
k = 0.0
for epoch in range(n_epochs):
for i, real_imgs in enumerate(dataloader):
# 采样并生成样本
z = torch.randn(real_imgs.shape[0], latent_dim)
gen_imgs = generator(z)
# 训练判别器
optimizer_D.zero_grad()
# 计算判别器的损失
d_real = discriminator(real_imgs)
d_fake = discriminator(gen_imgs.detach())
d_loss_real = torch.mean(torch.abs(d_real - real_imgs))
d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach()))
d_loss = d_loss_real - k * d_loss_fake
d_loss = d_loss_real - torch.clamp(d_loss_fake, max=margin)
# 更新判别器参数
d_loss.backward()
optimizer_D.step()
# 更新超参数k
diff = torch.mean(gamma * d_loss_real - d_loss_fake)
k = k + lambda_k * diff.item()
k = min(max(k, 0), 1) # Constraint to interval [0, 1]
# 训练生成器
optimizer_G.zero_grad()
gen_imgs = generator(z)
g_loss = torch.mean(torch.abs(discriminator(gen_imgs) - gen_imgs))
g_loss.backward()
optimizer_G.step()
BEGAN的完整pytorch实现可参考PyTorch-GAN。