DeepViT:构建更深的视觉Transformer.

CNN可以通过堆叠更多的卷积层来提高性能不同,transformer在层次更深时(如大于$12$层)会很快进入饱和,原因是随着transformer进入深层,计算得到的attention map变得越来越相似。本文作者设计了Re-attention,通过混合不同head生成的attention map以增强各层之间的多样性;基于该方法训练一个$32$层的ViT,在ImageNet上的Top-1 Acc提高了$1.6$个百分点。

作者通过实验发现在ViT中随着深度的加深,不同层之间的attention map变得越来越相似,这种现象称为attention collapse。层间attention map的相似度和四个因素有关:$p$和$q$是两个不同的层、$h$是注意力head、$t$是具体的输入,最后算的是两个层在同样的head和同样的输入下计算的attention map的余弦相似度,其值趋近于$1$时表示这两个attention map非常相似。

\[M_{h, t}^{p, q}=\frac{\mathbf{A}_{h,:, t}^p \mathbf{A}_{h,:, t}^q}{\left\|\mathbf{A}_{h,:, t}^p\right\|\left\|\mathbf{A}_{h,:, t}^q\right\|}\]

图1表示每层的attention map与周围$k$个层的对应attention map的相似性。随着深度的增加attention map越来越相似。图2表示随着层数的加深,相似的attention map的数量(红色线)增加,跨层的相似度(黑色线)增大。图3表示同一层不同head之间的相似性都低于$30\%$,它们呈现出足够的多样性。

为了解耦不同层之间的相似性,作者提出两种解决方法。第一种方法是增加自注意力模块的embedding dimension,即增加每个token的表达能力,使得生成的注意力图可以更加多样化,减少每个块的attention map之间的相似性。

作者设计了4种不同的embedding dimension,分别是$256,384,512,768$。如图所示,随着embedding dimension的增长,相似的block的数量在下降,同时模型的性能在上升,注意力坍塌的问题得以缓解。但增加embedding dimension也会显著增加计算成本,带来的性能改进往往会减少,且需要更大的数据量来训练,增加了过拟合的风险。

第二种方法是一个改进的模块Re-attention。注意到同一个层的不同head之间的相似度比较小,关注输入token 的不同方面。如果把不同的head的信息结合起来,利用重新构造一个attention map,能够避免注意力坍塌问题。

Re-attention采用一个可学习的变换矩阵$\Theta$和multi-head attention maps相乘来得到新的attention map,$\Theta$作用在head这个维度上:

\[\operatorname{Re-}\operatorname{Attention}(Q, K, V)=\operatorname{Norm}\left(\Theta^{\top}\left(\operatorname{Softmax}\left(\frac{Q K^{\top}}{\sqrt{d}}\right)\right)\right) V\]

Re-attention取代原始的Self-attention可以显著降低不同层的特征注意力图的相似性。

DeepViT的完整实现可参考vit-pytorch。其中Re-attention的实现如下:

from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.dropout = nn.Dropout(dropout)

        self.reattn_weights = nn.Parameter(torch.randn(heads, heads))

        self.reattn_norm = nn.Sequential(
            Rearrange('b h i j -> b i j h'),
            nn.LayerNorm(heads),
            Rearrange('b i j h -> b h i j')
        )

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        # attention

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attn = dots.softmax(dim=-1)
        attn = self.dropout(attn)

        # re-attention

        attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)
        attn = self.reattn_norm(attn)

        # aggregate and out

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out