BEiT:图像Transformer中的BERT预训练.

BEiTdVAE和基于BERTMIM(Mask Image Model)两个无监督模型的结合体,旨在通过被掩码掉的图像恢复图像的视觉标志来实现图像的预训练。BEiT主要由两部分构成:

BEIT的流程如图所示,它的上侧是一个dVAE模型,下侧是一个类似BERTEncoderdVAETokenizerDecoder组成,其中Tokenizer的作用是将图像的每个Patch编码成离散的视觉标志,Decoder的作用将视觉标志恢复成输入图像。BERT的输入是含有被掩码的图像的所有patch,预测的是dVAE生成的视觉标志。

图像Patch是将一个图像拆分成若干个不同的图像块,然后它们会被送到Transformer中进行模型的训练。对于一个图像$x\in R^{H\times W\times C}$,被分成$N$个图像块,其中$N=HW/P^2$,$P$是图像块的分辨率。这时可以用图像块组成的向量序列$x^p \in R^{N\times P^2C}$来表示输入图像。在BEiT中,原始图像的大小是$224\times 224$,图像Patch的大小是$16\times 16$,因此每个图像被分成了$14\times 14$个图像栅格。

另一方面,BEiT通过一个dVAE图像也抽象为一个信息密集的载体,即视觉标志(Visual Token)。dVAETokenizer用于将图像编码成视觉标志,Decoder用于将视觉标志还原成输入图像。具体的讲,每个图像$x\in R^{H\times W\times C}$可以表示为由$N$个时间标志组成的离散向量,表示为$z = [z_1,…,z_N] \in V^{N}$,其中字典$V$的大小$8192$。

BEIT 使用了Masked Image Modeling的自监督训练方式,随机盖住一些 tokens,让 BERT 模型预测盖住的tokens

  1. 把输入图像$x$编码为图像patch $x^p \in R^{N\times P^2C}$ 和视觉标志$z = [z_1,…,z_N] \in V^{N}$;
  2. 生成掩码$M$随机盖住$40\%$的图像patch,将其替换为可学习编码$e\in R^D$;
  3. 把掩码操作后的图像patch $x^M$通过BEiT编码器得到编码表示$h\in R^N$;
  4. 把盖住位置的输出编码表示$h^M$通过一个分类器,预测盖住的patch的相应的visual token
  5. 通过交叉熵损失最小化计算预测的 token 与真实的 token 之间的差异。

BEIT 并不是完全随机地盖住$40\%$,而是采取了 blockwise masking 的方法,即每次循环先通过下列算法计算出$s,r,a,t$,然后盖住$i \in [t,t+a], j \in [l,l+b]$的部分,直到盖住的部分超过了$40\%$为止。

BEIT的总损失函数包括视觉标志的重构损失与dVAE的变分下界损失:

\[\sum_{\left(x_i, \tilde{x}_i\right) \in \mathcal{D}}(\underbrace{E_{z_i \sim q_\phi\left(z \mid x_i\right)}\left[\log p_\psi\left(x_i \mid z_i\right)\right]}_{\text {Stage 1: Visual Token Reconstruction }}+\underbrace{\log p_\theta\left(\hat{z}_i \mid \tilde{x}_i\right)}_{\text {Stage 2: Masked Image Modeling }})\]

作者进行了一系列消融实现:

  1. Blockwise masking在两种下游任务的微调(分类和分割)中都是有利的,特别是在语义分割上。
  2. 盖住一个patch直接进行像素级的回归任务精度稍微变差了,说明预测 visual tokens 而不是直接进行pixel level的回归任务才是 BEIT 的关键。
  3. 不进行自监督预训练,即直接恢复所有image patches,性能也会下降。

下图可视化BEIT模型不同reference pointsattention map,方法是拿出BEIT的最后一个layer,假定一个参考点,选定它所在的patch,对应的attention map行拿出来,代表这个patch attend to所有patch的程度,再reshape成正方形。可以发现仅仅是预训练完以后,BEIT 就能够使用 self-attention 来区分不同的语义区域。通过BEIT获得的这些知识有可能提高微调模型的泛化能力,特别是在小数据集上。