UNIT:无监督图像到图像翻译网络.

图像翻译任务(Image-to-Image Translation)有监督和无监督两种形式。对于监督式图像翻译,需要提供一一对应的图像数据集,本文作者设计了一种无监督式的图像翻译框架UNIT,只需要给定两种不同风格的数据集,该网络可以学习两种图像风格之间的变换关系。

作者假设不同风格的图像集存在共享的隐变量空间,即每一对图像$x_1,x_2$都可以在隐空间中找到同一个对应的隐变量$z$。

图像集与隐空间之间的映射关系通过VAE实现,分别使用两个编码器把图像映射到隐空间,再分别使用两个生成器把隐变量重构为图像。与此同时,引入两个判别器分别判断两种类型图像的真实性。

1. UNIT的编码器

UNIT的两个编码器采用权重共享设计,即共享编码器的深层特征,这些特征通常被认为携带高级语义信息,这些信息在不同图像域中是共享的。

class Encoder(nn.Module):
    def __init__(self, in_channels=3, dim=64, n_downsample=2, shared_block=None):
        super(Encoder, self).__init__()

        # Initial convolution block
        layers = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, dim, 7),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
        ]

        # Downsampling
        for _ in range(n_downsample):
            layers += [
                nn.Conv2d(dim, dim * 2, 4, stride=2, padding=1),
                nn.InstanceNorm2d(dim * 2),
                nn.ReLU(inplace=True),
            ]
            dim *= 2

        # Residual blocks
        for _ in range(3):
            layers += [ResidualBlock(dim)]

        self.model_blocks = nn.Sequential(*layers)
        self.shared_block = shared_block

    def reparameterization(self, mu):
        Tensor = torch.cuda.FloatTensor if mu.is_cuda else torch.FloatTensor
        z = Variable(Tensor(np.random.normal(0, 1, mu.shape)))
        return z + mu

    def forward(self, x):
        x = self.model_blocks(x)
        mu = self.shared_block(x)
        z = self.reparameterization(mu)
        return mu, z

shared_E = ResidualBlock(features=shared_dim)
E1 = Encoder(dim=opt.dim, n_downsample=opt.n_downsample, shared_block=shared_E)
E2 = Encoder(dim=opt.dim, n_downsample=opt.n_downsample, shared_block=shared_E)

2. UNIT的生成器

UNIT的两个生成器也采用权重共享设计,即共享生成器的浅层特征(高级语义信息)。

class Generator(nn.Module):
    def __init__(self, out_channels=3, dim=64, n_upsample=2, shared_block=None):
        super(Generator, self).__init__()

        self.shared_block = shared_block

        layers = []
        dim = dim * 2 ** n_upsample
        # Residual blocks
        for _ in range(3):
            layers += [ResidualBlock(dim)]

        # Upsampling
        for _ in range(n_upsample):
            layers += [
                nn.ConvTranspose2d(dim, dim // 2, 4, stride=2, padding=1),
                nn.InstanceNorm2d(dim // 2),
                nn.LeakyReLU(0.2, inplace=True),
            ]
            dim = dim // 2

        # Output layer
        layers += [nn.ReflectionPad2d(3), nn.Conv2d(dim, out_channels, 7), nn.Tanh()]

        self.model_blocks = nn.Sequential(*layers)

    def forward(self, x):
        x = self.shared_block(x)
        x = self.model_blocks(x)
        return x

shared_G = ResidualBlock(features=shared_dim)
G1 = Generator(dim=opt.dim, n_upsample=opt.n_downsample, shared_block=shared_G)
G2 = Generator(dim=opt.dim, n_upsample=opt.n_downsample, shared_block=shared_G)

3. UNIT的判别器

UNIT的判别器采用Pix2Pix提出的PatchGAN结构,输出为一个$N \times N$矩阵,其中的每个元素对应输入图像的一个子区域,用来评估该子区域的真实性。

class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
        channels, height, width = input_shape
        # Calculate output of image discriminator (PatchGAN)
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.Conv2d(512, 1, 3, padding=1)
        )

    def forward(self, img):
        return self.model(img)

D1 = Discriminator(input_shape)
D2 = Discriminator(input_shape)

4. UNIT的目标函数

UNIT的目标函数可以拆分成三部分,即VAE损失、GAN损失和cycle consistency损失。

\[\begin{aligned} \mathop{ \min}_{G_1,G_2,E_1,E_2} \mathop{\max}_{D_1,D_2} & \mathcal{L}_{\text{VAE}_1}(E_1,G_1) + \mathcal{L}_{\text{GAN}_1}(E_1,G_1,D_1) + \mathcal{L}_{\text{CC}_1}(E_1,G_1,E_2,G_2) \\ & +\mathcal{L}_{\text{VAE}_2}(E_2,G_2) + \mathcal{L}_{\text{GAN}_2}(E_2,G_2,D_2) + \mathcal{L}_{\text{CC}_2}(E_2,G_2,E_1,G_1) \end{aligned}\]

VAE损失包括隐变量$z$的KL散度和图像的重构损失:

\[\begin{aligned} \mathcal{L}_{\text{VAE}_1}(E_1,G_1) &= D_{KL}[E_1(x_1)||P(z)] - \Bbb{E}_{x_1 \text{~} P_{data}(x_1)}[||x_1-G_1(E_1(x_1))||_1 ] \\ \mathcal{L}_{\text{VAE}_2}(E_2,G_2) &= D_{KL}[E_2(x_2)||P(z)] - \Bbb{E}_{x_2 \text{~} P_{data}(x_2)}[||x_2-G_2(E_2(x_2))||_1 ] \end{aligned}\]

GAN损失为二元交叉熵损失:

\[\begin{aligned} \mathcal{L}_{\text{GAN}_1}(E_1,G_1,D_1) &= \Bbb{E}_{x_1 \text{~} P_{data}(x_1)}[\log D_1(x_1)] + \Bbb{E}_{x_2 \text{~} P_{data}(x_2)}[1-\log D_1(G_1(E_2(x_2)))] \\ \mathcal{L}_{\text{GAN}_2}(E_2,G_2,D_2) &= \Bbb{E}_{x_2 \text{~} P_{data}(x_2)}[\log D_2(x_2)] + \Bbb{E}_{x_1 \text{~} P_{data}(x_1)}[1-\log D_2(G_2(E_1(x_1)))] \end{aligned}\]

cycle consistency损失包括重构隐变量的KL散度,以及图像的循环重构损失:

\[\begin{aligned} \mathcal{L}_{\text{CC}_1}(E_1,G_1,E_2,G_2) = &D_{KL}[E_2(G_2(E_1(x_1)))||P(z)] \\ & - \Bbb{E}_{x_1 \text{~} P_{data}(x_1)}[||x_1-G_1(E_2(G_2(E_1(x_1))))||_1 ] \\ \mathcal{L}_{\text{CC}_2}(E_2,G_2,E_1,G_1) = &D_{KL}[E_1(G_1(E_2(x_2)))||P(z)] \\ & - \Bbb{E}_{x_2 \text{~} P_{data}(x_2)}[||x_2-G_2(E_1(G_1(E_2(x_2))))||_1 ] \end{aligned}\]

UNIT的完整pytorch实现可参考PyTorch-GAN,下面给出其损失函数的计算和参数更新过程:

# Losses
criterion_GAN = torch.nn.BCELoss()
criterion_pixel = torch.nn.L1Loss()

# Optimizers
optimizer_G = torch.optim.Adam(
    itertools.chain(E1.parameters(), E2.parameters(), G1.parameters(), G2.parameters()),
    lr=opt.lr,
    betas=(opt.b1, opt.b2),
)
optimizer_D1 = torch.optim.Adam(D1.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D2 = torch.optim.Adam(D2.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# Calculate output of image discriminator (PatchGAN)
patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4)

for epoch in range(opt.n_epochs):
    for i, (X1, X2) in enumerate(zip(dataloader_A, dataloader_B)):
        # Adversarial ground truths
        valid = torch.ones(X1.shape[0], *patch).requires_grad_(False)
        fake = torch.zeros(X1.shape[0], *patch).requires_grad_(False)

        # ----------------------------------
        # forward propogation
        # ----------------------------------
        # Get shared latent representation
        mu1, Z1 = E1(X1)
        mu2, Z2 = E2(X2)

        # Reconstruct images
        recon_X1 = G1(Z1)
        recon_X2 = G2(Z2)

        # Translate images
        fake_X1 = G1(Z2)
        fake_X2 = G2(Z1)

        # Cycle translation
        mu1_, Z1_ = E1(fake_X1)
        mu2_, Z2_ = E2(fake_X2)
        cycle_X1 = G1(Z2_)
        cycle_X2 = G2(Z1_)

        # -----------------------
        #  Train Discriminator 1
        # -----------------------

        optimizer_D1.zero_grad()

        loss_D1 = criterion_GAN(D1(X1), valid) + criterion_GAN(D1(fake_X1.detach()), fake)

        loss_D1.backward()
        optimizer_D1.step()

        # -----------------------
        #  Train Discriminator 2
        # -----------------------

        optimizer_D2.zero_grad()

        loss_D2 = criterion_GAN(D2(X2), valid) + criterion_GAN(D2(fake_X2.detach()), fake)

        loss_D2.backward()
        optimizer_D2.step()

        # -------------------------------
        #  Train Generator and Encoder
        # -------------------------------
        optimizer_G.zero_grad()

        # Losses
        loss_GAN_1 = lambda_0 * criterion_GAN(D1(fake_X1), valid)
        loss_GAN_2 = lambda_0 * criterion_GAN(D2(fake_X2), valid)
        loss_KL_1 = lambda_1 * compute_kl(mu1)
        loss_KL_2 = lambda_1 * compute_kl(mu2)
        loss_ID_1 = lambda_2 * criterion_pixel(recon_X1, X1)
        loss_ID_2 = lambda_2 * criterion_pixel(recon_X2, X2)
        loss_KL_1_ = lambda_3 * compute_kl(mu1_)
        loss_KL_2_ = lambda_3 * compute_kl(mu2_)
        loss_cyc_1 = lambda_4 * criterion_pixel(cycle_X1, X1)
        loss_cyc_2 = lambda_4 * criterion_pixel(cycle_X2, X2)

        # Total loss
        loss_G = (
            loss_KL_1
            + loss_KL_2
            + loss_ID_1
            + loss_ID_2
            + loss_GAN_1
            + loss_GAN_2
            + loss_KL_1_
            + loss_KL_2_
            + loss_cyc_1
            + loss_cyc_2
        )

        loss_G.backward()
        optimizer_G.step()