Refiner:精炼视觉Transformer中的自注意力机制.

Refiner 是一个针对 Self-attention 机制本身做改进的工作。改进点包括:

Self-attention 机制如下图所示,图片分成一个个patches之后,每个patch可以看做是一个token。对于每个token来讲,Self-attention 机制会建模它与其他所有tokens之间的相关性。这样,一个层的 tokens 之间就能够充分地相互交换信息,从而给模型提供了很好的表达能力。但是问题是:不同的 tokens 会变得越来越相似,尤其是当模型加深时。这种现象也称为过度平滑 (over-smoothing)。在这个工作中,作者通过直接 Refine attention map来解决这个问题。

作者认为过度平滑问题会导致特征在通过模型的不同 Block 时变化会比较缓慢。作者使用CKA 相似度 (Centered Kernel Alignment similarity)衡量每个Block输出的中间 token 和最后1Block 的相似性,并根据第1个和最后1Block 的相似性作归一化。 这样的度量捕获了2个属性:中间特征收敛到最终特征的速度有多快,以及最终特征与第一层特征有多大的不同。

如下图所示是ViT,DeiT,ResNet 3种模型的 CKA 相似度,ViT模型随着层数的加深,中间层与最后一层特征的相似性增加缓慢,但是DeiTResNet 的这种相似性增长迅速,并且最后1层和第1层的差异比 ViT 大。

⚪ Attention Expansion

增加 head 的数量可以有效地提高模型的性能,但是这就会使得注意力图不够全面和准确:假如模型的 Embedding dimension 是固定的,直接增加 head 的数量,会导致每个headdimension变小,削弱了注意力图的表达能力。

为了解决这一难题,作者探索了注意力扩展,使用线性变换矩阵\(W_A \in \mathbb{R}^{H' \times H}\)将 multi-head self-attention map 线性映射到高维空间中。因此,在增加 head 的数量时,也能够保证每个headdimension是不变的。使模型既能享受到更多 head 的好处,又能享受到高嵌入维度的好处。

⚪ Distributed Local Attention (DLA)

作者认为 过度平滑 (over-smoothing) 的第2个原因是ViT模型忽略了 tokens 之间的 local relationship,局部性 (locality) 和空间不变性 (权重共享) 已被证明是 CNN 成功的关键,因此作者想把卷积融入 attention 里面,来同时利用 attention 机制的全局建模能力和卷积操作的局部建模能力。

具体的做法是对attention map的每一个head使用卷积核\(w \in \mathbb{R}^{k \times k}\)进行一步卷积操作,然后使用这个新的 attention mapvaluefeature aggregation 操作。

class DLA(nn.Module):
    def __init__(self, inp, oup, kernel_size = 3, stride=1, expand_ratio = 3, refine_mode='none'):
        super(DLA, self).__init__()
        """
            Distributed Local Attention used for refining the attention map.
        """

        hidden_dim = round(inp * expand_ratio)
        self.expand_ratio = expand_ratio
        self.identity = stride == 1 and inp == oup
        self.inp, self.oup = inp, oup
        self.high_dim_id = False
        self.refine_mode = refine_mode

        if refine_mode == 'conv':
            self.conv = Conv2dSamePadding(hidden_dim, hidden_dim, (kernel_size,kernel_size), stride, (1,1), groups=1, bias=False)
        elif refine_mode == 'conv_exapnd':
            if self.expand_ratio != 1:
                self.conv_exp = Conv2dSamePadding(inp, hidden_dim, 1, 1, bias=False)
                self.bn1 = nn.BatchNorm2d(hidden_dim)   
            self.depth_sep_conv = Conv2dSamePadding(hidden_dim, hidden_dim, (kernel_size,kernel_size), stride, (1,1), groups=hidden_dim, bias=False)
            self.bn2 = nn.BatchNorm2d(hidden_dim)

            self.conv_pro = Conv2dSamePadding(hidden_dim, oup, 1, 1, bias=False)
            self.bn3 = nn.BatchNorm2d(oup)

            self.relu = nn.ReLU6(inplace=True)

    def forward(self, input):
        x= input
        if self.refine_mode == 'conv':
            return self.conv(x)
        else:
            if self.expand_ratio !=1:
                x = self.relu(self.bn1(self.conv_exp(x)))
            x = self.relu(self.bn2(self.depth_sep_conv(x)))
            x = self.bn3(self.conv_pro(x))
            if self.identity:
                return x + input
            else:
                return x

⚪ Refiner

Refiner首先通过 Linear Expansion 来对 attention maphead数量进行扩展,再进行 Head-wise 的卷积操作。然后再通过 Linear Reduction 来对 attention map 进行特征削减来匹配 value 的维度。最后把特征削减之后的 attention mapvalue 作矩阵相乘,得到这个 head 的输出。

class Refined_Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,expansion_ratio = 3, 
                        share_atten=False, apply_transform=True, refine_mode='conv_exapnd', kernel_size=3, head_expand=None):
        """
            refine_mode: "conv" represents only convolution is used for refining the attention map;
                         "conv-expand" represents expansion and conv are used together for refining the attention map; 
            share_atten: If set True, the attention map is not generated; use the attention map generated from the previous block
        """
        super().__init__()
        self.num_heads = num_heads
        self.share_atten = share_atten
        head_dim = dim // num_heads
        self.apply_transform = apply_transform
        
        self.scale = qk_scale or head_dim ** -0.5

        if self.share_atten:
            self.DLA = DLA(self.num_heads,self.num_heads, refine_mode=refine_mode)
            self.adapt_bn = nn.BatchNorm2d(self.num_heads)
            self.qkv = nn.Linear(dim, dim, bias=qkv_bias)
        elif apply_transform:
            self.DLA = DLA(self.num_heads,self.num_heads, kernel_size=kernel_size, refine_mode=refine_mode, expand_ratio=head_expand)
            self.adapt_bn = nn.BatchNorm2d(self.num_heads)
            self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
        else:
            self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, atten=None):
        B, N, C = x.shape

        if self.share_atten:
            attn = atten
            attn = self.adapt_bn(self.DLA(attn)) * self.scale 

            v = self.qkv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
            attn_next = atten

        else:
            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]   

            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1) + atten * self.scale if atten is not None else attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            if self.apply_transform:
                attn = self.adapt_bn(self.DLA(attn))  
            attn_next = attn

        x = (attn @ v).transpose(1, 2).reshape(B, attn.shape[-1], C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn_next