对抗自编码器.

① 研究背景

VAE的损失函数可以分成两部分:

\[\begin{aligned} \mathcal{L} &= \mathbb{E}_{z \text{~} q(z|x)} [-\log p(x | z)] + KL[q(z|x)||p(z)] \end{aligned}\]

其中$\mathbb{E}_{z \text{~} q(z|x)} [-\log p(x | z)]$表示生成模型$p(x|z)$的重构损失,$KL[q(z|x)||p(z)]$表示后验分布$q(z|x)$的正则化项(KL损失)。

② 模型结构

Adversarial Autoencoder (AAE)采用对抗学习的思想构造后验分布$q(z|x)$的正则化项。通过引入一个判别器区分从后验分布中重参数化的隐变量$z$和从先验分布$p(z)$中采样的隐变量。

相比于VAE预设后验分布$q(z|x)$为正态分布(便于计算KL散度),AAE中的先验分布$p(z)$可以选择任意分布,只要保证能够进行采样即可。

③ Pytorch实现

AAE的完整pytorch实现可参考PyTorch-GAN,下面给出其损失函数的计算和参数更新过程:

# 定义网络结构
encoder = Encoder() # 输出重参数化后的z
decoder = Decoder() # 输出重构图像
discriminator = Discriminator() # 输出分类得分

# 定义损失函数
adversarial_loss = torch.nn.BCELoss() # 判别损失
pixelwise_loss = torch.nn.L1Loss() # 重构损失

# 定义优化器
optimizer_G = torch.optim.Adam(
    itertools.chain(encoder.parameters(), decoder.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

for epoch in range(n_epochs):
    for i, real_imgs in enumerate(dataloader):
        # 构造对抗标签
        valid = torch.ones(real_imgs.shape[0], 1)
        fake = torch.zeros(real_imgs.shape[0], 1)
         
        encoded_imgs = encoder(real_imgs)
        decoded_imgs = decoder(encoded_imgs)

        # 训练判别器
        z = torch.randn(real_imgs.shape[0], opt.latent_dim) # p(z)可以设置任意分布
        real_loss = adversarial_loss(discriminator(z), valid)
        fake_loss = adversarial_loss(discriminator(encoded_imgs.detach()), fake)
        d_loss = 0.5 * (real_loss + fake_loss)
        d_loss.backward()
        optimizer_D.step()

        # 训练编码器和解码器
        optimizer_G.zero_grad()
        g_loss = 0.001 * adversarial_loss(discriminator(encoded_imgs), valid) + 0.999 * pixelwise_loss(
            decoded_imgs, real_imgs
        )
        g_loss.backward()
        optimizer_G.step()