PixelDA:通过GAN实现像素级领域自适应.

PixelDA通过对源域图像增加噪声构成目标域的生成图像,从而实现了领域自适应。通常认为源域和目标域的图像维持相同的图像内容,但具有不同的图像风格。

PixelDA模型整体采用生成对抗网络形式。生成器接收源域图像和随机噪声,将其转换为目标域的图像;判别器判断输入的目标域图像是否真实。额外引入任务相关的网络(通常是分类器)辅助生成器学习。

PixelDA模型的生成器接收源域图像和随机噪声,通过向图像中增加噪声构造目标域图像。

class ResidualBlock(nn.Module):
    def __init__(self, in_features=64, out_features=64):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_features, in_features, 3, 1, 1),
            nn.BatchNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_features, in_features, 3, 1, 1),
            nn.BatchNorm2d(in_features),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Fully-connected layer which constructs image channel shaped output from noise
        self.fc = nn.Linear(opt.latent_dim, opt.channels * opt.img_size ** 2)

        self.l1 = nn.Sequential(nn.Conv2d(opt.channels * 2, 64, 3, 1, 1), nn.ReLU(inplace=True))

        resblocks = []
        for _ in range(opt.n_residual_blocks):
            resblocks.append(ResidualBlock())
        self.resblocks = nn.Sequential(*resblocks)

        self.l2 = nn.Sequential(nn.Conv2d(64, opt.channels, 3, 1, 1), nn.Tanh())

    def forward(self, img, z):
        gen_input = torch.cat((img, self.fc(z).view(*img.shape)), 1)
        out = self.l1(gen_input)
        out = self.resblocks(out)
        img_ = self.l2(out)
        return img_

PixelDA模型的判别器区分目标域图像的真实性。结构采用Pix2Pix提出的PatchGAN结构,把判别器设计为全卷积网络,输出为一个$N \times N$矩阵,其中的每个元素对应输入图像的一个子区域,用来评估该子区域的真实性。

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def block(in_features, out_features, normalization=True):
            """Discriminator block"""
            layers = [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_features))
            return layers

        self.model = nn.Sequential(
            *block(opt.channels, 64, normalization=False),
            *block(64, 128),
            *block(128, 256),
            *block(256, 512),
            nn.Conv2d(512, 1, 3, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        validity = self.model(img)
        return validity

PixelDA的判别器采用标准的交叉熵损失,生成器除交叉熵损失外,还引入了辅助训练的分类损失(多元交叉熵损失):

\[\begin{aligned} \mathop{\max}_{D} & \Bbb{E}_{x^t \text{~} P_{data}^t(x)}[\log D(x^t)] + \Bbb{E}_{(x^s,z) \text{~} (P_{data}^s(x),P_z(z))}[\log(1-D(G(x^s,z)))] \\ \mathop{ \min}_{G} & -\Bbb{E}_{(x^s,z) \text{~} (P_{data}^s(x),P_z(z))}[\log(D(G(x^s,z))] - \Bbb{E}_{(x,y) \text{~} (P_{data}(x),P_{data}(y))}[\log T_y(x)] \end{aligned}\]

PixelDA的完整pytorch实现可参考PyTorch-GAN,下面给出其损失函数的计算和参数更新过程:

# Loss functions
adversarial_loss = torch.nn.MSELoss()
task_loss = torch.nn.CrossEntropyLoss()

# Optimizers
optimizer_G = torch.optim.Adam(
    itertools.chain(generator.parameters(), classifier.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# Calculate output of image discriminator (PatchGAN)
patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4)

for epoch in range(opt.n_epochs):
    for i, ((imgs_A, labels_A), (imgs_B, labels_B)) in enumerate(zip(dataloader_A, dataloader_B)):
        # Adversarial ground truths
        valid = torch.ones(imgs_A.shape[0], *patch).requires_grad_(False)
        fake = torch.zeros(imgs_A.shape[0], *patch).requires_grad_(False)

        # Generate a batch of images
        z = torch.randn(imgs_A.shape[0], latent_dim)
        fake_B = generator(imgs_A, z)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(imgs_B), valid)
        fake_loss = adversarial_loss(discriminator(fake_B.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()

        # Perform task on translated source image
        label_pred = classifier(fake_B)

        # Calculate the task loss
        task_loss_ = (task_loss(label_pred, labels_A) + task_loss(classifier(imgs_A), labels_A)) / 2

        # Loss measures generator's ability to fool the discriminator
        g_loss = lambda_adv * adversarial_loss(discriminator(fake_B), valid) + lambda_task * task_loss_
        g_loss.backward()
        optimizer_G.step()