VT:基于Token的图像表示和处理.
本文作者提出了Visual Transformers,把图片建模为语义视觉符号 (semantic visual tokens),使用Transformer来建模tokens之间的关系,从而把问题定义在语义符号空间 (semantic token space)中,目的是在图像中表示和处理高级概念 (high-level concepts),在token空间中建模高级概念之间的联系。
具体地,使用空间注意力机制将特征图转换成一组紧凑的语义tokens,再把这些tokens输入一个Transformer,从而将语义概念编码在视觉tokens中,而不是对所有图像中的所有概念进行建模。从而能够关注图像中那些相对重要区域,而不是像CNN那样平等地对待所有的像素。
对于一张给定图片,首先通过卷积操作得到其low-level特征,把获得的特征图输入给VT。首先通过一个tokenizer,把这些特征图的像素转化为 visual tokens,每个 token 代表图片中的一个语义概念 (semantic concept);这些 token 通过Transformer处理后,输出的也是一堆 visual tokens。这些 visual tokens可以直接应用于图像分类任务,或者通过 Projector 投影回特征图进行语义分割任务。
1. Tokenizer
作者首先设计了一种Filter-based Tokenizer。对于输入图像$\mathbf{X} \in \mathbb{R}^{H W \times C}$,首先对其应用1 × 1卷积 $\mathbf{W}_A \in \mathbb{R}^{C \times L}$,然后对$HW$个长度为$L$的向量应用softmax函数得到$\mathbf{A} \in \mathbb{R}^{H W \times L}$,即把每一个像素$\mathbf{X}_p \in \mathbb{R}^{C}$映射到$L$个semantic group中的一个。再把它转置以后与输入进行矩阵乘法得到
\[\mathbf{T}=\underbrace{\operatorname{softmax}_{H W}\left(\mathbf{X} \mathbf{W}_A\right.}_{\mathbf{A} \in \mathbb{R}^{H W \times L}})^T \mathbf{X}=\mathbf{A}^T \mathbf{X} \in \mathbb{R}^{L \times C}\]作者又设计了Recurrent Tokenizer,使用上一层的token \(\mathbf{T}_{i n} \in \mathbb{R}^{L \times C}\)来指导这一层的token \(\mathbf{T} \in \mathbb{R}^{L \times C}\)的生成。首先使用\(\mathbf{T}_{i n}\)与矩阵\(\mathbf{W}_{T \rightarrow R} \in \mathbb{R}^{C \times C}\)相乘得到矩阵\(\mathbf{W}_{R} \in \mathbb{R}^{L \times C}\),再把它当做上面的 1 × 1卷积\(\mathbf{W}_A\)与输入作用得到这一层的token。Recurrent Tokenizer的表达式和示意图如下:
\[\begin{gathered} \mathbf{W}_R=\mathbf{T}_{i n} \mathbf{W}_{\mathbf{T} \rightarrow \mathbf{R}} \\ \mathbf{T}=\operatorname{softmax}_{H W}\left(\mathbf{X} \mathbf{W}_R\right)^T \mathbf{X} \end{gathered}\]2. Projector
对于一些需要像素级别预测的视觉任务,比如分割等,需要得到pixel-level细节信息,只有 visual tokens 提供的信息是不够的。所以再通过Projector把Transformer输出的 visual tokens 反变换称为Feature map。
\(\mathbf{X}_{\text {in }}, \mathbf{X}_{\text {out }} \in \mathbb{R}^{H W \times C}\)分别是输入和输出特征图,在得到\(\mathbf{X}_{\text {out }}\)的过程中,使用了\(\mathbf{X}_{\text {in }}\),Transformer的输出\(\mathbf{T}\)只是为了得到残差。
\[\mathbf{X}_{\text {out }}=\mathbf{X}_{i n}+\operatorname{softmax}_L\left(\left(\mathbf{X}_{i n} \mathbf{W}_Q\right)\left(\mathbf{T} \mathbf{W}_K\right)^T\right) \mathbf{T}\]3. 在视觉模型中使用VT
可以把VT添加到现有模型里面,比如ResNet变成visual-transformer-ResNets (VT-ResNets)。具体方法是把ResNet网络的最后一个stage的所有的卷积层变成VT module。比如ResNet-18的stage 4结束后得到的feature map是14×14×256,可以使用16个visual token,且其channel数都设为1024。所以最后Transformer会输出得到16个visual tokens $\mathbf{T}_{\text {out }} \in \mathbb{R}^{16 \times 1024}$。
也可以把VT添加到分割任务的FPN模块中,只需要把FPN中的卷积替换成VT module即可。在实做中使用8个visual tokens,且其channel数都设为1024。然后输出的visual tokens被投影回原始特征图,用于执行分割任务。与最初的FPN相比,VT-FPN的计算成本要小得多,因为只对极少数量的visual tokens而不是所有像素进行操作。
一个用于图像分类的VT-ResNet模型构建如下:
class ViTResNet(nn.Module):
def __init__(self, block, num_classes=10, dim = 128, num_tokens = 8, mlp_dim = 256, heads = 8, depth = 6, emb_dropout = 0.1, dropout= 0.1):
super(ViTResNet, self).__init__()
self.in_planes = 16
self.L = num_tokens
self.cT = dim
self.backbone = ResNet()
# Tokenization parameters
self.token_wA = nn.Parameter(torch.empty(BATCH_SIZE_TRAIN, self.L, 64),requires_grad = True)
torch.nn.init.xavier_uniform_(self.token_wA)
self.token_wV = nn.Parameter(torch.empty(BATCH_SIZE_TRAIN, 64, self.cT),requires_grad = True)
torch.nn.init.xavier_uniform_(self.token_wV)
self.pos_embedding = nn.Parameter(torch.empty(1, (num_tokens + 1), dim))
torch.nn.init.normal_(self.pos_embedding, std = .02)
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.to_cls_token = nn.Identity()
self.nn1 = nn.Linear(dim, num_classes)
torch.nn.init.xavier_uniform_(self.nn1.weight)
torch.nn.init.normal_(self.nn1.bias, std = 1e-6)
def forward(self, img, mask = None):
x = self.backbone(x)
x = rearrange(x, 'b c h w -> b (h w) c')
#Tokenization
wa = rearrange(self.token_wA, 'b l c -> b c l')
A= torch.einsum('bij,bjk->bik', x, wa)
A = rearrange(A, 'b hw l -> b l hw')
A = A.softmax(dim=-1)
# VV:(b, hw, cT)
VV= torch.einsum('bij,bjk->bik', x, self.token_wV)
# T:(b, L, cT = 128)
T = torch.einsum('bij,bjk->bik', A, VV)
# cls_tokens:(b, 1, 128)
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
# x:(b, L+1, cT = 128)
x = torch.cat((cls_tokens, T), dim=1)
x += self.pos_embedding
x = self.dropout(x)
# x:(b, L+1, cT = 128)
x = self.transformer(x, mask)
# x:(b, cT = 128)
x = self.to_cls_token(x[:, 0])
x = self.nn1(x)
return x