BiGAN:使用双向GAN进行对抗特征学习.

BiGAN既可以将隐空间的噪声分布映射到任意复杂的数据分布,又可以将数据映射回隐空间,以此学习有价值的特征表示。

1. 网络结构

该模型包括编码器、生成器(解码器)、判别器三部分。

2. 目标函数

BiGAN的目标函数为:

\[\begin{aligned} \mathop{ \min}_{G,E} \mathop{\max}_{D} \Bbb{E}_{x \text{~} P_{data}(x)}[\Bbb{E}_{z \text{~} P_{E}(\cdot | x)}[\log D(x,z)]] + \Bbb{E}_{z \text{~} P_{Z}(z)}[\Bbb{E}_{x \text{~} P_{G}(\cdot | x)}[\log(1-D(x,z))]] \\ =\mathop{ \min}_{G,E} \mathop{\max}_{D} \Bbb{E}_{x \text{~} P_{data}(x)}[\log D(x,E(x))] + \Bbb{E}_{z \text{~} P_{Z}(z)}[\log(1-D(G(z),z))] \end{aligned}\]

根据GAN的训练技巧,将真假样本的标签翻转过来,在训练时对于生成器能提供更大的梯度,因此将目标函数调整为:

\[\begin{aligned} \mathop{ \min}_{G,E} \mathop{\max}_{D} \Bbb{E}_{x \text{~} P_{data}(x)}[\log (1-D(x,E(x)))] + \Bbb{E}_{z \text{~} P_{Z}(z)}[\log D(G(z),z)] \end{aligned}\]

3. Pytorch实现

# 定义网络结构
encoder = Encoder() # 输出编码向量
decoder = Decoder() # 输出重构图像
discriminator = Discriminator() # 输出分类得分

# 定义损失函数
adversarial_loss = torch.nn.BCELoss() # 判别损失

# 定义优化器
optimizer_G = torch.optim.Adam(
    itertools.chain(encoder.parameters(), decoder.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

for epoch in range(n_epochs):
    for i, real_imgs in enumerate(dataloader):
        # 构造对抗标签
        valid = torch.ones(real_imgs.shape[0], 1)
        fake = torch.zeros(real_imgs.shape[0], 1)
         
        encoded_imgs = encoder(real_imgs)
        decoded_imgs = decoder(encoded_imgs)
        z = torch.randn(real_imgs.shape[0], opt.latent_dim) 
        sampled_imgs = decoder(z)

        # 训练判别器
        real_loss = adversarial_loss(discriminator(encoded_imgs.detach(), real_imgs), fake)
        fake_loss = adversarial_loss(discriminator(sampled_imgs.detach(), z), valid)
        d_loss = 0.5 * (real_loss + fake_loss)
        d_loss.backward()
        optimizer_D.step()

        # 训练编码器和解码器
        optimizer_G.zero_grad()
        real_loss = adversarial_loss(discriminator(encoded_imgs, real_imgs), fake)
        fake_loss = adversarial_loss(discriminator(sampled_imgs, z), valid)
        g_loss = -0.5 * (real_loss + fake_loss)
        g_loss.backward()
        optimizer_G.step()