ViT:使用图像块序列的Transformer进行图像分类.

1. 模型介绍

模型的整体结构如上图所示。作者尽可能遵守原始Transformer的结构设计,其目的是使得针对Transformer设计的优化结构可以直接套用。

将输入图像$x \in \Bbb{R}^{H \times W \times C}$划分成若干patch $x_p \in \Bbb{R}^{N \times (P^2 \cdot C)}$,其有效序列长度为$N = \frac{HW}{P^2}$。

将每个patch展平后通过线性映射转化为一个维度为$D$的嵌入向量(patch embedding),并在输入的起始位置增加一个可学习的类别嵌入,该向量在输出时的状态可作为图像的特征表示。在预训练和微调阶段,分类器将其作为输入。

增加$1D$位置编码(position embedding)后输入Transformer的编码器(实验发现$2D$位置编码对结果提升不明显)。预训练时在网络后增加一个MLP线性分类器进行图像分类。

微调时使用更高分辨率的图像。保持每一个图像patch的尺寸不变,这将使输入序列长度增加。Transformer可以输入任意长度的序列,但预训练的位置编码将不再匹配。为此使用$2D$插值调整位置编码。这部分是人为引入的归纳偏置(inductive bias)

2. 实验分析

作者训练了三个不同大小的ViT模型,其参数量如下表所示:

在中等规模的数据集(如ImageNet)上训练,准确率要比基于卷积神经网络的模型(如ResNet)低几个点。这是因为Transformer缺少卷积神经网络的归纳偏置,如平移等变性和局部性(translation equivariance and locality),这使得它在训练数据不足的时候泛化能力不强。作者认为在大尺度($14M$-$300M$)的数据集上训练可以解决这个问题。

实验结果显示,在JFT-300M数据集上预训练后,基于Transformer的分类模型迁移到小数据集任务中超越了基于卷积神经网络的模型:

作者可视化了部分线性嵌入的权重和位置编码,表明模型学习到特征提取和位置敏感的信息。作者分析不同层中注意力平均距离(类似于卷积网络中的感受野大小),发现在浅层模型同时关注近距离和远距离的特征,在深层模型主要关注远距离特征。而卷积神经网络在浅层主要关注近距离特征。

3. 模型实现

ViT的完整实现可参考vit-pytorch

ViT所采用的Transformer编码器为pre-norm的形式:

import torch
from torch import nn

from einops import rearrange

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

ViT模型构建如下:

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = (image_size, image_size)
        patch_height, patch_width = (patch_size, patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

实例化ViT模型的例子:

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256, 
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,        # Last dimension of output tensor after linear transformation
    depth = 6,         # Number of Transformer blocks
    heads = 16,        # Number of heads in Multi-head Attention layer
    mlp_dim = 2048,    # Dimension of the MLP (FeedForward) layer
    dropout = 0.1,
    emb_dropout = 0.1,
    pool = 'cls',      # either 'cls' token pooling or 'mean' pooling
    channels = 3,
)

img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)