BGAN:边界搜索GAN.

1. 分析GAN的目标函数

GAN的目标函数如下:

\[\begin{aligned} \mathop{ \min}_{G} \mathop{\max}_{D} L(G,D) & = \Bbb{E}_{x \text{~} P_{data}(x)}[\log D(x)] + \Bbb{E}_{x \text{~} P_{G}(x)}[\log(1-D(x))] \\ & =\int_x (P_{data}(x)\log D(x) + P_{G}(x)\log(1-D(x))) dx \end{aligned}\]

下面先求判别器$D$的最优值$D^{*}$,注意到积分不影响最优值的取得,因此计算被积表达式的极值\(\frac{\partial L(G,D)}{\partial D} = 0\),得:

\[D^*(x) = \frac{P_{data}(x)}{P_{data}(x)+P_{G}(x)} \in [0,1]\]

若生成器$G$也训练到最优值,此时有\(P_{data}(x)≈P_{G}(x)\),则判别器退化为常数 $D^{*}(x)=\frac{1}{2}$,失去判别能力。

当判别器$D$取得最优值$D^{*}$时,目标函数为:

\[\begin{aligned} L(G,D^*) & =\int_x (P_{data}(x)\log D^*(x) + P_{G}(x)\log(1-D^*(x))) dx \\ & =\int_x (P_{data}(x)\log \frac{P_{data}(x)}{P_{data}(x)+P_{G}(x)} + P_{G}(x)\log\frac{P_{G}(x)}{P_{data}(x)+P_{G}(x)}) dx \\ & =\int_x (P_{data}(x)\log \frac{P_{data}(x)}{\frac{P_{data}(x)+P_{G}(x)}{2}} + P_{G}(x)\log\frac{P_{G}(x)}{\frac{P_{data}(x)+P_{G}(x)}{2}}-2\log 2) dx \\ & = 2D_{JS}[P_{data}(x) || P_G(x)]-2\log 2 \end{aligned}\]

其中$D_{JS}$表示JS散度。因此当判别器$D$取得最优时,GAN的损失函数衡量了真实分布\(P_{data}(x)\)与生成分布\(P_G(x)\)之间的JS散度。若生成器$G$也取得最优值,则损失函数取得最小值 $-2\log 2$。

2. Boundary-Seeking GAN

根据上面的讨论,若生成器$G$训练到最优解时判别器退化为常数 $D^{*}(x)=\frac{1}{2}$。因此不妨直接把生成器的目标函数设置为以$D(x)=\frac{1}{2}$为极值点的形式:

\[\begin{aligned} G^* &\leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[ \frac{1}{2}(\log D(x) - \log (1-D(x)))^2 ] \end{aligned}\]

BGAN的完整pytorch实现可参考PyTorch-GAN。下面给出BGAN的训练过程:

discriminator_loss = torch.nn.BCELoss()

for epoch in range(n_epochs):
    for i, real_imgs in enumerate(dataloader):
        # 采样并生成样本
        z = torch.randn(real_imgs.shape[0], latent_dim)
        gen_imgs = generator(z)

        # 训练判别器
        optimizer_D.zero_grad()
        # 计算判别器的损失
        real_loss = discriminator_loss(discriminator(real_imgs), valid)
        fake_loss = discriminator_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        # 更新判别器参数
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        optimizer_G.zero_grad()
        gen_validity = discriminator(gen_imgs)
        g_loss = 0.5 * torch.mean((torch.log(gen_validity) - torch.log(1 - gen_validity)) ** 2)
        g_loss.backward()
        optimizer_G.step()