ViT:使用图像块序列的Transformer进行图像分类.
1. 模型介绍
模型的整体结构如上图所示。作者尽可能遵守原始Transformer的结构设计,其目的是使得针对Transformer设计的优化结构可以直接套用。
将输入图像$x \in \Bbb{R}^{H \times W \times C}$划分成若干patch $x_p \in \Bbb{R}^{N \times (P^2 \cdot C)}$,其有效序列长度为$N = \frac{HW}{P^2}$。
将每个patch展平后通过线性映射转化为一个维度为$D$的嵌入向量(patch embedding),并在输入的起始位置增加一个可学习的类别嵌入,该向量在输出时的状态可作为图像的特征表示。在预训练和微调阶段,分类器将其作为输入。
增加$1D$位置编码(position embedding)后输入Transformer的编码器(实验发现$2D$位置编码对结果提升不明显)。预训练时在网络后增加一个MLP线性分类器进行图像分类。
微调时使用更高分辨率的图像。保持每一个图像patch的尺寸不变,这将使输入序列长度增加。Transformer可以输入任意长度的序列,但预训练的位置编码将不再匹配。为此使用$2D$插值调整位置编码。这部分是人为引入的归纳偏置(inductive bias)。
2. 实验分析
作者训练了三个不同大小的ViT模型,其参数量如下表所示:
在中等规模的数据集(如ImageNet)上训练,准确率要比基于卷积神经网络的模型(如ResNet)低几个点。这是因为Transformer缺少卷积神经网络的归纳偏置,如平移等变性和局部性(translation equivariance and locality),这使得它在训练数据不足的时候泛化能力不强。作者认为在大尺度($14M$-$300M$)的数据集上训练可以解决这个问题。
实验结果显示,在JFT-300M数据集上预训练后,基于Transformer的分类模型迁移到小数据集任务中超越了基于卷积神经网络的模型:
作者可视化了部分线性嵌入的权重和位置编码,表明模型学习到特征提取和位置敏感的信息。作者分析不同层中注意力平均距离(类似于卷积网络中的感受野大小),发现在浅层模型同时关注近距离和远距离的特征,在深层模型主要关注远距离特征。而卷积神经网络在浅层主要关注近距离特征。
3. 模型实现
ViT的完整实现可参考vit-pytorch。
ViT所采用的Transformer编码器为pre-norm的形式:
import torch
from torch import nn
from einops import rearrange
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
ViT模型构建如下:
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = (image_size, image_size)
patch_height, patch_width = (patch_size, patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
实例化ViT模型的例子:
import torch
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024, # Last dimension of output tensor after linear transformation
depth = 6, # Number of Transformer blocks
heads = 16, # Number of heads in Multi-head Attention layer
mlp_dim = 2048, # Dimension of the MLP (FeedForward) layer
dropout = 0.1,
emb_dropout = 0.1,
pool = 'cls', # either 'cls' token pooling or 'mean' pooling
channels = 3,
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)