CCT:使用紧凑的Transformer避免大数据依赖.
通常认为Transformer这种结构是data hungry的,即想到得到良好的性能就需要大量的数据来训练。Transformer缺少了CNN的inductive biases的能力,也就没有了CNN的平移等变性 (Translation equivariance),因而需要大量的数据才能完成训练。
本文作者提出了更紧凑的ViT设计,通过优化tokenization把CNN和Transformer结合起来,可以不再需要大量的训练数据,以解决data hungry的问题。本文以极小的数据集 (比如CIFAR10/100, MNIST) 来从头训练更小的Transformer模型,也可以达到相似的性能。
对tokenization的优化体现在两个方面,分别是pathc tokenization和class tokenization。基于此作者分别提出了Compact Convolutional Transformers (CCT)和Compact Vision Transformers (CVT)结构。
1. Compact Vision Transformers (CVT)
CVT在ViT的基础上引入了序列池化 SeqPool,该方法将Transformer Encoder产生的基于顺序的信息进行池化。记编码器的输出为$x_L=f\left(x_0\right) \in \mathbb{R}^{b \times n \times d}$,其中$n$为序列长度,$d$是特征维度。把输出通过一个线性层$g\left(x_L\right) \in \mathbb{R}^{d \times 1}$和softmax激活函数:
\[x_L^{\prime}=\operatorname{softmax}\left(g\left(x_L\right)^T\right) \in R^{b \times 1 \times n}\]上式相当于构造了输出序列特征的权重向量,然后对序列特征进行加权平均:
\[z=x_L^{\prime} x_L=\operatorname{softmax}\left(g\left(x_L\right)^T\right) \times x_L \in R^{b \times 1 \times d}\]相当于对输出特征沿着序列维度进行了池化,生成的特征可以用于后续的分类任务。
self.attention_pool = nn.Linear(self.embedding_dim, 1)
x = self.encoder(x) # (b, n, d)
attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n')
x = einsum('b n, b n d -> b d', attn_weights.softmax(dim = 1), x)
2. Compact Convolutional Transformers (CCT)
CCT为了给模型引入inductive bias,在给图片分patch的环节使用了卷积层。实现时通过多个卷积层堆叠起来,第1层的通道数设置为64,最后一层通道数设置为Transformer的embedding dimension。使用卷积层可以使得模型更好地保留局部的空间信息,可以不再需要借助位置编码来保存这部分位置信息。
class Tokenizer(nn.Module):
def __init__(self,
kernel_size, stride, padding,
pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
n_conv_layers=1,
n_input_channels=3,
n_output_channels=64,
in_planes=64,
activation=None,
max_pool=True,
conv_bias=False):
super().__init__()
n_filter_list = [n_input_channels] + \
[in_planes for _ in range(n_conv_layers - 1)] + \
[n_output_channels]
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
self.conv_layers = nn.Sequential(
*[nn.Sequential(
nn.Conv2d(chan_in, chan_out,
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=(padding, padding), bias=conv_bias),
nn.Identity() if not exists(activation) else activation(),
nn.MaxPool2d(kernel_size=pooling_kernel_size,
stride=pooling_stride,
padding=pooling_padding) if max_pool else nn.Identity()
)
for chan_in, chan_out in n_filter_list_pairs
])
self.apply(self.init_weight)
def sequence_length(self, n_channels=3, height=224, width=224):
return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
def forward(self, x):
return rearrange(self.conv_layers(x), 'b c h w -> b (h w) c')
@staticmethod
def init_weight(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
CCT的完整实现可参考vit-pytorch。