SimMIM:一种掩码图像建模的简单框架.

SimMIM 是一个 MIM 任务上的预训练 CV 模型。这个模型直接回归预测原始像素 RGB 值。作者在这篇论文中探讨究竟是什么使得 MIM 任务使目标网络能学到更好的 visual representation。得出了以下结论:

  1. MIM 任务中,mask patch 的大小如果是32×32,就能使得预训练任务成为一个 strong pre-text task,非常有利于预训练大模型性能的提升。
  2. 直接回归预测原始像素 RGB 值的效果并不比复杂设计的patch分类方法差。
  3. 预测头可以设计成轻量化的模型,比如一个线性层,它的表现不比 heavy 模型差。

SimMIMMasking Strategy 是把 maskpatches 替换成可学习的 mask token vector,并随着网络一起训练。mask 的基本单位仍然是 Image Patches,对于 ViT 模型,masked patch size 使用32×32;对于 Swin Transformer 模型,masked patch size 使用4×4-32×32

目标网络的架构实际使用了 ViT 模型和 Swin Transformer 模型。希望 Prediction head 的输出就是重建之后的原图,所以为了预测输入图像在 full-resolution 下的所有像素值,将 feature map 映射回原始分辨率,并由这个 feature map 负责对相应原始像素的预测。比如当使用 Swin Transformer Encoder 时,输出是 downsample 32倍的 feature map。此时要先通过1×1的卷积输出维度是3072=3×32×32。再使用L1 loss

\[L=\frac{1}{\Omega(x_M)} ||y_M-x_M||_1\]
class SimMIM(nn.Module):
    def __init__(
        self,
        *,
        encoder, # 传入ViT
        masking_ratio = 0.5
    ):
        super().__init__()
        assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
        self.masking_ratio = masking_ratio

        # extract some hyperparameters and functions from encoder (vision transformer to be trained)
        self.encoder = encoder
        num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]

        self.to_patch = encoder.to_patch_embedding[0]
        self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])

        pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]

        # simple linear head
        self.mask_token = nn.Parameter(torch.randn(encoder_dim))
        self.to_pixels = nn.Linear(encoder_dim, pixel_values_per_patch)

    def forward(self, img):
        device = img.device

        # get patches
        patches = self.to_patch(img)
        batch, num_patches, *_ = patches.shape

        # for indexing purposes
        batch_range = torch.arange(batch, device = device)[:, None]

        # get positions
        pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)]

        # patch to encoder tokens and add positions
        tokens = self.patch_to_emb(patches)
        tokens = tokens + pos_emb

        # prepare mask tokens
        mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_patches)
        mask_tokens = mask_tokens + pos_emb

        # calculate of patches needed to be masked, and get positions (indices) to be masked
        num_masked = int(self.masking_ratio * num_patches)
        masked_indices = torch.rand(batch, num_patches, device = device).topk(k = num_masked, dim = -1).indices
        masked_bool_mask = torch.zeros((batch, num_patches), device = device).scatter_(-1, masked_indices, 1).bool()

        # mask tokens
        tokens = torch.where(masked_bool_mask[..., None], mask_tokens, tokens)

        # attend with vision transformer
        encoded = self.encoder.transformer(tokens)

        # get the masked tokens
        encoded_mask_tokens = encoded[batch_range, masked_indices]

        # small linear projection for predicted pixel values
        pred_pixel_values = self.to_pixels(encoded_mask_tokens)

        # get the masked patches for the final reconstruction loss
        masked_patches = patches[batch_range, masked_indices]

        # calculate reconstruction loss
        recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked
        return recon_loss

作者首先研究了不同 masking strategy 对表征学习的影响。最佳的 random masking strategy 是挡住$50\%$的原图。此外,当 mask patch size=32 时,mask ratio 在$10\%-70\%$时都能够取得很不错的结果。作者认为一个 mask 中心的像素距离边界可见像素是足够远的,因此可以强迫网络学习到一些 long-range 的关系,即使 mask 掉的像素足够多。相对较小的 patch 尺寸有利于微调性能,但总体精度不如较大的 patch 高。进一步将 patch 大小增加导致观测精度下降,可能是由于预测距离太大。

作者提出了一种 AvgDist 度量,该度量测量掩码像素到最近的可见像素的平均欧氏距离。不同掩码策略与不同掩蔽率的 AvgDist 如图(a)所示。从图中可以看出,所有的 masking strategy 的AvgDist 都随着 masking ratio 的增大而平滑增大。对于随机掩码策略,当 patch size 较小 (如48) 时, AvgDist 相对较低,且随着掩码率的增加而增长缓慢。另一方面,当 patch size 较大时 (如64),很小的 mask ratio (如$10\%$) 仍然会产生较大的 AvgDist

图(b)绘制了微调精度和 AvgDist 度量之间的关系,它遵循山脊 (ridge) 形状。微调精度高的条目大致分布在 AvgDist 的$[10,20]$范围内,而 AvgDist 越小或越高的条目表现越差。这表明掩码图像建模中的预测距离应该适中,既不要太大,也不要太小。AugDist 太小的话,网络可能会学习到太多捷径,AugDist 太大的话,网络可能会很难学习。实际使用的 mask ratio=0.6patch size=32

下图对比了不同结构的 Projection head 对结果的影响。作者依次尝试了 Linear layer2MLPinverseSwin-TinverseSwin-B 架构。发现参数量大的 Projection head 会带来更低的 loss,但是 Top-1Accuracy 反而变低了。这意味着预训练 impainting 的能力并不代表下游任务的能力。

下图对比了不同的** Projection resolution** 对结果的影响。大范围的分辨率 (12×12-192×192) 都能表现良好。性能只有在6×6的分辨率的低分辨率下才会下降,可能是因为6×6的分辨率丢弃了太多信息。 这些结果反映了下游图像分类任务所需的信息粒度。

作者探究 SimMIM 模型通过预训练 masked image modeling 任务获得了一种什么样的能力。下图每一行里面的 mask 分为 Random mask、挡住主要物体的 mask、挡住全部主要物体的 mask。结果显示:

  1. 如果使用 Random mask,物体的形状和纹理可以得到重建。但是,unmasked 部分因为模型没有学习这部分的重建,导致最终结果出现了很多的缺陷。
  2. 如果挡住主要物体的 mask,模型仍然能够重建出物体的部分。
  3. 如果挡住全部主要物体的 mask,则模型就使用背景去填充。

下图对比了只重建 masked patches (Prediction),或者重建所有的 patches (Reconstruction) 的结果。只重建 masked patches 的效果更好。

下图对比了不同大小的 mask patches 的重建结果。注意所有实验 mask ratio=0.6,结果发现当 mask patches 较小时,可以得到更好的重建结果。