VT:基于Token的图像表示和处理.

本文作者提出了Visual Transformers,把图片建模为语义视觉符号 (semantic visual tokens),使用Transformer来建模tokens之间的关系,从而把问题定义在语义符号空间 (semantic token space)中,目的是在图像中表示和处理高级概念 (high-level concepts),在token空间中建模高级概念之间的联系。

具体地,使用空间注意力机制将特征图转换成一组紧凑的语义tokens,再把这些tokens输入一个Transformer,从而将语义概念编码在视觉tokens中,而不是对所有图像中的所有概念进行建模。从而能够关注图像中那些相对重要区域,而不是像CNN那样平等地对待所有的像素。

对于一张给定图片,首先通过卷积操作得到其low-level特征,把获得的特征图输入给VT。首先通过一个tokenizer,把这些特征图的像素转化为 visual tokens,每个 token 代表图片中的一个语义概念 (semantic concept);这些 token 通过Transformer处理后,输出的也是一堆 visual tokens。这些 visual tokens可以直接应用于图像分类任务,或者通过 Projector 投影回特征图进行语义分割任务。

1. Tokenizer

作者首先设计了一种Filter-based Tokenizer。对于输入图像$\mathbf{X} \in \mathbb{R}^{H W \times C}$,首先对其应用1 × 1卷积 $\mathbf{W}_A \in \mathbb{R}^{C \times L}$,然后对$HW$个长度为$L$的向量应用softmax函数得到$\mathbf{A} \in \mathbb{R}^{H W \times L}$,即把每一个像素$\mathbf{X}_p \in \mathbb{R}^{C}$映射到$L$个semantic group中的一个。再把它转置以后与输入进行矩阵乘法得到

\[\mathbf{T}=\underbrace{\operatorname{softmax}_{H W}\left(\mathbf{X} \mathbf{W}_A\right.}_{\mathbf{A} \in \mathbb{R}^{H W \times L}})^T \mathbf{X}=\mathbf{A}^T \mathbf{X} \in \mathbb{R}^{L \times C}\]

作者又设计了Recurrent Tokenizer,使用上一层的token \(\mathbf{T}_{i n} \in \mathbb{R}^{L \times C}\)来指导这一层的token \(\mathbf{T} \in \mathbb{R}^{L \times C}\)的生成。首先使用\(\mathbf{T}_{i n}\)与矩阵\(\mathbf{W}_{T \rightarrow R} \in \mathbb{R}^{C \times C}\)相乘得到矩阵\(\mathbf{W}_{R} \in \mathbb{R}^{L \times C}\),再把它当做上面的 1 × 1卷积\(\mathbf{W}_A\)与输入作用得到这一层的tokenRecurrent Tokenizer的表达式和示意图如下:

\[\begin{gathered} \mathbf{W}_R=\mathbf{T}_{i n} \mathbf{W}_{\mathbf{T} \rightarrow \mathbf{R}} \\ \mathbf{T}=\operatorname{softmax}_{H W}\left(\mathbf{X} \mathbf{W}_R\right)^T \mathbf{X} \end{gathered}\]

2. Projector

对于一些需要像素级别预测的视觉任务,比如分割等,需要得到pixel-level细节信息,只有 visual tokens 提供的信息是不够的。所以再通过ProjectorTransformer输出的 visual tokens 反变换称为Feature map

\(\mathbf{X}_{\text {in }}, \mathbf{X}_{\text {out }} \in \mathbb{R}^{H W \times C}\)分别是输入和输出特征图,在得到\(\mathbf{X}_{\text {out }}\)的过程中,使用了\(\mathbf{X}_{\text {in }}\),Transformer的输出\(\mathbf{T}\)只是为了得到残差。

\[\mathbf{X}_{\text {out }}=\mathbf{X}_{i n}+\operatorname{softmax}_L\left(\left(\mathbf{X}_{i n} \mathbf{W}_Q\right)\left(\mathbf{T} \mathbf{W}_K\right)^T\right) \mathbf{T}\]

3. 在视觉模型中使用VT

可以把VT添加到现有模型里面,比如ResNet变成visual-transformer-ResNets (VT-ResNets)。具体方法是把ResNet网络的最后一个stage的所有的卷积层变成VT module。比如ResNet-18stage 4结束后得到的feature map14×14×256,可以使用16visual token,且其channel数都设为1024。所以最后Transformer会输出得到16visual tokens $\mathbf{T}_{\text {out }} \in \mathbb{R}^{16 \times 1024}$。

也可以把VT添加到分割任务的FPN模块中,只需要把FPN中的卷积替换成VT module即可。在实做中使用8visual tokens,且其channel数都设为1024。然后输出的visual tokens被投影回原始特征图,用于执行分割任务。与最初的FPN相比,VT-FPN的计算成本要小得多,因为只对极少数量的visual tokens而不是所有像素进行操作。

一个用于图像分类的VT-ResNet模型构建如下:

class ViTResNet(nn.Module):
    def __init__(self, block, num_classes=10, dim = 128, num_tokens = 8, mlp_dim = 256, heads = 8, depth = 6, emb_dropout = 0.1, dropout= 0.1):
        super(ViTResNet, self).__init__()
        self.in_planes = 16
        self.L = num_tokens
        self.cT = dim
        
        self.backbone = ResNet()
    
        # Tokenization parameters
        self.token_wA = nn.Parameter(torch.empty(BATCH_SIZE_TRAIN, self.L, 64),requires_grad = True)
        torch.nn.init.xavier_uniform_(self.token_wA)
        self.token_wV = nn.Parameter(torch.empty(BATCH_SIZE_TRAIN, 64, self.cT),requires_grad = True)
        torch.nn.init.xavier_uniform_(self.token_wV)        
        
        self.pos_embedding = nn.Parameter(torch.empty(1, (num_tokens + 1), dim))
        torch.nn.init.normal_(self.pos_embedding, std = .02)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
        self.to_cls_token = nn.Identity()

        self.nn1 = nn.Linear(dim, num_classes)
        torch.nn.init.xavier_uniform_(self.nn1.weight)
        torch.nn.init.normal_(self.nn1.bias, std = 1e-6)
    
        
    def forward(self, img, mask = None):
        x = self.backbone(x) 
        x = rearrange(x, 'b c h w -> b (h w) c')

        #Tokenization 
        wa = rearrange(self.token_wA, 'b l c -> b c l')
        A= torch.einsum('bij,bjk->bik', x, wa) 
        A = rearrange(A, 'b hw l -> b l hw')
        A = A.softmax(dim=-1)

        # VV:(b, hw, cT)
        VV= torch.einsum('bij,bjk->bik', x, self.token_wV)  

        # T:(b, L, cT = 128)
        T = torch.einsum('bij,bjk->bik', A, VV)  

        # cls_tokens:(b, 1, 128)
        cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)

        # x:(b, L+1, cT = 128)
        x = torch.cat((cls_tokens, T), dim=1)
        x += self.pos_embedding
        x = self.dropout(x)

        # x:(b, L+1, cT = 128)
        x = self.transformer(x, mask)

        # x:(b, cT = 128)
        x = self.to_cls_token(x[:, 0])   

        x = self.nn1(x)
        return x