LeViT:以卷积网络的形式进行快速推理的视觉Transformer.
本文的目的是减小视觉Transformer在不同设备上的推理时间 (inference speed)。作者对比了3个模型的准确率 v.s. 处理速度 (每秒处理的图片数量) 。所提LeViT模型做到了更好的精度-速度权衡,在达到高吞吐率的前提下进一步提升了性能。
作者首先做了把CNN和Transformer结构嫁接的实验,每个模型都是ResNet和DeiT-S模型的混合,作者发现当ResNet取前2个stage,DeiT-S取6个层的时候,参数量只有13.1M,每秒就能够处理1048张图片,而且可以达到最高的精度80.9。
这5组实验的训练曲线如图所示。作者发现嫁接的模型在一开始的训练曲线很接近ResNet-50;而在训练的末期又很接近DeiT-S的曲线。作者假设嫁接模型的卷积层的inductive biases的能力使得模型在浅层能很好地学习到low-level的信息,这些信息使得输入给 Transformer Block 的patch embedding 更有意义,使整个嫁接模型收敛得更快了。而后面的Transformer Block则提升了模型的整体精度。
基于此作者在视觉Transformer中引入了卷积操作,做成了一个类似于CNN中的LeNet的架构(卷积+下采样),因此称之为LeViT。
该模型的主要特点为:
- 借助卷积层提取特征:输入图片的维度是$3×224×224$,先通过4层卷积+BN+GELU得到维度为$256×14×14$的张量。作为输入进入Transformer Blocks中。
- 不使用class token:Transformer 的输出维度是$512×4×4$的张量,通过Average Pooling得到512维的向量。
- 注意力模块有两种形式。一种是不改变空间尺寸的注意力,一种是空间尺寸减半的注意力(shrink attention)。
- Transformer中的MLP操作非常耗时,而且作用只是升降张量的维度;所以统一换成Conv 1×1+BN,并把expansion ratio由4调到2。
- 把位置编码替换为Attention bias,即用绝对值编码位置的相对距离:
LeViT的完整实现可参考vit-pytorch,其中注意力模块的实现过程如下:
class Attention(nn.Module):
def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, dropout = 0., dim_out = None, downsample = False):
super().__init__()
inner_dim_key = dim_key * heads
inner_dim_value = dim_value * heads
dim_out = default(dim_out, dim)
self.heads = heads
self.scale = dim_key ** -0.5
self.to_q = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, stride = (2 if downsample else 1), bias = False), nn.BatchNorm2d(inner_dim_key))
self.to_k = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, bias = False), nn.BatchNorm2d(inner_dim_key))
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
out_batch_norm = nn.BatchNorm2d(dim_out)
nn.init.zeros_(out_batch_norm.weight)
self.to_out = nn.Sequential(
nn.GELU(),
nn.Conv2d(inner_dim_value, dim_out, 1),
out_batch_norm,
nn.Dropout(dropout)
)
# positional bias
self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads)
q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
k_range = torch.arange(fmap_size)
q_pos = torch.stack(torch.meshgrid(q_range, q_range, indexing = 'ij'), dim = -1)
k_pos = torch.stack(torch.meshgrid(k_range, k_range, indexing = 'ij'), dim = -1)
q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
x_rel, y_rel = rel_pos.unbind(dim = -1)
pos_indices = (x_rel * fmap_size) + y_rel
self.register_buffer('pos_indices', pos_indices)
def apply_pos_bias(self, fmap):
bias = self.pos_bias(self.pos_indices)
bias = rearrange(bias, 'i j h -> () h i j')
return fmap + (bias / self.scale)
def forward(self, x):
b, n, *_, h = *x.shape, self.heads
q = self.to_q(x)
y = q.shape[2]
qkv = (q, self.to_k(x), self.to_v(x))
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
dots = self.apply_pos_bias(dots)
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
return self.to_out(out)
MLP构建如下:
class FeedForward(nn.Module):
def __init__(self, dim, mult, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1),
nn.Hardswish(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
Transformer结构如下:
class Transformer(nn.Module):
def __init__(self, dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult = 2, dropout = 0., dim_out = None, downsample = False):
super().__init__()
dim_out = default(dim_out, dim)
self.layers = nn.ModuleList([])
self.attn_residual = (not downsample) and dim == dim_out
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, fmap_size = fmap_size, heads = heads, dim_key = dim_key, dim_value = dim_value, dropout = dropout, downsample = downsample, dim_out = dim_out),
FeedForward(dim_out, mlp_mult, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
attn_res = (x if self.attn_residual else 0)
x = attn(x) + attn_res
x = ff(x) + x
return x
LeViT结构如下,包含蒸馏token:
class LeViT(nn.Module):
def __init__(
self,
*,
image_size,
num_classes,
dim,
depth,
heads,
mlp_mult,
stages = 3,
dim_key = 32,
dim_value = 64,
dropout = 0.,
num_distill_classes = None
):
super().__init__()
dims = cast_tuple(dim, stages)
depths = cast_tuple(depth, stages)
layer_heads = cast_tuple(heads, stages)
assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'
self.conv_embedding = nn.Sequential(
nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
nn.Conv2d(128, dims[0], 3, stride = 2, padding = 1)
)
fmap_size = image_size // (2 ** 4)
layers = []
for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads):
is_last = ind == (stages - 1)
layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout))
if not is_last:
next_dim = dims[ind + 1]
layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True))
fmap_size = ceil(fmap_size / 2)
self.backbone = nn.Sequential(*layers)
self.pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...')
)
self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None)
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
x = self.conv_embedding(img)
x = self.backbone(x)
x = self.pool(x)
out = self.mlp_head(x)
distill = self.distill_head(x)
if exists(distill):
return out, distill
return out