InfoGAN:通过最大化互信息实现可插值的表示学习.

InfoGAN可以根据给定条件生成具有某一特征的图像。不同于通常的条件GAN会预先给定图像的标签或类别,在InfoGAN中图像的条件是通过一个编码器$Q$(若条件是离散类别,则可看作分类器)得到的。

InfoGAN的生成器接收随机噪声$z$和条件编码$c$,生成给定条件$c$时的图像$G(z,c)$;判别器$D(x)$接收图像$x$,判断图像$x$是否为真实图像。编码器则把生成图像编码为对应的输入条件$\hat{c}=E(G(c))$。

在实现时编码器和判别器参数共享,仅在最后一层采用不同的网络层。

img_shape = (opt.channels, opt.img_size, opt.img_size)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        input_dim = opt.latent_dim + opt.n_classes + opt.code_dim

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(input_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, code):
        gen_input = torch.cat((noise, code), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *img_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 128),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # Output layers
        self.adv_layer = nn.Sequential(nn.Linear(128, 1), nn.Sigmoid())
        self.latent_layer = nn.Sequential(nn.Linear(128, opt.code_dim))

    def forward(self, img):
        out = self.model(img.view(img.size(0), -1))
        validity = self.adv_layer(out)
        latent_code = self.latent_layer(out)
        return validity, latent_code

InfoGAN在原GAN的目标函数中额外引入了条件编码$c$和生成图像$G(z,c)$的互信息$I(c;G(z,c))$,如下:

\[\begin{aligned} \mathop{ \min}_{G} \mathop{\max}_{D} & \Bbb{E}_{x \text{~} P_{data}(x)}[\log D(x)] + \Bbb{E}_{z \text{~} P_{Z}(z)}[\log(1-D(G(z,c)))] \\ & -\lambda I(c;G(z,c)) \end{aligned}\]

互信息$I(c;G(z,c))$衡量了随机变量$c$由于已知随机变量$G(z,c)$而减少的不确定性。生成器通过最大化互信息使得$G(z,c)$中包含尽可能多的$c$的信息量。

互信息仍然是难以直接处理的。不妨寻找互信息的一个下界:

\[\begin{aligned} I(c;G(z,c)) &= I(c;x)\\ &= \iint q(c,x) \log \frac{q(c,x)}{q(c)q(x)}dcdx \\ &= \iint q(x|c) q(c) \log \frac{q(c|x)}{q(c)}dcdx \\ &= \iint q(x|c) q(c) \log \frac{p(c|x)q(c|x)}{p(c|x)q(c)}dcdx \\ &= \iint q(x|c) q(c) \log \frac{p(c|x)}{q(c)}dcdx + \iint q(x|c) q(c) \log \frac{q(c|x)}{p(c|x)}dcdx \\ &= \iint q(x|c) q(c) \log \frac{p(c|x)}{q(c)}dcdx + \iint q(c|x) q(x) \log \frac{q(c|x)}{p(c|x)}dcdx \\&=\iint q(x|c) q(c) \log \frac{p(c|x)}{q(c)}dcdx + \int q(c|x) \log \frac{q(c|x)}{p(c|x)}dc \\ &= \iint q(x|c) q(c) \log \frac{p(c|x)}{q(c)}dcdx + D_{KL}[q(c|x)||p(c|x)] \\ &\geq \iint q(x|c) q(c) \log \frac{p(c|x)}{q(c)}dcdx \\ &= \iint q(x|c) q(c) \log p(c|x)dcdx - \iint q(x|c) q(c) \log q(c)dcdx \\ & = \iint q(x|c) q(c) \log p(c|x)dcdx - \iint q(c) \log q(c)dc \\ &= \iint q(x|c) q(c) \log p(c|x)dcdx + Const. \end{aligned}\]

最大化互信息等价于最大化互信息的一个下界。其中$p(c|x)$可以任意指定分布,不妨取正态分布$p(c|x)$~\(\mathcal{N}(c;Q(x),\sigma^2)\),其中$Q(x)$是一个带参数的编码器。此时互信息的下界表示为:

\[\begin{aligned} I(c;G(z,c))= I(c;x) &\geq \iint q(x|c) q(c) \log p(c|x)dcdx \\ &= \iint q(x|c) q(c) \log \mathcal{N}(c;Q(x),\sigma^2)dcdx \\ &= \iint q(x|c) q(c) \log \frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{||c-Q(x)||^2}{2\sigma^2}} dcdx \\ &= \int q(c) \log \frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{||c-Q(x)||^2}{2\sigma^2}} dc \\ & \leftrightarrow - \Bbb{E}_{c \text{~} q(c)}[||c-Q(x)||^2] \end{aligned}\]

在实现时互信息的一个下界可以用条件编码的重构误差表示。至此InfoGAN的目标函数写作:

\[\begin{aligned} \mathop{ \min}_{G,Q} \mathop{\max}_{D} & \Bbb{E}_{x \text{~} P_{data}(x)}[\log D(x)] + \Bbb{E}_{z \text{~} P_{Z}(z)}[\log(1-D(G(z,c)))] \\ & +\lambda \Bbb{E}_{z \text{~} P_{Z}(z)}[||c-Q(G(z,c))||^2] \end{aligned}\]

InfoGAN的完整pytorch实现可参考PyTorch-GAN,下面给出其损失函数的计算和参数更新过程:

# Loss functions
adversarial_loss = torch.nn.BCELoss()
continuous_loss = torch.nn.MSELoss()

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_info = torch.optim.Adam(
    itertools.chain(generator.parameters(), discriminator.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)

for epoch in range(opt.n_epochs):
    for i, real_imgs in enumerate(dataloader):
        # Adversarial ground truths
        valid = torch.ones(real_imgs.shape[0], 1).requires_grad_.(False)
        fake = torch.zeros(real_imgs.shape[0], 1).requires_grad_.(False)

        # Sample noise and code as generator input
        z = torch.randn(real_imgs.shape[0], latent_dim)
        gen_codes = torch.Tensor(real_imgs.shape[0],opt.code_dim).uniform_(-1,1)

        # Generate a batch of images
        gen_imgs = generator(z, gen_codes)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Loss for real images
        real_pred, _ = discriminator(real_imgs)
        d_real_loss = adversarial_loss(real_pred, valid)

        # Loss for fake images
        fake_pred, _ = discriminator(gen_imgs.detach())
        d_fake_loss = adversarial_loss(fake_pred, fake)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()

        # Loss measures generator's ability to fool the discriminator
        validity, _ = discriminator(gen_imgs)
        g_loss = adversarial_loss(validity, valid)
        g_loss.backward()
        optimizer_G.step()

        # ------------------
        # Information Loss
        # ------------------
        optimizer_info.zero_grad()
        _, pred_code = discriminator(gen_imgs)
        info_loss = lambda_con * continuous_loss(pred_code, gen_codes)
        info_loss.backward()
        optimizer_info.step()