StarGAN:统一的多领域图像翻译框架.
对于图像翻译任务(Image-to-Image Translation),大多数方法获得的图像输出都是单一的,例如执行斑马到马的转换;如果想要实现多种不同的转换,则需要成倍地网络的计算负担(每两个类型之间都建立转换关系)。
本文设计了StarGAN,通过一种模型同时实现多种类型的图像翻译。该方法在训练时需要对每一个图像提供对应的领域标签(其中的每一个元素指代一种类型,如人脸图像的发色与性别)。
1. StarGAN的整体结构
StarGAN由一个判别器和一个生成器构造。判别器用于判断图像是否为真实图像,若为真实图像则进一步预测其领域标签;生成器接收一张图像和给定的领域标签,生成对应领域的图像。
StarGAN的训练过程采用循环过程。将输入图像和目标标签输入生成器,产生目标域的生成图像。再将该图像和原标签输入生成器,产生原图像的重构图像。判别器则尝试区分生成图像和输入图像。
2. StarGAN的网络设计
StarGAN的生成器采用残差网络构成的编码器-解码器结构,标签$c$采用与输入图像直连的方式。
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
conv_block = [
nn.Conv2d(in_features, in_features, 3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(in_features, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_features, in_features, 3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(in_features, affine=True, track_running_stats=True),
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class GeneratorResNet(nn.Module):
def __init__(self, img_shape=(3, 128, 128), res_blocks=9, c_dim=5):
super(GeneratorResNet, self).__init__()
channels, img_size, _ = img_shape
# Initial convolution block
model = [
nn.Conv2d(channels + c_dim, 64, 7, stride=1, padding=3, bias=False),
nn.InstanceNorm2d(64, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
]
# Downsampling
curr_dim = 64
for _ in range(2):
model += [
nn.Conv2d(curr_dim, curr_dim * 2, 4, stride=2, padding=1, bias=False),
nn.InstanceNorm2d(curr_dim * 2, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
]
curr_dim *= 2
# Residual blocks
for _ in range(res_blocks):
model += [ResidualBlock(curr_dim)]
# Upsampling
for _ in range(2):
model += [
nn.ConvTranspose2d(curr_dim, curr_dim // 2, 4, stride=2, padding=1, bias=False),
nn.InstanceNorm2d(curr_dim // 2, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
]
curr_dim = curr_dim // 2
# Output layer
model += [nn.Conv2d(curr_dim, channels, 7, stride=1, padding=3), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, x, c):
c = c.view(c.size(0), c.size(1), 1, 1)
c = c.repeat(1, 1, x.size(2), x.size(3))
x = torch.cat((x, c), 1)
return self.model(x)
StarGAN的判别器采用Pix2Pix提出的PatchGAN结构,输出为一个$N \times N$矩阵,其中的每个元素对应输入图像的一个子区域,用来评估该子区域的真实性。与此同时,判别器还对标签进行预测。
class Discriminator(nn.Module):
def __init__(self, img_shape=(3, 128, 128), c_dim=5, n_strided=6):
super(Discriminator, self).__init__()
channels, img_size, _ = img_shape
def discriminator_block(in_filters, out_filters):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1), nn.LeakyReLU(0.01)]
return layers
layers = discriminator_block(channels, 64)
curr_dim = 64
for _ in range(n_strided - 1):
layers.extend(discriminator_block(curr_dim, curr_dim * 2))
curr_dim *= 2
self.model = nn.Sequential(*layers)
# Output 1: PatchGAN
self.out1 = nn.Conv2d(curr_dim, 1, 3, padding=1, bias=False)
# Output 2: Class prediction
kernel_size = img_size // 2 ** n_strided
self.out2 = nn.Sequential(nn.Conv2d(curr_dim, c_dim, kernel_size, bias=False),nn.Sigmoid())
def forward(self, img):
feature_repr = self.model(img)
out_adv = self.out1(feature_repr)
out_cls = self.out2(feature_repr)
return out_adv, out_cls.view(out_cls.size(0), -1)
3. StarGAN的目标函数
StarGAN的判别器的目标函数包括对抗损失和标签分类损失;生成器的目标函数包括对抗损失、标签分类损失和重构损失。
\[\begin{aligned} \mathop{\max}_{D} & \Bbb{E}_{x \text{~} P_{data}(x)}[\log D(x)] + \Bbb{E}_{x \text{~} P_{data}(x)}[1-\log D(G(x, y^t))] \\ &+ \Bbb{E}_{(x,y) \text{~} (P_{data}(x),P_{data}(Y))}[\log D_{y}(x)] \\ \mathop{ \min}_{G} &- \Bbb{E}_{x \text{~} P_{data}(x)}[D(G(x, y^t))]-\Bbb{E}_{x \text{~} P_{data}(x)}[\log D_{y^t}(G(x, y^t))] \\ &+ \Bbb{E}_{x \text{~} P_{data}(x)}[||x-G(G(x, y^t),y^s)||_1] \end{aligned}\]StarGAN的完整pytorch实现可参考PyTorch-GAN,下面给出其损失函数的计算和参数更新过程:
# Losses
gan_loss = torch.nn.BCELoss()
cycle_loss = torch.nn.L1Loss()
# Initialize model
generator = GeneratorResNet(img_shape=img_shape, res_blocks=opt.residual_blocks, c_dim=c_dim)
discriminator = Discriminator(img_shape=img_shape, c_dim=c_dim)
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
# Calculate output of image discriminator (PatchGAN)
patch = (1, opt.img_height // 2 ** 6, opt.img_width // 2 ** 6)
for epoch in range(opt.n_epochs):
for i, (imgs, labels) in enumerate(dataloader):
# Adversarial ground truths
valid = torch.ones(imgs.shape[0], *patch).requires_grad_(False)
fake = torch.zeros(imgs.shape[0], *patch).requires_grad_(False)
# ----------------------------------
# forward propogation
# ----------------------------------
# Sample labels as generator inputs
sampled_c = torch.randn(imgs.size(0), c_dim)
# Generate fake batch of images
fake_imgs = generator(imgs, sampled_c)
# Reconstruct image
recov_imgs = generator(fake_imgs, labels)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Real images
real_validity, pred_cls = discriminator(imgs)
# Fake images
fake_validity, _ = discriminator(fake_imgs.detach())
# Adversarial loss
loss_D_adv = gan_loss(real_validity, valid) + gan_loss(fake_validity, fake)
# Classification loss
loss_D_cls = gan_loss(pred_cls, labels)
# Total loss
loss_D = loss_D_adv + lambda_cls * loss_D_cls
loss_D.backward()
optimizer_D.step()
# -------------------------------
# Train Generator
# -------------------------------
optimizer_G.zero_grad()
# Discriminator evaluates translated image
fake_validity, pred_cls = discriminator(gen_imgs)
# Adversarial loss
loss_G_adv = gan_loss(fake_validity, valid)
# Classification loss
loss_G_cls = gan_loss(pred_cls, sampled_c)
# Reconstruction loss
loss_G_rec = criterion_cycle(recov_imgs, imgs)
# Total loss
loss_G = loss_G_adv + lambda_cls * loss_G_cls + lambda_rec * loss_G_rec
loss_G.backward()
optimizer_G.step()