对抗自编码器.
- paper:Adversarial Autoencoders
① 研究背景
VAE的损失函数可以分成两部分:
\[\begin{aligned} \mathcal{L} &= \mathbb{E}_{z \text{~} q(z|x)} [-\log p(x | z)] + KL[q(z|x)||p(z)] \end{aligned}\]其中$\mathbb{E}_{z \text{~} q(z|x)} [-\log p(x | z)]$表示生成模型$p(x|z)$的重构损失,$KL[q(z|x)||p(z)]$表示后验分布$q(z|x)$的正则化项(KL损失)。
② 模型结构
Adversarial Autoencoder (AAE)采用对抗学习的思想构造后验分布$q(z|x)$的正则化项。通过引入一个判别器区分从后验分布中重参数化的隐变量$z$和从先验分布$p(z)$中采样的隐变量。
相比于VAE预设后验分布$q(z|x)$为正态分布(便于计算KL散度),AAE中的先验分布$p(z)$可以选择任意分布,只要保证能够进行采样即可。
③ Pytorch实现
AAE的完整pytorch实现可参考PyTorch-GAN,下面给出其损失函数的计算和参数更新过程:
# 定义网络结构
encoder = Encoder() # 输出重参数化后的z
decoder = Decoder() # 输出重构图像
discriminator = Discriminator() # 输出分类得分
# 定义损失函数
adversarial_loss = torch.nn.BCELoss() # 判别损失
pixelwise_loss = torch.nn.L1Loss() # 重构损失
# 定义优化器
optimizer_G = torch.optim.Adam(
itertools.chain(encoder.parameters(), decoder.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
for epoch in range(n_epochs):
for i, real_imgs in enumerate(dataloader):
# 构造对抗标签
valid = torch.ones(real_imgs.shape[0], 1)
fake = torch.zeros(real_imgs.shape[0], 1)
encoded_imgs = encoder(real_imgs)
decoded_imgs = decoder(encoded_imgs)
# 训练判别器
z = torch.randn(real_imgs.shape[0], opt.latent_dim) # p(z)可以设置任意分布
real_loss = adversarial_loss(discriminator(z), valid)
fake_loss = adversarial_loss(discriminator(encoded_imgs.detach()), fake)
d_loss = 0.5 * (real_loss + fake_loss)
d_loss.backward()
optimizer_D.step()
# 训练编码器和解码器
optimizer_G.zero_grad()
g_loss = 0.001 * adversarial_loss(discriminator(encoded_imgs), valid) + 0.999 * pixelwise_loss(
decoded_imgs, real_imgs
)
g_loss.backward()
optimizer_G.step()