CvT:向视觉Transformer中引入卷积.

卷积神经网络具有一些优良的特征,比如局部感受野、共享卷积权重、空间下采样等,从而在视觉任务上具有突出的表现。而视觉Transformer能够捕捉全局信息,比卷积网络具有更强的表示能力,因此往往需要更多的数据量以支持训练。本文提出了一种融合卷积与视觉Transformer的模型CvT,通过构造多阶段的层次结构,在ImageNet-1kImageNet-22k数据集上训练,达到了SOTA的性能。

CvT结构引入了两种卷积操作,分别叫做Convolutional Token EmbeddingConvolutional Projection

1. Convolutional Token Embedding

在每个stage中会进行下面的操作:输入的2D token map会先进入Convolutional Token Embedding这个层,相当于在2D reshaped token map上做一次卷积操作。这个层的输入是个reshape2Dtoken。再通过一个Layer Normalization。卷积的目的是保证在每个阶段都减小token的数量,也就是减小feature resolution;在每个阶段都扩大token width,也就是扩大feature dimension。这样实现的效果就和CNN差不多,都是随着层数的加深而逐渐减小feature resolution和逐渐增加feature dimension

假设前一层的输出维度是$x_{i-1}\in R^{H_{i-1}\times W_{i-1} \times C_{i-1}}$,先通过Convolutional Token Embedding的常规的卷积操作得到$f(x_{i-1})\in R^{H_{i}\times W_{i} \times C_{i}}$,再把它flatten成一个$H_iW_i\times C_i$的张量并进行Layer Normalization操作,得到的结果进入下面的第$i$个stageTransformer Block的操作。这些操作的目的是保证在每个阶段都减小token的数量,也就是减小feature resolution;在每个阶段都扩大token width,也就是扩大feature dimension

2. Convolutional Projection

在每个stage中,Convolutional Token Embedding的输出会再通过Convolutional Transformer Blocks。这个结构长得和普通TransformerBlock差不多,只是把普通TransformerBlock中的Linear Projection操作换成了Convolutional Projection操作,说白了就是用Depth-wise separable convolution操作来代替了Linear Projection操作。

具体来讲,token首先reshape2Dtoken map,再分别通过3Depthwise-separable Convolution (kernel $=s\times s$)变成querykeyvalue值。最后再把这些querykeyvalue值通过flatten操作。

对于常规的Convolution所需的参数量和计算量分别是$s^2C^2$和$O(s^2C^2T)$。式中$C$是tokenchannel dimension,$T$是token的数量。Depthwise-separable Convolution所需的参数量和计算量分别是$s^2C$和$O(s^2CT)$。

为了使得模型进一步简化,作者又提出了Squeezed convolutional projection操作。在计算query时,采用的Depthwise-separable Convolutionstride值为1。在计算keyvalue时,采用的Depthwise-separable Convolutionstride值为2。按照这种方式,token的数量对于keyvalue来说可以减少4倍,性能只有很少的下降。

此外,CvT不再采用位置编码(卷积的zero-padding操作可以暗含位置信息);class token只加在最后一个stage里面。

CvT的完整实现可参考vit-pytorch

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim, dim * mult, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(dim * mult, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class DepthWiseConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.BatchNorm2d(dim_in),
            nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
        )
    def forward(self, x):
        return self.net(x)
        
class Attention(nn.Module):
    def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        padding = proj_kernel // 2
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
        self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        shape = x.shape
        b, n, _, y, h = *shape, self.heads
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))

        dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64, mlp_mult = 4, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class CvT(nn.Module):
    def __init__(
        self,
        *,
        num_classes,
        s1_emb_dim = 64,        # stage 1 - dimension
        s1_emb_kernel = 7,      # stage 1 - conv kernel
        s1_emb_stride = 4,      # stage 1 - conv stride
        s1_proj_kernel = 3,     # stage 1 - attention ds-conv kernel size
        s1_kv_proj_stride = 2,  # stage 1 - attention key / value projection stride
        s1_heads = 1,           # stage 1 - heads
        s1_depth = 1,           # stage 1 - depth
        s1_mlp_mult = 4,        # stage 1 - feedforward expansion factor
        s2_emb_dim = 192,       # stage 2 - (same as above)
        s2_emb_kernel = 3,
        s2_emb_stride = 2,
        s2_proj_kernel = 3,
        s2_kv_proj_stride = 2,
        s2_heads = 3,
        s2_depth = 2,
        s2_mlp_mult = 4,
        s3_emb_dim = 384,       # stage 3 - (same as above)
        s3_emb_kernel = 3,
        s3_emb_stride = 2,
        s3_proj_kernel = 3,
        s3_kv_proj_stride = 2,
        s3_heads = 4,
        s3_depth = 10,
        s3_mlp_mult = 4,
        dropout = 0.
    ):
        super().__init__()
        dim = 3
        layers = []

        for prefix in ('s1_', 's2_', 's3_'):
            layers.append(nn.Sequential(
                nn.Conv2d(dim, config[prefix+'emb_dim'], kernel_size = config[prefix+'emb_kernel'], padding = (config[prefix+'emb_kernel'] // 2), stride = config[prefix+'emb_stride']),
                LayerNorm(config[prefix+'emb_dim']),
                Transformer(dim = config[prefix+'emb_dim'], proj_kernel = config[prefix+'proj_kernel'], kv_proj_stride = config[prefix+'kv_proj_stride'], depth = config[prefix+'depth'], heads = config[prefix+'heads'], mlp_mult = config[prefix+'mlp_mult'], dropout = dropout)
            ))

            dim = config[prefix+'emb_dim']

        self.layers = nn.Sequential(*layers)

        self.to_logits = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            Rearrange('... () () -> ...'),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        latents = self.layers(x)
        return self.to_logits(latents)