CCT:使用紧凑的Transformer避免大数据依赖.

通常认为Transformer这种结构是data hungry的,即想到得到良好的性能就需要大量的数据来训练。Transformer缺少了CNNinductive biases的能力,也就没有了CNN的平移等变性 (Translation equivariance),因而需要大量的数据才能完成训练。

本文作者提出了更紧凑的ViT设计,通过优化tokenizationCNNTransformer结合起来,可以不再需要大量的训练数据,以解决data hungry的问题。本文以极小的数据集 (比如CIFAR10/100, MNIST) 来从头训练更小的Transformer模型,也可以达到相似的性能。

tokenization的优化体现在两个方面,分别是pathc tokenizationclass tokenization。基于此作者分别提出了Compact Convolutional Transformers (CCT)Compact Vision Transformers (CVT)结构。

1. Compact Vision Transformers (CVT)

CVTViT的基础上引入了序列池化 SeqPool,该方法将Transformer Encoder产生的基于顺序的信息进行池化。记编码器的输出为$x_L=f\left(x_0\right) \in \mathbb{R}^{b \times n \times d}$,其中$n$为序列长度,$d$是特征维度。把输出通过一个线性层$g\left(x_L\right) \in \mathbb{R}^{d \times 1}$和softmax激活函数:

\[x_L^{\prime}=\operatorname{softmax}\left(g\left(x_L\right)^T\right) \in R^{b \times 1 \times n}\]

上式相当于构造了输出序列特征的权重向量,然后对序列特征进行加权平均:

\[z=x_L^{\prime} x_L=\operatorname{softmax}\left(g\left(x_L\right)^T\right) \times x_L \in R^{b \times 1 \times d}\]

相当于对输出特征沿着序列维度进行了池化,生成的特征可以用于后续的分类任务。

self.attention_pool = nn.Linear(self.embedding_dim, 1)

x = self.encoder(x) # (b, n, d)
attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n')
x = einsum('b n, b n d -> b d', attn_weights.softmax(dim = 1), x)

2. Compact Convolutional Transformers (CCT)

CCT为了给模型引入inductive bias,在给图片分patch的环节使用了卷积层。实现时通过多个卷积层堆叠起来,第1层的通道数设置为64,最后一层通道数设置为Transformerembedding dimension。使用卷积层可以使得模型更好地保留局部的空间信息,可以不再需要借助位置编码来保存这部分位置信息。

class Tokenizer(nn.Module):
    def __init__(self,
                 kernel_size, stride, padding,
                 pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
                 n_conv_layers=1,
                 n_input_channels=3,
                 n_output_channels=64,
                 in_planes=64,
                 activation=None,
                 max_pool=True,
                 conv_bias=False):
        super().__init__()

        n_filter_list = [n_input_channels] + \
                        [in_planes for _ in range(n_conv_layers - 1)] + \
                        [n_output_channels]

        n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])

        self.conv_layers = nn.Sequential(
            *[nn.Sequential(
                nn.Conv2d(chan_in, chan_out,
                          kernel_size=(kernel_size, kernel_size),
                          stride=(stride, stride),
                          padding=(padding, padding), bias=conv_bias),
                nn.Identity() if not exists(activation) else activation(),
                nn.MaxPool2d(kernel_size=pooling_kernel_size,
                             stride=pooling_stride,
                             padding=pooling_padding) if max_pool else nn.Identity()
            )
                for chan_in, chan_out in n_filter_list_pairs
            ])

        self.apply(self.init_weight)

    def sequence_length(self, n_channels=3, height=224, width=224):
        return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]

    def forward(self, x):
        return rearrange(self.conv_layers(x), 'b c h w -> b (h w) c')

    @staticmethod
    def init_weight(m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight)

CCT的完整实现可参考vit-pytorch