MUNIT:多模态无监督图像到图像翻译网络.
UNIT假设不同风格的图像集存在共享的隐变量空间,即每一对图像$x_1,x_2$都可以在隐空间中找到同一个对应的隐变量$z$。
本文作者指出该假设过于简化,无法满足更多图像集之间的对应关系。因此进一步假设,每一张图像$x$都对应在所有领域共享的内容空间中的内容编码$c$和领域特有的风格空间中的风格编码$s$。
网络的学习过程包括图像重构和编码重构两部分。
图像重构是指对图像$x_1,x_2$分别编码为$(c_1,s_1),(c_2,s_2)$,再解码为重构图像$\hat{x}_1,\hat{x}_2$,并最终构造两者的L1重构损失:
\[\begin{aligned} \mathcal{L}_{\text{recon}}^{x_1} &= \Bbb{E}_{x_1 \text{~} p(x_1)}[||x_1-G_1(E_1^c(x_1),E_1^s(x_1))||_1 ] \\ \mathcal{L}_{\text{recon}}^{x_2} &= \Bbb{E}_{x_2 \text{~} p(x_2)}[||x_2-G_2(E_2^c(x_2),E_2^s(x_2))||_1 ] \end{aligned}\]编码重构是指对图像$x_1,x_2$分别编码为$(c_1,s_1),(c_2,s_2)$,然后重组编码$(c_1,s_2),(c_2,s_1)$,并解码为迁移风格的图像$x_{1 \to 2},x_{2 \to 1}$,然后再将其分别编码为$(\hat{c}_1,\hat{s}_2),(\hat{c}_2,\hat{s}_1)$,并最终构造编码的L1重构损失:
\[\begin{aligned} \mathcal{L}_{\text{recon}}^{c_1} &= \Bbb{E}_{c_1 \text{~} p(c_1),s_2 \text{~} p(s_2)}[||c_1-E_2^c(G_2(c_1,s_2))||_1 ] \\ \mathcal{L}_{\text{recon}}^{s_2} &= \Bbb{E}_{c_1 \text{~} p(c_1),s_2 \text{~} p(s_2)}[||s_2-E_2^s(G_2(c_1,s_2))||_1 ] \\ \mathcal{L}_{\text{recon}}^{c_2} &= \Bbb{E}_{c_2 \text{~} p(c_2),s_1 \text{~} p(s_1)}[||c_2-E_1^c(G_1(c_2,s_1))||_1 ] \\ \mathcal{L}_{\text{recon}}^{s_1} &= \Bbb{E}_{c_2 \text{~} p(c_2),s_1 \text{~} p(s_1)}[||s_1-E_1^s(G_1(c_2,s_1))||_1 ] \end{aligned}\]此外,对图像$x_1,x_2$和迁移图像$x_{1 \to 2},x_{2 \to 1}$应用对抗损失:
\[\begin{aligned} \mathcal{L}_{\text{GAN}}^{x_1} &= \Bbb{E}_{x_1 \text{~} p(x_1)}[\log D_1(x_1)] + \Bbb{E}_{c_2 \text{~} p(c_2),s_1 \text{~} p(s_1)}[1-\log D_1(G_1(c_2,s_1))] \\ \mathcal{L}_{\text{GAN}}^{x_2} &= \Bbb{E}_{x_2 \text{~} p(x_2)}[\log D_2(x_2)] + \Bbb{E}_{c_1 \text{~} p(c_1),s_2 \text{~} p(s_2)}[1-\log D_2(G_2(c_1,s_2))] \end{aligned}\]网络的总损失函数如下:
\[\begin{aligned} \mathop{ \min}_{G_1,G_2,E_1,E_2} \mathop{\max}_{D_1,D_2} &\mathcal{L}_{\text{GAN}}^{x_1} + \mathcal{L}_{\text{GAN}}^{x_2} + \lambda_x(\mathcal{L}_{\text{recon}}^{x_1}+\mathcal{L}_{\text{recon}}^{x_2}) \\ & + \lambda_c(\mathcal{L}_{\text{recon}}^{c_1}+\mathcal{L}_{\text{recon}}^{c_2})+ \lambda_s(\mathcal{L}_{\text{recon}}^{s_1}+\mathcal{L}_{\text{recon}}^{s_2}) \end{aligned}\]网络的整体结构如图所示,其中生成器(解码器)采用了AdaIN方法,即通过风格编码$s$来参数化Instance归一化过程中的仿射参数$\gamma,\beta$。
\[AdaIN(z,\gamma,\beta) = \gamma (\frac{z-\mu(z)}{\sigma(z)})+\beta\]MUNIT的完整pytorch实现可参考PyTorch-GAN。
MUNIT的编码器实现如下。其中内容编码器为全卷积形式,对应的内容编码为二维张量;风格编码器为卷积+全连接层形式,对应的风格编码为一维向量。
#################################
# Content Encoder
#################################
class ContentEncoder(nn.Module):
def __init__(self, in_channels=3, dim=64, n_residual=3, n_downsample=2):
super(ContentEncoder, self).__init__()
# Initial convolution block
layers = [
nn.ReflectionPad2d(3),
nn.Conv2d(in_channels, dim, 7),
nn.InstanceNorm2d(dim),
nn.ReLU(inplace=True),
]
# Downsampling
for _ in range(n_downsample):
layers += [
nn.Conv2d(dim, dim * 2, 4, stride=2, padding=1),
nn.InstanceNorm2d(dim * 2),
nn.ReLU(inplace=True),
]
dim *= 2
# Residual blocks
for _ in range(n_residual):
layers += [ResidualBlock(dim, norm="in")]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
#################################
# Style Encoder
#################################
class StyleEncoder(nn.Module):
def __init__(self, in_channels=3, dim=64, n_downsample=2, style_dim=8):
super(StyleEncoder, self).__init__()
# Initial conv block
layers = [nn.ReflectionPad2d(3), nn.Conv2d(in_channels, dim, 7), nn.ReLU(inplace=True)]
# Downsampling
for _ in range(2):
layers += [nn.Conv2d(dim, dim * 2, 4, stride=2, padding=1), nn.ReLU(inplace=True)]
dim *= 2
# Downsampling with constant depth
for _ in range(n_downsample - 2):
layers += [nn.Conv2d(dim, dim, 4, stride=2, padding=1), nn.ReLU(inplace=True)]
# Average pool and output layer
layers += [nn.AdaptiveAvgPool2d(1), nn.Conv2d(dim, style_dim, 1, 1, 0)]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
#################################
# Encoder
#################################
class Encoder(nn.Module):
def __init__(self, in_channels=3, dim=64, n_residual=3, n_downsample=2, style_dim=8):
super(Encoder, self).__init__()
self.content_encoder = ContentEncoder(in_channels, dim, n_residual, n_downsample)
self.style_encoder = StyleEncoder(in_channels, dim, n_downsample, style_dim)
def forward(self, x):
content_code = self.content_encoder(x)
style_code = self.style_encoder(x)
return content_code, style_code
# Initialize encoders
Enc1 = Encoder(dim=opt.dim, n_downsample=opt.n_downsample, n_residual=opt.n_residual, style_dim=opt.style_dim)
Enc2 = Encoder(dim=opt.dim, n_downsample=opt.n_downsample, n_residual=opt.n_residual, style_dim=opt.style_dim)
MUNIT的生成器(解码器)实现如下。
######################################
# MLP (predicts AdaIn parameters)
######################################
class MLP(nn.Module):
def __init__(self, input_dim, output_dim, dim=256, n_blk=3, activ="relu"):
super(MLP, self).__init__()
layers = [nn.Linear(input_dim, dim), nn.ReLU(inplace=True)]
for _ in range(n_blk - 2):
layers += [nn.Linear(dim, dim), nn.ReLU(inplace=True)]
layers += [nn.Linear(dim, output_dim)]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x.view(x.size(0), -1))
#################################
# Decoder
#################################
class Decoder(nn.Module):
def __init__(self, out_channels=3, dim=64, n_residual=3, n_upsample=2, style_dim=8):
super(Decoder, self).__init__()
layers = []
dim = dim * 2 ** n_upsample
# Residual blocks
for _ in range(n_residual):
layers += [ResidualBlock(dim, norm="adain")]
# Upsampling
for _ in range(n_upsample):
layers += [
nn.Upsample(scale_factor=2),
nn.Conv2d(dim, dim // 2, 5, stride=1, padding=2),
LayerNorm(dim // 2),
nn.ReLU(inplace=True),
]
dim = dim // 2
# Output layer
layers += [nn.ReflectionPad2d(3), nn.Conv2d(dim, out_channels, 7), nn.Tanh()]
self.model = nn.Sequential(*layers)
# Initiate mlp (predicts AdaIN parameters)
num_adain_params = self.get_num_adain_params()
self.mlp = MLP(style_dim, num_adain_params)
def get_num_adain_params(self):
"""Return the number of AdaIN parameters needed by the model"""
num_adain_params = 0
for m in self.modules():
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
num_adain_params += 2 * m.num_features
return num_adain_params
def assign_adain_params(self, adain_params):
"""Assign the adain_params to the AdaIN layers in model"""
for m in self.modules():
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
# Extract mean and std predictions
mean = adain_params[:, : m.num_features]
std = adain_params[:, m.num_features : 2 * m.num_features]
# Update bias and weight
m.bias = mean.contiguous().view(-1)
m.weight = std.contiguous().view(-1)
# Move pointer
if adain_params.size(1) > 2 * m.num_features:
adain_params = adain_params[:, 2 * m.num_features :]
def forward(self, content_code, style_code):
# Update AdaIN parameters by MLP prediction based off style code
self.assign_adain_params(self.mlp(style_code))
img = self.model(content_code)
return img
# Initialize generators
Dec1 = Decoder(dim=opt.dim, n_upsample=opt.n_downsample, n_residual=opt.n_residual, style_dim=opt.style_dim)
Dec2 = Decoder(dim=opt.dim, n_upsample=opt.n_downsample, n_residual=opt.n_residual, style_dim=opt.style_dim)
其中用于融合风格编码和内容编码的AdaIN方法实现如下:
norm_layer = AdaptiveInstanceNorm2d
class AdaptiveInstanceNorm2d(nn.Module):
"""Reference: https://github.com/NVlabs/MUNIT/blob/master/networks.py"""
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(AdaptiveInstanceNorm2d, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
# weight and bias are dynamically assigned
self.weight = None
self.bias = None
# just dummy buffers, not used
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
def forward(self, x):
assert (
self.weight is not None and self.bias is not None
), "Please assign weight and bias before calling AdaIN!"
b, c, h, w = x.size()
running_mean = self.running_mean.repeat(b)
running_var = self.running_var.repeat(b)
# Apply instance norm
x_reshaped = x.contiguous().view(1, b * c, h, w)
out = F.batch_norm(
x_reshaped, running_mean, running_var, self.weight, self.bias, True, self.momentum, self.eps
)
return out.view(b, c, h, w)
def __repr__(self):
return self.__class__.__name__ + "(" + str(self.num_features) + ")"
MUNIT的判别器实现如下:
class MultiDiscriminator(nn.Module):
def __init__(self, in_channels=3):
super(MultiDiscriminator, self).__init__()
def discriminator_block(in_filters, out_filters, normalize=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
# Extracts three discriminator models
self.models = nn.ModuleList()
for i in range(3):
self.models.add_module(
"disc_%d" % i,
nn.Sequential(
*discriminator_block(in_channels, 64, normalize=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.Conv2d(512, 1, 3, padding=1)
),
)
self.downsample = nn.AvgPool2d(in_channels, stride=2, padding=[1, 1], count_include_pad=False)
def forward(self, x):
outputs = []
for m in self.models:
outputs.append(m(x))
x = self.downsample(x)
return outputs
# Initialize discriminators
D1 = MultiDiscriminator()
D2 = MultiDiscriminator()
MUNIT的损失函数计算和参数更新过程如下:
criterion_GAN = torch.nn.BCELoss()
criterion_recon = torch.nn.L1Loss()
# Optimizers
optimizer_G = torch.optim.Adam(
itertools.chain(Enc1.parameters(), Dec1.parameters(), Enc2.parameters(), Dec2.parameters()),
lr=opt.lr,
betas=(opt.b1, opt.b2),
)
optimizer_D1 = torch.optim.Adam(D1.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D2 = torch.optim.Adam(D2.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
for epoch in range(opt.n_epochs):
for i, (X1, X2) in enumerate(zip(dataloader_A, dataloader_B)):
# Adversarial ground truths
valid = torch.ones(X1.shape[0], *patch).requires_grad_(False)
fake = torch.zeros(X1.shape[0], *patch).requires_grad_(False)
# ----------------------------------
# forward propogation
# ----------------------------------
# Get shared latent representation
c_code_1, s_code_1 = Enc1(X1)
c_code_2, s_code_2 = Enc2(X2)
# Reconstruct images
X11 = Dec1(c_code_1, s_code_1)
X22 = Dec2(c_code_2, s_code_2)
# Translate images
X21 = Dec1(c_code_2, s_code_1)
X12 = Dec2(c_code_1, s_code_2)
# Cycle translation
c_code_21, s_code_21 = Enc1(X21)
c_code_12, s_code_12 = Enc2(X12)
# -----------------------
# Train Discriminator 1
# -----------------------
optimizer_D1.zero_grad()
loss_D1 = criterion_GAN(X1, valid) + criterion_GAN(X21.detach(), fake)
loss_D1.backward()
optimizer_D1.step()
# -----------------------
# Train Discriminator 2
# -----------------------
optimizer_D2.zero_grad()
loss_D2 = criterion_GAN(X2, valid) + criterion_GAN(X12.detach(), fake)
loss_D2.backward()
optimizer_D2.step()
# -------------------------------
# Train Generator and Encoder
# -------------------------------
optimizer_G.zero_grad()
# Losses
loss_GAN_1 = lambda_gan * criterion_GAN(X21, valid)
loss_GAN_2 = lambda_gan * criterion_GAN(X12, valid)
loss_ID_1 = lambda_id * criterion_recon(X11, X1)
loss_ID_2 = lambda_id * criterion_recon(X22, X2)
loss_s_1 = lambda_style * criterion_recon(s_code_21, style_1)
loss_s_2 = lambda_style * criterion_recon(s_code_12, style_2)
loss_c_1 = lambda_cont * criterion_recon(c_code_12, c_code_1.detach())
loss_c_2 = lambda_cont * criterion_recon(c_code_21, c_code_2.detach())
# Total loss
loss_G = (
loss_GAN_1
+ loss_GAN_2
+ loss_ID_1
+ loss_ID_2
+ loss_s_1
+ loss_s_2
+ loss_c_1
+ loss_c_2
)
loss_G.backward()
optimizer_G.step()