ClusterGAN:生成对抗网络的隐空间聚类.

ClusterGAN通过从一个one-hot编码变量和连续隐变量的混合分布中对隐变量进行采样,结合GAN模型和一个编码器(将数据投影到隐空间)共同训练,能够实现在隐空间的聚类。

1. 网络结构

ClusterGAN由生成器、判别器和编码器构成。

生成器$G$从一个离散分布$z_c$和连续分布$z_n$共同组成的分布中采样生成图像$x_g$;判别器$D$用于区分生成图像$x_g$和真实图像$x_r$;编码器把生成图像$x_g$编码为重构的离散编码$\hat{z}_c$和连续编码$\hat{z}_n$。

从混合分布中采样的流程如下:

def sample_z(shape=64, latent_dim=10, n_c=10, fix_class=-1, req_grad=False):
    assert (fix_class == -1 or (fix_class >= 0 and fix_class < n_c) ), "Requested class %i outside bounds."%fix_class
    # Sample noise as generator input, zn
    zn = torch.randn((shape, latent_dim)).requires_grad_(req_grad)
    ######### zc, zc_idx variables with grads, and zc to one-hot vector
    # Pure one-hot vector generation
    zc_FT = torch.zeros((shape, n_c))
    zc_idx = torch.empty(shape, dtype=torch.long)
    if (fix_class == -1):
        zc_idx = zc_idx.random_(n_c)
        zc_FT = zc_FT.scatter_(1, zc_idx.unsqueeze(1), 1.)
    else:
        zc_idx[:] = fix_class
        zc_FT[:, fix_class] = 1
    zc = zc_FT.requires_grad_(req_grad)
    # Return components of latent space variable
    return zn, zc, zc_idx

训练完成后,从隐空间中采样的隐变量具有聚类特性:

2. 损失函数

ClusterGAN的目标函数可以拆分成三部分,即对抗损失、连续编码$\hat{z}_n$的L2重构损失和离散编码$\hat{z}_c$的交叉熵损失。

\[\begin{aligned} \mathop{ \min}_{G,E} \mathop{\max}_{D} & \Bbb{E}_{x \text{~} P_{data}(x)}[\log D(x)] + \Bbb{E}_{(z_n,z_c) \text{~} P(z)} [\log(1-D(G(z_n,z_c)))] \\ & + \beta_n \Bbb{E}_{(z_n,z_c) \text{~} P(z)} [||z_n-E^{n}(G(z_n,z_c))||_2^2] \\ & + \beta_c \Bbb{E}_{(z_n,z_c) \text{~} P(z)} [z_c \log E^{c}(G(z_n,z_c))] \end{aligned}\]

ClusterGAN的完整pytorch实现可参考PyTorch-GAN,下面给出其损失函数的计算和参数更新过程:

# Loss function
bce_loss = torch.nn.BCELoss()
xe_loss = torch.nn.CrossEntropyLoss()
mse_loss = torch.nn.MSELoss()

# Optimizers
optimizer_G = torch.optim.Adam(
    itertools.chain(E.parameters(), G.parameters()),
    lr=opt.lr, betas=(opt.b1, opt.b2),
)
optimizer_D = torch.optim.Adam(D1.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

for epoch in range(opt.n_epochs):
    for i, real_imgs in enumerate(dataloader):
        # Adversarial ground truths
        valid = torch.ones(real_imgs.shape[0]).requires_grad_(False)
        fake = torch.zeros(real_imgs.shape[0]).requires_grad_(False)

        # ----------------------------------
        # forward propogation
        # ----------------------------------
        # Sample random latent variables
        zn, zc, zc_idx = sample_z(shape=real_imgs.shape[0],
                                  latent_dim=latent_dim,
                                  n_c=n_c)
        # Generate a batch of images
        gen_imgs = generator(zn, zc)        

        # -----------------------
        #  Train Discriminator
        # -----------------------
        optimizer_D.zero_grad()

        # Discriminator output from real and generated samples
        D_gen = discriminator(gen_imgs.detach())
        D_real = discriminator(real_imgs)

        real_loss = bce_loss(D_real, valid)
        fake_loss = bce_loss(D_gen, fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # -------------------------------
        #  Train Generator and Encoder
        # -------------------------------
        optimizer_G.zero_grad()
        enc_gen_zn, enc_gen_zc, enc_gen_zc_logits = encoder(gen_imgs)

        # Calculate losses for z_n, z_c
        zn_loss = mse_loss(enc_gen_zn, zn)
        zc_loss = xe_loss(enc_gen_zc_logits, zc_idx)

        D_gen = discriminator(gen_imgs)
        gan_loss = bce_loss(D_gen, valid)
        g_loss = gan_loss + betan * zn_loss + betac * zc_loss
        g_loss.backward()
        optimizer_G.step()