PixelDA:通过GAN实现像素级领域自适应.
PixelDA通过对源域图像增加噪声构成目标域的生成图像,从而实现了领域自适应。通常认为源域和目标域的图像维持相同的图像内容,但具有不同的图像风格。
PixelDA模型整体采用生成对抗网络形式。生成器接收源域图像和随机噪声,将其转换为目标域的图像;判别器判断输入的目标域图像是否真实。额外引入任务相关的网络(通常是分类器)辅助生成器学习。
PixelDA模型的生成器接收源域图像和随机噪声,通过向图像中增加噪声构造目标域图像。
class ResidualBlock(nn.Module):
def __init__(self, in_features=64, out_features=64):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(in_features, in_features, 3, 1, 1),
nn.BatchNorm2d(in_features),
nn.ReLU(inplace=True),
nn.Conv2d(in_features, in_features, 3, 1, 1),
nn.BatchNorm2d(in_features),
)
def forward(self, x):
return x + self.block(x)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# Fully-connected layer which constructs image channel shaped output from noise
self.fc = nn.Linear(opt.latent_dim, opt.channels * opt.img_size ** 2)
self.l1 = nn.Sequential(nn.Conv2d(opt.channels * 2, 64, 3, 1, 1), nn.ReLU(inplace=True))
resblocks = []
for _ in range(opt.n_residual_blocks):
resblocks.append(ResidualBlock())
self.resblocks = nn.Sequential(*resblocks)
self.l2 = nn.Sequential(nn.Conv2d(64, opt.channels, 3, 1, 1), nn.Tanh())
def forward(self, img, z):
gen_input = torch.cat((img, self.fc(z).view(*img.shape)), 1)
out = self.l1(gen_input)
out = self.resblocks(out)
img_ = self.l2(out)
return img_
PixelDA模型的判别器区分目标域图像的真实性。结构采用Pix2Pix提出的PatchGAN结构,把判别器设计为全卷积网络,输出为一个$N \times N$矩阵,其中的每个元素对应输入图像的一个子区域,用来评估该子区域的真实性。
class Discriminator(nn.Module):
def __init__(self, in_channels=3):
super(Discriminator, self).__init__()
def block(in_features, out_features, normalization=True):
"""Discriminator block"""
layers = [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True)]
if normalization:
layers.append(nn.InstanceNorm2d(out_features))
return layers
self.model = nn.Sequential(
*block(opt.channels, 64, normalization=False),
*block(64, 128),
*block(128, 256),
*block(256, 512),
nn.Conv2d(512, 1, 3, 1, 1),
nn.Sigmoid()
)
def forward(self, img):
validity = self.model(img)
return validity
PixelDA的判别器采用标准的交叉熵损失,生成器除交叉熵损失外,还引入了辅助训练的分类损失(多元交叉熵损失):
\[\begin{aligned} \mathop{\max}_{D} & \Bbb{E}_{x^t \text{~} P_{data}^t(x)}[\log D(x^t)] + \Bbb{E}_{(x^s,z) \text{~} (P_{data}^s(x),P_z(z))}[\log(1-D(G(x^s,z)))] \\ \mathop{ \min}_{G} & -\Bbb{E}_{(x^s,z) \text{~} (P_{data}^s(x),P_z(z))}[\log(D(G(x^s,z))] - \Bbb{E}_{(x,y) \text{~} (P_{data}(x),P_{data}(y))}[\log T_y(x)] \end{aligned}\]PixelDA的完整pytorch实现可参考PyTorch-GAN,下面给出其损失函数的计算和参数更新过程:
# Loss functions
adversarial_loss = torch.nn.MSELoss()
task_loss = torch.nn.CrossEntropyLoss()
# Optimizers
optimizer_G = torch.optim.Adam(
itertools.chain(generator.parameters(), classifier.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
# Calculate output of image discriminator (PatchGAN)
patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4)
for epoch in range(opt.n_epochs):
for i, ((imgs_A, labels_A), (imgs_B, labels_B)) in enumerate(zip(dataloader_A, dataloader_B)):
# Adversarial ground truths
valid = torch.ones(imgs_A.shape[0], *patch).requires_grad_(False)
fake = torch.zeros(imgs_A.shape[0], *patch).requires_grad_(False)
# Generate a batch of images
z = torch.randn(imgs_A.shape[0], latent_dim)
fake_B = generator(imgs_A, z)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(discriminator(imgs_B), valid)
fake_loss = adversarial_loss(discriminator(fake_B.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Perform task on translated source image
label_pred = classifier(fake_B)
# Calculate the task loss
task_loss_ = (task_loss(label_pred, labels_A) + task_loss(classifier(imgs_A), labels_A)) / 2
# Loss measures generator's ability to fool the discriminator
g_loss = lambda_adv * adversarial_loss(discriminator(fake_B), valid) + lambda_task * task_loss_
g_loss.backward()
optimizer_G.step()