Refiner:精炼视觉Transformer中的自注意力机制.
Refiner 是一个针对 Self-attention 机制本身做改进的工作。改进点包括:
- Attention Expansion:通过增大channel数来增加 head 的数量。
- Distributed Local Attention:通过融入卷积来增强 attention 局部信息的交互。
Self-attention 机制如下图所示,图片分成一个个patches之后,每个patch可以看做是一个token。对于每个token来讲,Self-attention 机制会建模它与其他所有tokens之间的相关性。这样,一个层的 tokens 之间就能够充分地相互交换信息,从而给模型提供了很好的表达能力。但是问题是:不同的 tokens 会变得越来越相似,尤其是当模型加深时。这种现象也称为过度平滑 (over-smoothing)。在这个工作中,作者通过直接 Refine attention map来解决这个问题。
作者认为过度平滑问题会导致特征在通过模型的不同 Block 时变化会比较缓慢。作者使用CKA 相似度 (Centered Kernel Alignment similarity)衡量每个Block输出的中间 token 和最后1个 Block 的相似性,并根据第1个和最后1个 Block 的相似性作归一化。 这样的度量捕获了2个属性:中间特征收敛到最终特征的速度有多快,以及最终特征与第一层特征有多大的不同。
如下图所示是ViT,DeiT,ResNet 3种模型的 CKA 相似度,ViT模型随着层数的加深,中间层与最后一层特征的相似性增加缓慢,但是DeiT 和 ResNet 的这种相似性增长迅速,并且最后1层和第1层的差异比 ViT 大。
⚪ Attention Expansion
增加 head 的数量可以有效地提高模型的性能,但是这就会使得注意力图不够全面和准确:假如模型的 Embedding dimension 是固定的,直接增加 head 的数量,会导致每个head的dimension变小,削弱了注意力图的表达能力。
为了解决这一难题,作者探索了注意力扩展,使用线性变换矩阵\(W_A \in \mathbb{R}^{H' \times H}\)将 multi-head self-attention map 线性映射到高维空间中。因此,在增加 head 的数量时,也能够保证每个head的dimension是不变的。使模型既能享受到更多 head 的好处,又能享受到高嵌入维度的好处。
⚪ Distributed Local Attention (DLA)
作者认为 过度平滑 (over-smoothing) 的第2个原因是ViT模型忽略了 tokens 之间的 local relationship,局部性 (locality) 和空间不变性 (权重共享) 已被证明是 CNN 成功的关键,因此作者想把卷积融入 attention 里面,来同时利用 attention 机制的全局建模能力和卷积操作的局部建模能力。
具体的做法是对attention map的每一个head使用卷积核\(w \in \mathbb{R}^{k \times k}\)进行一步卷积操作,然后使用这个新的 attention map 与 value 作 feature aggregation 操作。
class DLA(nn.Module):
def __init__(self, inp, oup, kernel_size = 3, stride=1, expand_ratio = 3, refine_mode='none'):
super(DLA, self).__init__()
"""
Distributed Local Attention used for refining the attention map.
"""
hidden_dim = round(inp * expand_ratio)
self.expand_ratio = expand_ratio
self.identity = stride == 1 and inp == oup
self.inp, self.oup = inp, oup
self.high_dim_id = False
self.refine_mode = refine_mode
if refine_mode == 'conv':
self.conv = Conv2dSamePadding(hidden_dim, hidden_dim, (kernel_size,kernel_size), stride, (1,1), groups=1, bias=False)
elif refine_mode == 'conv_exapnd':
if self.expand_ratio != 1:
self.conv_exp = Conv2dSamePadding(inp, hidden_dim, 1, 1, bias=False)
self.bn1 = nn.BatchNorm2d(hidden_dim)
self.depth_sep_conv = Conv2dSamePadding(hidden_dim, hidden_dim, (kernel_size,kernel_size), stride, (1,1), groups=hidden_dim, bias=False)
self.bn2 = nn.BatchNorm2d(hidden_dim)
self.conv_pro = Conv2dSamePadding(hidden_dim, oup, 1, 1, bias=False)
self.bn3 = nn.BatchNorm2d(oup)
self.relu = nn.ReLU6(inplace=True)
def forward(self, input):
x= input
if self.refine_mode == 'conv':
return self.conv(x)
else:
if self.expand_ratio !=1:
x = self.relu(self.bn1(self.conv_exp(x)))
x = self.relu(self.bn2(self.depth_sep_conv(x)))
x = self.bn3(self.conv_pro(x))
if self.identity:
return x + input
else:
return x
⚪ Refiner
Refiner首先通过 Linear Expansion 来对 attention map 的head数量进行扩展,再进行 Head-wise 的卷积操作。然后再通过 Linear Reduction 来对 attention map 进行特征削减来匹配 value 的维度。最后把特征削减之后的 attention map 和 value 作矩阵相乘,得到这个 head 的输出。
class Refined_Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,expansion_ratio = 3,
share_atten=False, apply_transform=True, refine_mode='conv_exapnd', kernel_size=3, head_expand=None):
"""
refine_mode: "conv" represents only convolution is used for refining the attention map;
"conv-expand" represents expansion and conv are used together for refining the attention map;
share_atten: If set True, the attention map is not generated; use the attention map generated from the previous block
"""
super().__init__()
self.num_heads = num_heads
self.share_atten = share_atten
head_dim = dim // num_heads
self.apply_transform = apply_transform
self.scale = qk_scale or head_dim ** -0.5
if self.share_atten:
self.DLA = DLA(self.num_heads,self.num_heads, refine_mode=refine_mode)
self.adapt_bn = nn.BatchNorm2d(self.num_heads)
self.qkv = nn.Linear(dim, dim, bias=qkv_bias)
elif apply_transform:
self.DLA = DLA(self.num_heads,self.num_heads, kernel_size=kernel_size, refine_mode=refine_mode, expand_ratio=head_expand)
self.adapt_bn = nn.BatchNorm2d(self.num_heads)
self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
else:
self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, atten=None):
B, N, C = x.shape
if self.share_atten:
attn = atten
attn = self.adapt_bn(self.DLA(attn)) * self.scale
v = self.qkv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
attn_next = atten
else:
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1) + atten * self.scale if atten is not None else attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if self.apply_transform:
attn = self.adapt_bn(self.DLA(attn))
attn_next = attn
x = (attn @ v).transpose(1, 2).reshape(B, attn.shape[-1], C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn_next