MAE: 掩码自编码器是可扩展的视觉学习者.

本文设计了一种应用于计算机视觉的自监督学习方法,掩码自编码器(masked autoencoder, MAE)。MAE接收随机遮挡部分patch的图像为输入,并重构原始图像。

MAE的整个网络采用非对称的编码器-解码器结构。编码器只对未遮挡的图像块进行操作;解码器是轻量级的,旨在从编码特征和遮挡token中重建输入图像。

相比于语言任务,图像的信息密度低,即挡住图片的一部分patches,可以很容易地通过看它周围的 patches 而想象出它的样子来。因此通常对图像进行较大比例的遮挡(如$75\%$),此时掩码重建任务具有一定的难度,而且可以较大程度地减少了计算量和内存消耗,并降低预训练时间。

编码器采用ViT结构,只输入未遮挡的图像块序列,因此能够使用有限的内存和计算训练非常大的编码器。编码器的特征和用于表示遮挡图像块的遮挡token组合后作为解码器的输入,通过一组轻量级的Transformer模块重构原始图像。预训练完成后,解码器可以被丢弃,只使用编码器提取图像特征用于下游任务。

解码器输出的每一个元素表示一个遮挡图像块的像素值向量,损失函数计算原始图像和重构图像上遮挡部分的像素的均方误差。作者指出先计算出每个 patch 的像素值的均值和方差,并使用它们去归一化这个 patch 的每个像素值。最后再使用归一化的像素值进行 MSE Loss 计算,能够提高特征表示的质量。

在实际实现时,通过线性映射和位置编码为每一个图像块生成一个token,对token序列随机打乱(记住打乱顺序)后,根据掩码率删除序列的最后一部分,其保留的部分便是未遮挡的图像块序列,用作编码器的输入。

MAE的具体实现过程为:

  1. 首先通过Linear Projection和位置编码得到 image tokens
  2. 随机 shuffle 这些 tokens,按照 masking ratio 扔掉最后的一部分。
  3. unmasked patches 输出到 Encoder 中,得到这些 tokens 的表征。
  4. Encoder 的输出结合 masked tokens (可学习的向量),执行 unshuffle 操作恢复顺序,再一起输入到 Decoder 中。
  5. shuffleunshuffle 操作的时间开销可忽略不计。
class MAE(nn.Module):
    def __init__(
        self,
        *,
        encoder, # 传入ViT
        decoder_dim,
        masking_ratio = 0.75,
        decoder_depth = 1,
        decoder_heads = 8,
        decoder_dim_head = 64
    ):
        super().__init__()
        assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
        self.masking_ratio = masking_ratio

        # extract some hyperparameters and functions from encoder (vision transformer to be trained)

        self.encoder = encoder
        num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]

        self.to_patch = encoder.to_patch_embedding[0]
        self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])

        pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]

        # decoder parameters
        self.decoder_dim = decoder_dim
        self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
        self.mask_token = nn.Parameter(torch.randn(decoder_dim))
        self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
        self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
        self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)

    def forward(self, img):
        device = img.device

        # get patches
        patches = self.to_patch(img)
        batch, num_patches, *_ = patches.shape

        # patch to encoder tokens and add positions
        tokens = self.patch_to_emb(patches)
        if self.encoder.pool == "cls":
            tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)]
        elif self.encoder.pool == "mean":
            tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype) 

        # calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
        num_masked = int(self.masking_ratio * num_patches)
        rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
        masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]

        # get the unmasked tokens to be encoded
        batch_range = torch.arange(batch, device = device)[:, None]
        tokens = tokens[batch_range, unmasked_indices]

        # get the patches to be masked for the final reconstruction loss
        masked_patches = patches[batch_range, masked_indices]

        # attend with vision transformer
        encoded_tokens = self.encoder.transformer(tokens)

        # project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder
        decoder_tokens = self.enc_to_dec(encoded_tokens)

        # reapply decoder position embedding to unmasked tokens
        unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)

        # repeat mask tokens for number of masked, and add the positions using the masked indices derived above
        mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked)
        mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)

        # concat the masked tokens to the decoder tokens and attend with decoder
        decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
        decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
        decoder_tokens[batch_range, masked_indices] = mask_tokens
        decoded_tokens = self.decoder(decoder_tokens)

        # splice out the mask tokens and project to pixel values
        mask_tokens = decoded_tokens[batch_range, masked_indices]
        pred_pixel_values = self.to_pixels(mask_tokens)

        # calculate reconstruction loss
        recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
        return recon_loss

下面展示一些恢复结果:

作者针对不同的掩码率进行了实验。有趣的是,对图像进行约$75\%$的遮挡能够获得最好的效果,这和自然语言处理中使用的较低掩码率不同(BERT约$15\%$)。这可能是因为较大的遮挡使得模型必须学习有用的通用表示,而不是简单地通过线条或纹理来完成任务。

作者进一步进行了一些消融实验,其中fit表示对模型进行端到端的微调;lin表示仅微调输出端的线性层。

作者比较了MAE和其他自监督模型的表现。对于 ViT-B 模型,所有的方法性能相似,但是对于更大的 ViT-L 模型,性能差距就拉开了,证明了 MAE 对于大模型的泛化性能。