通过上下文条件生成对抗网络实现半监督学习.

本文设计了一种用于上下文的像素预测的半监督特征学习方法上下文条件(Context-Conditional)GAN,能够根据周围像素生成任意图像区域。作者把Context EncoderAuxiliary Classifier GAN (ACGAN)结合起来。

生成器采用一种自编码器结构,接收masked输入图像$x$,生成图像修补的结果;判别器接收图像$x$和图像标签$y$,判断图像$x$是否为真实图像(二分类)以及是否属于对应的标签$y$(多分类)。

1. 网络结构

生成器采用卷积自编码器网络:

class Generator(nn.Module):
    def __init__(self, channels=3):
        super(Generator, self).__init__()

        def downsample(in_feat, out_feat, normalize=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        def upsample(in_feat, out_feat, normalize=True):
            layers = [nn.ConvTranspose2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.ReLU())
            return layers

        self.model = nn.Sequential(
            *downsample(channels, 64, normalize=False),
            *downsample(64, 128),
            *downsample(128, 256),
            *downsample(256, 512),
            *upsample(512, 256),
            *upsample(256, 128),
            *upsample(128, 64),
            nn.Conv2d(64, channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

判别器采用PatchGAN的结构,接收修补后的完整图像,输出$4 \times 4$的特征,用于判断图像的真假和类别:

class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        channels, height, width = input_shape
        # Calculate output of image discriminator (PatchGAN)
        patch_h, patch_w = int(height / 2 ** 5), int(width / 2 ** 5)
        self.output_shape = (1, patch_h, patch_w)

        def discriminator_block(in_filters, out_filters, num_convs, normalize):
            """Returns layers of each discriminator block"""
            layers = []
            for i in range(num_convs):
                layers.append(nn.Conv2d(in_filters, out_filters, 3, 1, 1))
                if normalize:
                    layers.append(nn.InstanceNorm2d(out_filters))
                layers.append(nn.LeakyReLU(0.2, inplace=True))
                in_filters = out_filters
            layers.append(nn.MaxPool2d(2))
            return layers

        layers = []
        in_filters = channels
        for out_filters, num_convs, normalize in [(64, 1, False), (128, 1, True), (256, 2, True), (512, 2, True), (512, 2, True)]:
            layers.extend(discriminator_block(in_filters, out_filters, stride, normalize))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, 3, 1, 1))
        self.model = nn.Sequential(*layers)

        # Output layers
        self.adv_layer = nn.Sequential(nn.Linear(patch_h*patch_w, 1), nn.Sigmoid())
        self.aux_layer = nn.Sequential(nn.Linear(patch_h*patch_w, opt.n_classes), nn.Softmax())
    

    def forward(self, img):
        out = self.model(img)
        validity = self.adv_layer(out.view(out.size(0), -1))
        label = self.aux_layer(out.view(out.size(0), -1))
        return validity, label

2. 损失函数

上下文条件编码器的损失函数包括L2重构损失、对抗损失和分类损失。

L2重构损失捕获缺失区域的整体结构,使修补后的区域与周围具有结构一致性;在训练时随机产生输入图像的二值化的掩码$\hat{M}$(1表示缺失区域,0表示输入像素),用$F$表示上下文编码器,则重构损失表示为:

\[\mathcal{L}_{rec}(x) = ||\hat{M} \odot (x-F((1-\hat{M})\odot x))||_2^2\]
def apply_random_mask(imgs):
    idx = np.random.randint(0, opt.img_size - opt.mask_size, (imgs.shape[0], 2))
    masked_imgs = imgs.clone()
    for i, (y1, x1) in enumerate(idx):
        y2, x2 = y1 + opt.mask_size, x1 + opt.mask_size
        masked_imgs[i, :, y1:y2, x1:x2] = -1
    return masked_imgs

pixelwise_loss = torch.nn.L2Loss()
gen_imgs = generator(masked_imgs)
g_pixel = pixelwise_loss(gen_imgs, masked_imgs)

对抗损失和分类损失则与Auxiliary Classifier GAN (ACGAN)相同:

# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        # Adversarial ground truths
        valid = torch.ones(imgs.shape[0], 1).requires_grad_.(False)
        fake = torch.zeros(imgs.shape[0], 1).requires_grad_.(False)

        # Generate a batch of images
        masked_imgs = apply_random_mask(imgs)
        gen_imgs = generator(masked_imgs)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Loss for real images
        real_pred, real_aux = discriminator(masked_imgs)
        d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

        # Loss for fake images
        fake_pred, fake_aux = discriminator(gen_imgs.detach())
        d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()

        # Loss measures generator's ability to fool the discriminator
        validity, pred_label = discriminator(gen_imgs)
        g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))
        g_loss += g_pixel
        g_loss.backward()
        optimizer_G.step()

Context-Conditional GAN的完整pytorch实现可参考PyTorch-GAN