Twins:重新思考视觉Transformer中的空间注意力设计.

相较于CNN来说,Transformer由于其能高效地捕获远距离依赖的特性,近期在计算机视觉领域也引领了一波潮流。Transformer主要是依靠Self-Attention去捕获各个token之间的关系,但是这种Global Self-Attention的计算复杂度太高,不利于在token数目较多的密集检测任务(分割、检测)中使用。

基于以上考虑,目前主流有两种应对方法:

  1. 一种是以SwinTransformer为代表的Locally-Grouped Self-Attention。其在不重叠的窗口内计算Self-Attention,当窗口大小固定时,整体的计算复杂度将下降,然后再通过其他方法去实现窗口间的互动,例如SwinTransformer中的Shift-Window方法。但这种方法的缺点在于窗口的大小会不一致,不利于现代深度学习框架的优化和加速。
  2. 一种是以PVT为代表的Sub-Sampled Version Self-Attention。其在计算Self-Attention前,会先对QKV Token进行下采样,从而降低计算复杂度。

本文整体思路可以认为是PVT+SwinTransformer的结合:在局部窗口内部计算Self-Attention(SwinTransformer),同时对每个窗口内部的特征进行压缩,然后再使用一个全局Attention机制去捕获各个窗口的关系(PVT)。

⚪ Twins-PCPVT

PVT中的Global Sub-Sample Attention是十分高效的,当配合上合适的Positional Encodings(Conditional Positional Encoding)时,其能取得媲美甚至超过目前SOTATransformer结构。

PVT通过逐步融合各个Patch的方式,形成了一种多尺度的结构,使得其更适合用于密集预测任务例如目标检测或者是语义分割,其继承了ViTDeiTLearnable Positional Encoding的设计,所有的Layer均直接使用Global Attention机制,并通过Spatial Reduction的方式去降低计算复杂度。

作者通过实验发现,PVTSwinTransformer的性能差异主要来自于PVT没有采用一个合适的Positional Encoding方式,通过采用Conditional Positional Encoding(CPE)去替换PVT中的PEPVT即可获得与当前最好的SwinTransformer相近的性能。

⚪ Twins-SVT

更进一步,基于Separable Depthwise Convolution的思想,本文提出了一个Spatially Separable Self-Attention(SSSA)。该模块仅包含矩阵乘法,在现代深度学习框架下能够得到优化和加速。通过提出的Spatially Separable Self-Attention(SSSA)去缓解Self-Attention的计算复杂度过高的问题。SSSA由两个部分组成:Locally-Grouped Self-Attention(LSA)Global Sub-Sampled Attention(GSA)

(1) Locally-Grouped Self-Attention(LSA)

首先将2D feature map划分为多个Sub-Windows,并仅在Window内部进行Self-Attention计算,计算量会大大减少,由$O(H^2W^2d)$下降至$O(k_1k_2HWd)$,其中$k_1=\frac{H}{m},k_2=\frac{W}{n}$,当$k_1,k_2$固定时,计算复杂度将仅与$HW$呈线性关系。

class LocalAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., patch_size = 7):
        super().__init__()
        inner_dim = dim_head *  heads
        self.patch_size = patch_size
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
        self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, fmap):
        shape, p = fmap.shape, self.patch_size
        b, n, x, y, h = *shape, self.heads
        x, y = map(lambda t: t // p, (x, y))

        fmap = rearrange(fmap, 'b c (x p1) (y p2) -> (b x y) c p1 p2', p1 = p, p2 = p)

        q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1))
        q, k, v = map(lambda t: rearrange(t, 'b (h d) p1 p2 -> (b h) (p1 p2) d', h = h), (q, k, v))

        dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

        attn = dots.softmax(dim = - 1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b x y h) (p1 p2) d -> b (h d) (x p1) (y p2)', h = h, x = x, y = y, p1 = p, p2 = p)
        return self.to_out(out)

(2) Global Sub-Sampled Attention(GSA)

LSA缺乏各个Window之间的信息交互,比较简单的一个方法是,在LSA后面再接一个Global Self-Attention Layer,这种方法在实验中被证明也是有效的,但是其计算复杂度会较高。

另一个思路是,将每个Window提取一个维度较低的特征作为各个window的表征,然后基于这个表征再去与各个window进行交互,相当于Self-Attention中的Key的作用,这样一来,计算复杂度会下降至:$O(mnHWd)=O(\frac{H^2W^2d}{k_1k_2})$。

这种方法实际上相当于对feature map进行下采样,因此被命名为Global Sub-Sampled Attention。 综合使用LSAGSA,可以取得类似于Separable Convolution(Depth-wise+Point-wise)的效果,整体的计算复杂度为:$O(\frac{H^2W^2d}{k_1k_2}+k_1k_2HWd)$。同时有:$\frac{H^2W^2d}{k_1k_2}+k_1k_2HWd \geq 2HWd\sqrt{HW}$,当且仅当$k_1k_2 = \sqrt{HW}$。

考虑到分类任务中,$H=W=224$是比较常规的设置,同时使用方形框,则有$k_1=k_2$,第一个stagefeature map大小为$56$,可得$k_1=k_2=\sqrt{56}=7$。 当然可以针对各个Stage去设定其窗口大小,不过为了简单性,所有的$k$均设置为$7$。

class GlobalAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., k = 7):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
        self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)

        self.dropout = nn.Dropout(dropout)

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        shape = x.shape
        b, n, _, y, h = *shape, self.heads
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))

        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))

        dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

        attn = dots.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
        return self.to_out(out)

Twins-SVT的完整实现可参考vit-pytorch