MAE: 掩码自编码器是可扩展的视觉学习者.
本文设计了一种应用于计算机视觉的自监督学习方法,掩码自编码器(masked autoencoder, MAE)。MAE接收随机遮挡部分patch的图像为输入,并重构原始图像。
MAE的整个网络采用非对称的编码器-解码器结构。编码器只对未遮挡的图像块进行操作;解码器是轻量级的,旨在从编码特征和遮挡token中重建输入图像。
相比于语言任务,图像的信息密度低,即挡住图片的一部分patches,可以很容易地通过看它周围的 patches 而想象出它的样子来。因此通常对图像进行较大比例的遮挡(如$75\%$),此时掩码重建任务具有一定的难度,而且可以较大程度地减少了计算量和内存消耗,并降低预训练时间。
编码器采用ViT结构,只输入未遮挡的图像块序列,因此能够使用有限的内存和计算训练非常大的编码器。编码器的特征和用于表示遮挡图像块的遮挡token组合后作为解码器的输入,通过一组轻量级的Transformer模块重构原始图像。预训练完成后,解码器可以被丢弃,只使用编码器提取图像特征用于下游任务。
解码器输出的每一个元素表示一个遮挡图像块的像素值向量,损失函数计算原始图像和重构图像上遮挡部分的像素的均方误差。作者指出先计算出每个 patch 的像素值的均值和方差,并使用它们去归一化这个 patch 的每个像素值。最后再使用归一化的像素值进行 MSE Loss 计算,能够提高特征表示的质量。
在实际实现时,通过线性映射和位置编码为每一个图像块生成一个token,对token序列随机打乱(记住打乱顺序)后,根据掩码率删除序列的最后一部分,其保留的部分便是未遮挡的图像块序列,用作编码器的输入。
MAE的具体实现过程为:
- 首先通过Linear Projection和位置编码得到 image tokens。
- 随机 shuffle 这些 tokens,按照 masking ratio 扔掉最后的一部分。
- 把 unmasked patches 输出到 Encoder 中,得到这些 tokens 的表征。
- 把 Encoder 的输出结合 masked tokens (可学习的向量),执行 unshuffle 操作恢复顺序,再一起输入到 Decoder 中。
- shuffle 和 unshuffle 操作的时间开销可忽略不计。
class MAE(nn.Module):
def __init__(
self,
*,
encoder, # 传入ViT
decoder_dim,
masking_ratio = 0.75,
decoder_depth = 1,
decoder_heads = 8,
decoder_dim_head = 64
):
super().__init__()
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
self.masking_ratio = masking_ratio
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
self.encoder = encoder
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
self.to_patch = encoder.to_patch_embedding[0]
self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])
pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]
# decoder parameters
self.decoder_dim = decoder_dim
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)
def forward(self, img):
device = img.device
# get patches
patches = self.to_patch(img)
batch, num_patches, *_ = patches.shape
# patch to encoder tokens and add positions
tokens = self.patch_to_emb(patches)
if self.encoder.pool == "cls":
tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)]
elif self.encoder.pool == "mean":
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
num_masked = int(self.masking_ratio * num_patches)
rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
# get the unmasked tokens to be encoded
batch_range = torch.arange(batch, device = device)[:, None]
tokens = tokens[batch_range, unmasked_indices]
# get the patches to be masked for the final reconstruction loss
masked_patches = patches[batch_range, masked_indices]
# attend with vision transformer
encoded_tokens = self.encoder.transformer(tokens)
# project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder
decoder_tokens = self.enc_to_dec(encoded_tokens)
# reapply decoder position embedding to unmasked tokens
unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked)
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
# concat the masked tokens to the decoder tokens and attend with decoder
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
decoder_tokens[batch_range, masked_indices] = mask_tokens
decoded_tokens = self.decoder(decoder_tokens)
# splice out the mask tokens and project to pixel values
mask_tokens = decoded_tokens[batch_range, masked_indices]
pred_pixel_values = self.to_pixels(mask_tokens)
# calculate reconstruction loss
recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
return recon_loss
下面展示一些恢复结果:
作者针对不同的掩码率进行了实验。有趣的是,对图像进行约$75\%$的遮挡能够获得最好的效果,这和自然语言处理中使用的较低掩码率不同(BERT约$15\%$)。这可能是因为较大的遮挡使得模型必须学习有用的通用表示,而不是简单地通过线条或纹理来完成任务。
作者进一步进行了一些消融实验,其中fit表示对模型进行端到端的微调;lin表示仅微调输出端的线性层。
- 表(a)和(b)调整了解码器的深度和宽度;结果表明足够深且更窄的 Decoder能够在 fine-tuning 时获得较好的性能;
- 表(c)测试了编码器输入是否使用遮挡token;结果表明效果变差,这可能是因为在这种情况下预训练和部署之间存在差距。即在预训练的输入中有很大一部分是mask tokens,这在测试图像中是不存在的。
- 表(d)测试了不同的重构目标;结果表明使用归一化的像素值进行 MSE Loss计算效果更好。
- 表(e)测试了不同的数据增强;结果表明只使用cropping-only就比较好,在 MAE 中数据增强的角色其实是由 random masking 来扮演的,每个 iteration 的 mask 都不同,所以就相当于是产生了新的训练样本。
- 表(f)测试了不同的掩码采样方法,包括随机采样、按块采样和网格采样。结果表明简单的随机抽样最适合 MAE 模型。
作者比较了MAE和其他自监督模型的表现。对于 ViT-B 模型,所有的方法性能相似,但是对于更大的 ViT-L 模型,性能差距就拉开了,证明了 MAE 对于大模型的泛化性能。