CoGAN:耦合生成对抗网络.

Coupled GAN (CoGAN)通过学习多个域上的联合分布来实现在无配对数据的情况下的图像转换。

CoGAN使用两个GAN网络,这两个网络通过权重共享机制减少网络参数。两个网络的生成器分别学习不同数据域中的数据分布,判别器则分别判断是否为对应数据域中的真实数据。目标函数如下:

\[\begin{aligned} \mathop{ \min}_{G_1,G_2} \mathop{\max}_{D_1,D_2} & \Bbb{E}_{x_1 \text{~} P_{data}(x_1)}[\log D_1(x_1)] + \Bbb{E}_{z_1 \text{~} P_{Z}(z_1)}[\log(1-D_1(G_1(z_1)))] \\ & + \Bbb{E}_{x_2 \text{~} P_{data}(x_2)}[\log D_2(x_2)] + \Bbb{E}_{z_2 \text{~} P_{Z}(z_2)}[\log(1-D_2(G_2(z_2)))] \end{aligned}\]

生成器的浅层学习到图像数据的高级语义信息,而深层学习到图像的底层细节;对于判别器则相反,是由深层学习图像的语义特征。

作者通过权值共享使得两个生成器和两个判别器学习通用的高级语义特征,对应生成器的浅层共享和判别器的深层共享。同时分别保留对不同域中图像底层细节特征的提取。

CoGAN的完整pytorch实现可参考PyTorch-GAN

class CoupledGenerators(nn.Module):
    def __init__(self):
        super(CoupledGenerators, self).__init__()

        self.init_size = opt.img_size // 4
        self.fc = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.shared_conv = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
        )
        self.G1 = nn.Sequential(
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )
        self.G2 = nn.Sequential(
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise):
        out = self.fc(noise)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img_emb = self.shared_conv(out)
        img1 = self.G1(img_emb)
        img2 = self.G2(img_emb)
        return img1, img2

class CoupledDiscriminators(nn.Module):
    def __init__(self):
        super(CoupledDiscriminators, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            block.extend([nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)])
            return block

        self.D1 = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        self.D2 = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4
        self.shared_conv = nn.Linear(128 * ds_size ** 2, 1)

    def forward(self, img1, img2):
        # Determine validity of first image
        out = self.D1(img1)
        out = out.view(out.shape[0], -1)
        validity1 = self.shared_conv(out)
        # Determine validity of second image
        out = self.D2(img2)
        out = out.view(out.shape[0], -1)
        validity2 = self.shared_conv(out)

        return validity1, validity2