DeepViT:构建更深的视觉Transformer.
和CNN可以通过堆叠更多的卷积层来提高性能不同,transformer在层次更深时(如大于$12$层)会很快进入饱和,原因是随着transformer进入深层,计算得到的attention map变得越来越相似。本文作者设计了Re-attention,通过混合不同head生成的attention map以增强各层之间的多样性;基于该方法训练一个$32$层的ViT,在ImageNet上的Top-1 Acc提高了$1.6$个百分点。
作者通过实验发现在ViT中随着深度的加深,不同层之间的attention map变得越来越相似,这种现象称为attention collapse。层间attention map的相似度和四个因素有关:$p$和$q$是两个不同的层、$h$是注意力head、$t$是具体的输入,最后算的是两个层在同样的head和同样的输入下计算的attention map的余弦相似度,其值趋近于$1$时表示这两个attention map非常相似。
\[M_{h, t}^{p, q}=\frac{\mathbf{A}_{h,:, t}^p \mathbf{A}_{h,:, t}^q}{\left\|\mathbf{A}_{h,:, t}^p\right\|\left\|\mathbf{A}_{h,:, t}^q\right\|}\]图1表示每层的attention map与周围$k$个层的对应attention map的相似性。随着深度的增加attention map越来越相似。图2表示随着层数的加深,相似的attention map的数量(红色线)增加,跨层的相似度(黑色线)增大。图3表示同一层不同head之间的相似性都低于$30\%$,它们呈现出足够的多样性。
为了解耦不同层之间的相似性,作者提出两种解决方法。第一种方法是增加自注意力模块的embedding dimension,即增加每个token的表达能力,使得生成的注意力图可以更加多样化,减少每个块的attention map之间的相似性。
作者设计了4种不同的embedding dimension,分别是$256,384,512,768$。如图所示,随着embedding dimension的增长,相似的block的数量在下降,同时模型的性能在上升,注意力坍塌的问题得以缓解。但增加embedding dimension也会显著增加计算成本,带来的性能改进往往会减少,且需要更大的数据量来训练,增加了过拟合的风险。
第二种方法是一个改进的模块Re-attention。注意到同一个层的不同head之间的相似度比较小,关注输入token 的不同方面。如果把不同的head的信息结合起来,利用重新构造一个attention map,能够避免注意力坍塌问题。
Re-attention采用一个可学习的变换矩阵$\Theta$和multi-head attention maps相乘来得到新的attention map,$\Theta$作用在head这个维度上:
\[\operatorname{Re-}\operatorname{Attention}(Q, K, V)=\operatorname{Norm}\left(\Theta^{\top}\left(\operatorname{Softmax}\left(\frac{Q K^{\top}}{\sqrt{d}}\right)\right)\right) V\]用Re-attention取代原始的Self-attention可以显著降低不同层的特征注意力图的相似性。
DeepViT的完整实现可参考vit-pytorch。其中Re-attention的实现如下:
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.dropout = nn.Dropout(dropout)
self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
self.reattn_norm = nn.Sequential(
Rearrange('b h i j -> b i j h'),
nn.LayerNorm(heads),
Rearrange('b i j h -> b h i j')
)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
# attention
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
attn = self.dropout(attn)
# re-attention
attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)
attn = self.reattn_norm(attn)
# aggregate and out
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out