通过局部多尺度重构进行掩码图像建模.

现有的MIM方法通常训练成本很高,而实践中期望其能够从海量无标记数据(如网络上随机爬取的图像)中学习通用的知识,所以高昂的预训练成本限制了其工业落地。MIM方法的计算量在于编码器和解码器,由于解码器可以很小,已有的加速预训练的方法都通过降低编码器的计算量来加速编码过程。

不同于已有的思路,本文转换视角,从表征学习过程本身来深入思考现有方法的不足。现有全部MIM方法只在顶层引入重构任务,使得较低层无法获得直接的指导,从而只能通过缓慢的学习过程来学习patch表征及语义关联,拖累了整体的表征学习过程。尤其对于一些金字塔型主干网络,其较低层往往有着远多于顶层的patch(如Swin-2243136(最底层)vs 49(顶层))。

实际上较低层在表征学习中扮演关键角色:1)良好学习的较低层可以将知识传递给较高层以促进其学习;2)在下游任务微调时,较高层通常快速适应到新任务中,而较低层变化较慢,需要在预训练时就得到充分学习。

为了更直观地展现模型不同层对patch间语义关联的学习程度,考察不同层的query patchkey patch之间的标准化互信息(Normalized Mutual Information,NMI)。较高的NMI值意味着注意力强烈地依赖于query patch,如图所示已有的很多经典模型在较低层的注意力并不像顶层一样强烈地依赖于query patch

在掩码图像建模中较低层的学习很关键,然而目前所有MIM方法都只显式地指导顶层的学习。为此,考虑到重构任务需要patch之间的语义推理才能完成,将重构任务引入多个局部层以显式地进行有意义的指导。进一步发现直接地将顶层的重构任务引入到多个局部层增益不明显,原因可能是多个不同的局部层需要学习不同粒度的信息。为此考虑从原始输入中提取不同尺度的监督信号来指导多个局部层的学习。

具体的,对于原始输入,为了获得监督信号,已有方法通常首先将图像划分为不重叠的区域,该划分与构造编码器输入的划分对齐。然后使用恰当的特征描述算子(如像素标准化,HOG或预训练的codebook)提取每个区域的特征作为监督信号。通常可以认为在粗糙划分下每个区域捕捉原始输入相对high-level的语义信息,比如目标的部分或整体形状;而精细划分下每个区域捕捉相对low-level的语义信息,比如边,角或纹理。

编码器只输入可见patch。从原始输入中构造多尺度的监督信号来分别用于多个局部层的重构,令较低层重构细尺度的监督信号而较高层重构粗尺度的监督信号。对于金字塔结构的模型,通常已经划分为多个stage,将重构任务用于每个stage的末端;对于柱状架构,参照金字塔架构的经验,选择部分层进行重构。

解码器由三部分组成:推理部分(transformer blocks)+缩放部分(Deconvolution/Pool)+预测部分(MLP)。推理部分负责基于可见patch的表征推理被遮挡patch的信息;缩放部分是处理特征尺度与监督信号尺度不一致的情况,比如ViT这种柱状结构每层特征尺度不变而监督信号尺度是变化的,当不匹配时需要使用反卷积或池化操作进行上/下采样;预测部分负责整合放缩后的预测来作为最终输出。

实验在柱状架构ViT以及金字塔架构Swin上验证LocalMIM的有效性,出于其简单性,只考察像素标准化和HOG两种特征描述算子。LocalMIM比已有模型显著更加高效。具体的,就ImageNet-1Ktop-1微调准确率而言,LocalMIM分别以3.1倍和5.6倍的加速达到MAEMaskFeat的最佳表现,以3.6倍和6.4倍的加速达到SimMIM192GreenMIM的最佳表现。相较于其他模型,LocalMIM也以显著更少的预训练时长实现了可比较的表现。

实验进一步在训练过程中对选择的层进行梯度截断,即每阶段的参数只能接收来自该阶段重构任务的反传梯度,接收不到来自更高层的梯度。即便没有全局的反向传播梯度只使用局部的监督梯度也能很好的指导主干网络每层的表征学习,这一方面展现了引入的局部监督任务的优越性,另一方面也展现了神经网络解耦训练的可能性。