BiGAN:使用双向GAN进行对抗特征学习.
BiGAN既可以将隐空间的噪声分布映射到任意复杂的数据分布,又可以将数据映射回隐空间,以此学习有价值的特征表示。
1. 网络结构
该模型包括编码器、生成器(解码器)、判别器三部分。
- 编码器:把真实图像$x$编码成隐空间特征$z$;
- 生成器:把$z$解码成重构图像;
- 判别器:给定图像$x$和编码$z$,区分是编码器还是解码器提供的。
2. 目标函数
BiGAN的目标函数为:
\[\begin{aligned} \mathop{ \min}_{G,E} \mathop{\max}_{D} \Bbb{E}_{x \text{~} P_{data}(x)}[\Bbb{E}_{z \text{~} P_{E}(\cdot | x)}[\log D(x,z)]] + \Bbb{E}_{z \text{~} P_{Z}(z)}[\Bbb{E}_{x \text{~} P_{G}(\cdot | x)}[\log(1-D(x,z))]] \\ =\mathop{ \min}_{G,E} \mathop{\max}_{D} \Bbb{E}_{x \text{~} P_{data}(x)}[\log D(x,E(x))] + \Bbb{E}_{z \text{~} P_{Z}(z)}[\log(1-D(G(z),z))] \end{aligned}\]根据GAN的训练技巧,将真假样本的标签翻转过来,在训练时对于生成器能提供更大的梯度,因此将目标函数调整为:
\[\begin{aligned} \mathop{ \min}_{G,E} \mathop{\max}_{D} \Bbb{E}_{x \text{~} P_{data}(x)}[\log (1-D(x,E(x)))] + \Bbb{E}_{z \text{~} P_{Z}(z)}[\log D(G(z),z)] \end{aligned}\]3. Pytorch实现
# 定义网络结构
encoder = Encoder() # 输出编码向量
decoder = Decoder() # 输出重构图像
discriminator = Discriminator() # 输出分类得分
# 定义损失函数
adversarial_loss = torch.nn.BCELoss() # 判别损失
# 定义优化器
optimizer_G = torch.optim.Adam(
itertools.chain(encoder.parameters(), decoder.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
for epoch in range(n_epochs):
for i, real_imgs in enumerate(dataloader):
# 构造对抗标签
valid = torch.ones(real_imgs.shape[0], 1)
fake = torch.zeros(real_imgs.shape[0], 1)
encoded_imgs = encoder(real_imgs)
decoded_imgs = decoder(encoded_imgs)
z = torch.randn(real_imgs.shape[0], opt.latent_dim)
sampled_imgs = decoder(z)
# 训练判别器
real_loss = adversarial_loss(discriminator(encoded_imgs.detach(), real_imgs), fake)
fake_loss = adversarial_loss(discriminator(sampled_imgs.detach(), z), valid)
d_loss = 0.5 * (real_loss + fake_loss)
d_loss.backward()
optimizer_D.step()
# 训练编码器和解码器
optimizer_G.zero_grad()
real_loss = adversarial_loss(discriminator(encoded_imgs, real_imgs), fake)
fake_loss = adversarial_loss(discriminator(sampled_imgs, z), valid)
g_loss = -0.5 * (real_loss + fake_loss)
g_loss.backward()
optimizer_G.step()