CvT:向视觉Transformer中引入卷积.
卷积神经网络具有一些优良的特征,比如局部感受野、共享卷积权重、空间下采样等,从而在视觉任务上具有突出的表现。而视觉Transformer能够捕捉全局信息,比卷积网络具有更强的表示能力,因此往往需要更多的数据量以支持训练。本文提出了一种融合卷积与视觉Transformer的模型CvT,通过构造多阶段的层次结构,在ImageNet-1k和ImageNet-22k数据集上训练,达到了SOTA的性能。
CvT结构引入了两种卷积操作,分别叫做Convolutional Token Embedding和Convolutional Projection。
1. Convolutional Token Embedding
在每个stage中会进行下面的操作:输入的2D token map会先进入Convolutional Token Embedding这个层,相当于在2D reshaped token map上做一次卷积操作。这个层的输入是个reshape成2D的token。再通过一个Layer Normalization。卷积的目的是保证在每个阶段都减小token的数量,也就是减小feature resolution;在每个阶段都扩大token width,也就是扩大feature dimension。这样实现的效果就和CNN差不多,都是随着层数的加深而逐渐减小feature resolution和逐渐增加feature dimension。
假设前一层的输出维度是$x_{i-1}\in R^{H_{i-1}\times W_{i-1} \times C_{i-1}}$,先通过Convolutional Token Embedding的常规的卷积操作得到$f(x_{i-1})\in R^{H_{i}\times W_{i} \times C_{i}}$,再把它flatten成一个$H_iW_i\times C_i$的张量并进行Layer Normalization操作,得到的结果进入下面的第$i$个stage的Transformer Block的操作。这些操作的目的是保证在每个阶段都减小token的数量,也就是减小feature resolution;在每个阶段都扩大token width,也就是扩大feature dimension。
2. Convolutional Projection
在每个stage中,Convolutional Token Embedding的输出会再通过Convolutional Transformer Blocks。这个结构长得和普通Transformer的Block差不多,只是把普通Transformer的Block中的Linear Projection操作换成了Convolutional Projection操作,说白了就是用Depth-wise separable convolution操作来代替了Linear Projection操作。
具体来讲,token首先reshape成2D的token map,再分别通过3个Depthwise-separable Convolution (kernel $=s\times s$)变成query,key和value值。最后再把这些query,key和value值通过flatten操作。
对于常规的Convolution所需的参数量和计算量分别是$s^2C^2$和$O(s^2C^2T)$。式中$C$是token的channel dimension,$T$是token的数量。Depthwise-separable Convolution所需的参数量和计算量分别是$s^2C$和$O(s^2CT)$。
为了使得模型进一步简化,作者又提出了Squeezed convolutional projection操作。在计算query时,采用的Depthwise-separable Convolution的stride值为1。在计算key和value时,采用的Depthwise-separable Convolution的stride值为2。按照这种方式,token的数量对于key和value来说可以减少4倍,性能只有很少的下降。
此外,CvT不再采用位置编码(卷积的zero-padding操作可以暗含位置信息);class token只加在最后一个stage里面。
CvT的完整实现可参考vit-pytorch。
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class DepthWiseConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.BatchNorm2d(dim_in),
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
padding = proj_kernel // 2
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
shape = x.shape
b, n, _, y, h = *shape, self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64, mlp_mult = 4, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class CvT(nn.Module):
def __init__(
self,
*,
num_classes,
s1_emb_dim = 64, # stage 1 - dimension
s1_emb_kernel = 7, # stage 1 - conv kernel
s1_emb_stride = 4, # stage 1 - conv stride
s1_proj_kernel = 3, # stage 1 - attention ds-conv kernel size
s1_kv_proj_stride = 2, # stage 1 - attention key / value projection stride
s1_heads = 1, # stage 1 - heads
s1_depth = 1, # stage 1 - depth
s1_mlp_mult = 4, # stage 1 - feedforward expansion factor
s2_emb_dim = 192, # stage 2 - (same as above)
s2_emb_kernel = 3,
s2_emb_stride = 2,
s2_proj_kernel = 3,
s2_kv_proj_stride = 2,
s2_heads = 3,
s2_depth = 2,
s2_mlp_mult = 4,
s3_emb_dim = 384, # stage 3 - (same as above)
s3_emb_kernel = 3,
s3_emb_stride = 2,
s3_proj_kernel = 3,
s3_kv_proj_stride = 2,
s3_heads = 4,
s3_depth = 10,
s3_mlp_mult = 4,
dropout = 0.
):
super().__init__()
dim = 3
layers = []
for prefix in ('s1_', 's2_', 's3_'):
layers.append(nn.Sequential(
nn.Conv2d(dim, config[prefix+'emb_dim'], kernel_size = config[prefix+'emb_kernel'], padding = (config[prefix+'emb_kernel'] // 2), stride = config[prefix+'emb_stride']),
LayerNorm(config[prefix+'emb_dim']),
Transformer(dim = config[prefix+'emb_dim'], proj_kernel = config[prefix+'proj_kernel'], kv_proj_stride = config[prefix+'kv_proj_stride'], depth = config[prefix+'depth'], heads = config[prefix+'heads'], mlp_mult = config[prefix+'mlp_mult'], dropout = dropout)
))
dim = config[prefix+'emb_dim']
self.layers = nn.Sequential(*layers)
self.to_logits = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...'),
nn.Linear(dim, num_classes)
)
def forward(self, x):
latents = self.layers(x)
return self.to_logits(latents)