BGAN:边界搜索GAN.
1. 分析GAN的目标函数
GAN的目标函数如下:
\[\begin{aligned} \mathop{ \min}_{G} \mathop{\max}_{D} L(G,D) & = \Bbb{E}_{x \text{~} P_{data}(x)}[\log D(x)] + \Bbb{E}_{x \text{~} P_{G}(x)}[\log(1-D(x))] \\ & =\int_x (P_{data}(x)\log D(x) + P_{G}(x)\log(1-D(x))) dx \end{aligned}\]下面先求判别器$D$的最优值$D^{*}$,注意到积分不影响最优值的取得,因此计算被积表达式的极值\(\frac{\partial L(G,D)}{\partial D} = 0\),得:
\[D^*(x) = \frac{P_{data}(x)}{P_{data}(x)+P_{G}(x)} \in [0,1]\]若生成器$G$也训练到最优值,此时有\(P_{data}(x)≈P_{G}(x)\),则判别器退化为常数 $D^{*}(x)=\frac{1}{2}$,失去判别能力。
当判别器$D$取得最优值$D^{*}$时,目标函数为:
\[\begin{aligned} L(G,D^*) & =\int_x (P_{data}(x)\log D^*(x) + P_{G}(x)\log(1-D^*(x))) dx \\ & =\int_x (P_{data}(x)\log \frac{P_{data}(x)}{P_{data}(x)+P_{G}(x)} + P_{G}(x)\log\frac{P_{G}(x)}{P_{data}(x)+P_{G}(x)}) dx \\ & =\int_x (P_{data}(x)\log \frac{P_{data}(x)}{\frac{P_{data}(x)+P_{G}(x)}{2}} + P_{G}(x)\log\frac{P_{G}(x)}{\frac{P_{data}(x)+P_{G}(x)}{2}}-2\log 2) dx \\ & = 2D_{JS}[P_{data}(x) || P_G(x)]-2\log 2 \end{aligned}\]其中$D_{JS}$表示JS散度。因此当判别器$D$取得最优时,GAN的损失函数衡量了真实分布\(P_{data}(x)\)与生成分布\(P_G(x)\)之间的JS散度。若生成器$G$也取得最优值,则损失函数取得最小值 $-2\log 2$。
2. Boundary-Seeking GAN
根据上面的讨论,若生成器$G$训练到最优解时判别器退化为常数 $D^{*}(x)=\frac{1}{2}$。因此不妨直接把生成器的目标函数设置为以$D(x)=\frac{1}{2}$为极值点的形式:
\[\begin{aligned} G^* &\leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[ \frac{1}{2}(\log D(x) - \log (1-D(x)))^2 ] \end{aligned}\]BGAN的完整pytorch实现可参考PyTorch-GAN。下面给出BGAN的训练过程:
discriminator_loss = torch.nn.BCELoss()
for epoch in range(n_epochs):
for i, real_imgs in enumerate(dataloader):
# 采样并生成样本
z = torch.randn(real_imgs.shape[0], latent_dim)
gen_imgs = generator(z)
# 训练判别器
optimizer_D.zero_grad()
# 计算判别器的损失
real_loss = discriminator_loss(discriminator(real_imgs), valid)
fake_loss = discriminator_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
# 更新判别器参数
d_loss.backward()
optimizer_D.step()
# 训练生成器
optimizer_G.zero_grad()
gen_validity = discriminator(gen_imgs)
g_loss = 0.5 * torch.mean((torch.log(gen_validity) - torch.log(1 - gen_validity)) ** 2)
g_loss.backward()
optimizer_G.step()