UNIT:无监督图像到图像翻译网络.
图像翻译任务(Image-to-Image Translation)有监督和无监督两种形式。对于监督式图像翻译,需要提供一一对应的图像数据集,本文作者设计了一种无监督式的图像翻译框架UNIT,只需要给定两种不同风格的数据集,该网络可以学习两种图像风格之间的变换关系。
作者假设不同风格的图像集存在共享的隐变量空间,即每一对图像$x_1,x_2$都可以在隐空间中找到同一个对应的隐变量$z$。
图像集与隐空间之间的映射关系通过VAE实现,分别使用两个编码器把图像映射到隐空间,再分别使用两个生成器把隐变量重构为图像。与此同时,引入两个判别器分别判断两种类型图像的真实性。
1. UNIT的编码器
UNIT的两个编码器采用权重共享设计,即共享编码器的深层特征,这些特征通常被认为携带高级语义信息,这些信息在不同图像域中是共享的。
class Encoder(nn.Module):
def __init__(self, in_channels=3, dim=64, n_downsample=2, shared_block=None):
super(Encoder, self).__init__()
# Initial convolution block
layers = [
nn.ReflectionPad2d(3),
nn.Conv2d(in_channels, dim, 7),
nn.InstanceNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
]
# Downsampling
for _ in range(n_downsample):
layers += [
nn.Conv2d(dim, dim * 2, 4, stride=2, padding=1),
nn.InstanceNorm2d(dim * 2),
nn.ReLU(inplace=True),
]
dim *= 2
# Residual blocks
for _ in range(3):
layers += [ResidualBlock(dim)]
self.model_blocks = nn.Sequential(*layers)
self.shared_block = shared_block
def reparameterization(self, mu):
Tensor = torch.cuda.FloatTensor if mu.is_cuda else torch.FloatTensor
z = Variable(Tensor(np.random.normal(0, 1, mu.shape)))
return z + mu
def forward(self, x):
x = self.model_blocks(x)
mu = self.shared_block(x)
z = self.reparameterization(mu)
return mu, z
shared_E = ResidualBlock(features=shared_dim)
E1 = Encoder(dim=opt.dim, n_downsample=opt.n_downsample, shared_block=shared_E)
E2 = Encoder(dim=opt.dim, n_downsample=opt.n_downsample, shared_block=shared_E)
2. UNIT的生成器
UNIT的两个生成器也采用权重共享设计,即共享生成器的浅层特征(高级语义信息)。
class Generator(nn.Module):
def __init__(self, out_channels=3, dim=64, n_upsample=2, shared_block=None):
super(Generator, self).__init__()
self.shared_block = shared_block
layers = []
dim = dim * 2 ** n_upsample
# Residual blocks
for _ in range(3):
layers += [ResidualBlock(dim)]
# Upsampling
for _ in range(n_upsample):
layers += [
nn.ConvTranspose2d(dim, dim // 2, 4, stride=2, padding=1),
nn.InstanceNorm2d(dim // 2),
nn.LeakyReLU(0.2, inplace=True),
]
dim = dim // 2
# Output layer
layers += [nn.ReflectionPad2d(3), nn.Conv2d(dim, out_channels, 7), nn.Tanh()]
self.model_blocks = nn.Sequential(*layers)
def forward(self, x):
x = self.shared_block(x)
x = self.model_blocks(x)
return x
shared_G = ResidualBlock(features=shared_dim)
G1 = Generator(dim=opt.dim, n_upsample=opt.n_downsample, shared_block=shared_G)
G2 = Generator(dim=opt.dim, n_upsample=opt.n_downsample, shared_block=shared_G)
3. UNIT的判别器
UNIT的判别器采用Pix2Pix提出的PatchGAN结构,输出为一个$N \times N$矩阵,其中的每个元素对应输入图像的一个子区域,用来评估该子区域的真实性。
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
channels, height, width = input_shape
# Calculate output of image discriminator (PatchGAN)
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
def discriminator_block(in_filters, out_filters, normalize=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(channels, 64, normalize=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.Conv2d(512, 1, 3, padding=1)
)
def forward(self, img):
return self.model(img)
D1 = Discriminator(input_shape)
D2 = Discriminator(input_shape)
4. UNIT的目标函数
UNIT的目标函数可以拆分成三部分,即VAE损失、GAN损失和cycle consistency损失。
\[\begin{aligned} \mathop{ \min}_{G_1,G_2,E_1,E_2} \mathop{\max}_{D_1,D_2} & \mathcal{L}_{\text{VAE}_1}(E_1,G_1) + \mathcal{L}_{\text{GAN}_1}(E_1,G_1,D_1) + \mathcal{L}_{\text{CC}_1}(E_1,G_1,E_2,G_2) \\ & +\mathcal{L}_{\text{VAE}_2}(E_2,G_2) + \mathcal{L}_{\text{GAN}_2}(E_2,G_2,D_2) + \mathcal{L}_{\text{CC}_2}(E_2,G_2,E_1,G_1) \end{aligned}\]VAE损失包括隐变量$z$的KL散度和图像的重构损失:
\[\begin{aligned} \mathcal{L}_{\text{VAE}_1}(E_1,G_1) &= D_{KL}[E_1(x_1)||P(z)] - \Bbb{E}_{x_1 \text{~} P_{data}(x_1)}[||x_1-G_1(E_1(x_1))||_1 ] \\ \mathcal{L}_{\text{VAE}_2}(E_2,G_2) &= D_{KL}[E_2(x_2)||P(z)] - \Bbb{E}_{x_2 \text{~} P_{data}(x_2)}[||x_2-G_2(E_2(x_2))||_1 ] \end{aligned}\]GAN损失为二元交叉熵损失:
\[\begin{aligned} \mathcal{L}_{\text{GAN}_1}(E_1,G_1,D_1) &= \Bbb{E}_{x_1 \text{~} P_{data}(x_1)}[\log D_1(x_1)] + \Bbb{E}_{x_2 \text{~} P_{data}(x_2)}[1-\log D_1(G_1(E_2(x_2)))] \\ \mathcal{L}_{\text{GAN}_2}(E_2,G_2,D_2) &= \Bbb{E}_{x_2 \text{~} P_{data}(x_2)}[\log D_2(x_2)] + \Bbb{E}_{x_1 \text{~} P_{data}(x_1)}[1-\log D_2(G_2(E_1(x_1)))] \end{aligned}\]cycle consistency损失包括重构隐变量的KL散度,以及图像的循环重构损失:
\[\begin{aligned} \mathcal{L}_{\text{CC}_1}(E_1,G_1,E_2,G_2) = &D_{KL}[E_2(G_2(E_1(x_1)))||P(z)] \\ & - \Bbb{E}_{x_1 \text{~} P_{data}(x_1)}[||x_1-G_1(E_2(G_2(E_1(x_1))))||_1 ] \\ \mathcal{L}_{\text{CC}_2}(E_2,G_2,E_1,G_1) = &D_{KL}[E_1(G_1(E_2(x_2)))||P(z)] \\ & - \Bbb{E}_{x_2 \text{~} P_{data}(x_2)}[||x_2-G_2(E_1(G_1(E_2(x_2))))||_1 ] \end{aligned}\]UNIT的完整pytorch实现可参考PyTorch-GAN,下面给出其损失函数的计算和参数更新过程:
# Losses
criterion_GAN = torch.nn.BCELoss()
criterion_pixel = torch.nn.L1Loss()
# Optimizers
optimizer_G = torch.optim.Adam(
itertools.chain(E1.parameters(), E2.parameters(), G1.parameters(), G2.parameters()),
lr=opt.lr,
betas=(opt.b1, opt.b2),
)
optimizer_D1 = torch.optim.Adam(D1.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D2 = torch.optim.Adam(D2.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
# Calculate output of image discriminator (PatchGAN)
patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4)
for epoch in range(opt.n_epochs):
for i, (X1, X2) in enumerate(zip(dataloader_A, dataloader_B)):
# Adversarial ground truths
valid = torch.ones(X1.shape[0], *patch).requires_grad_(False)
fake = torch.zeros(X1.shape[0], *patch).requires_grad_(False)
# ----------------------------------
# forward propogation
# ----------------------------------
# Get shared latent representation
mu1, Z1 = E1(X1)
mu2, Z2 = E2(X2)
# Reconstruct images
recon_X1 = G1(Z1)
recon_X2 = G2(Z2)
# Translate images
fake_X1 = G1(Z2)
fake_X2 = G2(Z1)
# Cycle translation
mu1_, Z1_ = E1(fake_X1)
mu2_, Z2_ = E2(fake_X2)
cycle_X1 = G1(Z2_)
cycle_X2 = G2(Z1_)
# -----------------------
# Train Discriminator 1
# -----------------------
optimizer_D1.zero_grad()
loss_D1 = criterion_GAN(D1(X1), valid) + criterion_GAN(D1(fake_X1.detach()), fake)
loss_D1.backward()
optimizer_D1.step()
# -----------------------
# Train Discriminator 2
# -----------------------
optimizer_D2.zero_grad()
loss_D2 = criterion_GAN(D2(X2), valid) + criterion_GAN(D2(fake_X2.detach()), fake)
loss_D2.backward()
optimizer_D2.step()
# -------------------------------
# Train Generator and Encoder
# -------------------------------
optimizer_G.zero_grad()
# Losses
loss_GAN_1 = lambda_0 * criterion_GAN(D1(fake_X1), valid)
loss_GAN_2 = lambda_0 * criterion_GAN(D2(fake_X2), valid)
loss_KL_1 = lambda_1 * compute_kl(mu1)
loss_KL_2 = lambda_1 * compute_kl(mu2)
loss_ID_1 = lambda_2 * criterion_pixel(recon_X1, X1)
loss_ID_2 = lambda_2 * criterion_pixel(recon_X2, X2)
loss_KL_1_ = lambda_3 * compute_kl(mu1_)
loss_KL_2_ = lambda_3 * compute_kl(mu2_)
loss_cyc_1 = lambda_4 * criterion_pixel(cycle_X1, X1)
loss_cyc_2 = lambda_4 * criterion_pixel(cycle_X2, X2)
# Total loss
loss_G = (
loss_KL_1
+ loss_KL_2
+ loss_ID_1
+ loss_ID_2
+ loss_GAN_1
+ loss_GAN_2
+ loss_KL_1_
+ loss_KL_2_
+ loss_cyc_1
+ loss_cyc_2
)
loss_G.backward()
optimizer_G.step()