FcaNet:频域通道注意力网络.

卷积神经网络中的通道注意力机制通常使用全局平均池化(GAP)提取不同通道的统计信息,但是平均值信息不足以表示各个通道的信息,比如不同的通道可能具有相同的平均值,而它们对应的语义信息可能是完全不同的。

本文作者证明GAP是离散余弦变换DCT的一个特例,等价于DCT的最低频率分量,并进一步在频域中提出了一种频域通道注意力机制;进一步通过探讨使用不同数量的频率分量及其不同组合的影响,提出了选择频率分量的两步准则。

1. 离散余弦变换

图像既可以表示在空间域,又可以表示在频率域。空域是指图像的像素表示,可以在其上直接对像素值进行处理,例如卷积、池化、插值等操作。频域是指将图像看作信号,将像素值看作分别沿图像的高度和宽度方向正弦函数的叠加。

根据傅里叶级数,任何信号都可以表示成一系列具有不同振幅及频率的正弦函数的和;因此,通过二维离散傅里叶变换(DFT)可以得到图像的频谱图。频谱图中不同半径代表不同频率分量,低频信息在中心,高频信息在边缘。

离散余弦变换(DCT)是离散傅里叶变换的一种特殊形式,是导出DFT余弦项的一种变换。相比于DFTDCT能够把图像更重要的信息聚集在一起,去掉一些由于间断导致的高频信息,实现信息的聚焦。

本文作者证明了图像在空域的全局平均池化等价于DCT的最低频分量。记图像像素点$x_{i,j}$,则DCT计算图像中每个输入点的加权和,其中余弦部分相当于权重:

\[f_{h,w} = \sum_{i=0}^{H-1} \sum_{j=0}^{W-1} x_{i,j} \cos(\frac{\pi h}{H}(i+\frac{1}{2})) \cos(\frac{\pi w}{W}(j+\frac{1}{2}))\]

通过所有频率分量的在空间点$(i,j)$的叠加可以恢复像素值,从而实现DCT的逆变换:

\[x_{i,j} = \sum_{i=0}^{H-1} \sum_{j=0}^{W-1} f_{h,w} \cos(\frac{\pi h}{H}(i+\frac{1}{2})) \cos(\frac{\pi w}{W}(j+\frac{1}{2}))\]

如果只考虑最低频分量,令$h=w=0$,则$f_{0,0}$表示为:

\[f_{0,0} = \sum_{i=0}^{H-1} \sum_{j=0}^{W-1} x_{i,j} \cos(0) \cos(0) = \sum_{i=0}^{H-1} \sum_{j=0}^{W-1} x_{i,j} = GAP(x)\cdot HW\]

上式等价于全局平均池化,因此GAP也可用DCT的最低频率分量表示:

\[GAP(x) = \frac{f_{0,0}}{HW}\]

DCT的余弦项为:

\[B_{h,w}^{i,j} = \cos(\frac{\pi h}{H}(i+\frac{1}{2})) \cos(\frac{\pi w}{W}(j+\frac{1}{2}))\]

则可把图像表示为DCT的不同频率项之和:

\[\begin{aligned} x_{i,j} &= \sum_{i=0}^{H-1} \sum_{j=0}^{W-1} f_{h,w} \cos(\frac{\pi h}{H}(i+\frac{1}{2})) \cos(\frac{\pi w}{W}(j+\frac{1}{2})) \\ &= f_{0,0}B_{0,0}^{i,j} + f_{0,1}B_{0,1}^{i,j} + \cdots + f_{H-1,W-1}B_{H-1,W-1}^{i,j} \end{aligned}\]

2. 多光谱通道注意力 Multi-spectral channel attention

既然图像可以表示为DCT的不同频率项之和,则可以用不同频率项构造注意力机制中的通道统计量。特别地,只考虑最低频率分量则等价于使用GAP的通道注意力:

作者设计的多光谱通道注意力如下图所示,首先选择应用离散余弦变换后Top-k个性能最佳的频率分量标号,然后把输入特征沿通道划分为$k$等份,对每份计算其对应的DCT频率分量,并与对应的特征分组相乘。

在选择频率分量时,作者提出了两步准则:

  1. 首先分别计算出通道注意力中每个频率分量的结果;
  2. 再根据所得结果筛选出Top-k个性能最佳的频率分量。

在单个频率分量的测试中,结果表明还是最低频的信息效果最好,这也验证了神经网络偏好低频信息的结论。其他频率分量对网络也是有贡献的,因此可以将这些信息嵌入特征。

值得一提的是,单个频率分量的测试都是在$7 \times 7$特征图上进行的,这是因为ImageNet上最小特征尺寸是$7 \times 7$,对应$49$个频率分量。在实验中选择通过$7 \times 7$特征图验证的最好频率分量标号即可,因为在应用DCT权重前通过特征的自适应池化对不同尺寸的特征图除以$7$做了归一化,使得分辨率不影响结果,比如$14 \times 14$特征图上的标号$(2,2)$等价于$7 \times 7$特征图上的标号$(1,1)$。通过实验获得的标号排序如下:

def get_freq_indices(method):
    assert method in ['top1','top2','top4','top8','top16','top32',
                      'bot1','bot2','bot4','bot8','bot16','bot32',
                      'low1','low2','low4','low8','low16','low32']
    num_freq = int(method[3:])
    if 'top' in method:
        all_top_indices_x = [0,0,6,0,0,1,1,4,5,1,3,0,0,0,3,2,4,6,3,5,5,2,6,5,5,3,3,4,2,2,6,1]
        all_top_indices_y = [0,1,0,5,2,0,2,0,0,6,0,4,6,3,5,2,6,3,3,3,5,1,1,2,4,2,1,1,3,0,5,3]
        mapper_x = all_top_indices_x[:num_freq]
        mapper_y = all_top_indices_y[:num_freq]
    elif 'low' in method:
        all_low_indices_x = [0,0,1,1,0,2,2,1,2,0,3,4,0,1,3,0,1,2,3,4,5,0,1,2,3,4,5,6,1,2,3,4]
        all_low_indices_y = [0,1,0,1,2,0,1,2,2,3,0,0,4,3,1,5,4,3,2,1,0,6,5,4,3,2,1,0,6,5,4,3]
        mapper_x = all_low_indices_x[:num_freq]
        mapper_y = all_low_indices_y[:num_freq]
    elif 'bot' in method:
        all_bot_indices_x = [6,1,3,3,2,4,1,2,4,4,5,1,4,6,2,5,6,1,6,2,2,4,3,3,5,5,6,2,5,5,3,6]
        all_bot_indices_y = [6,4,4,6,6,3,1,4,4,5,6,5,2,2,5,1,4,3,5,0,3,1,1,2,4,2,1,1,5,3,3,3]
        mapper_x = all_bot_indices_x[:num_freq]
        mapper_y = all_bot_indices_y[:num_freq]
    else:
        raise NotImplementedError
    return mapper_x, mapper_y

在获得每个频率分量的性能后,需要确定所选择的最佳频率分量数。实验报告了Top-k最高性能的频率成分,结果表明应用多光谱的结果优于GAP(对应$k=1$),最终选择16个频率分量。

多光谱通道注意力实现如下:

class MultiSpectralAttentionLayer(torch.nn.Module):
    def __init__(self, channel, dct_h, dct_w, reduction = 16, freq_sel_method = 'top16'):
        super(MultiSpectralAttentionLayer, self).__init__()
        self.reduction = reduction
        self.dct_h = dct_h
        self.dct_w = dct_w

        mapper_x, mapper_y = get_freq_indices(freq_sel_method)
        mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x] 
        mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y]
        # make the frequencies in different sizes are identical to a 7x7 frequency space
        # eg, (2,2) in 14x14 is identical to (1,1) in 7x7

        self.dct_layer = MultiSpectralDCTLayer(dct_h, dct_w, mapper_x, mapper_y, channel)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        n,c,h,w = x.shape
        x_pooled = x
        if h != self.dct_h or w != self.dct_w:
            x_pooled = torch.nn.functional.adaptive_avg_pool2d(x, (self.dct_h, self.dct_w))
        y = self.dct_layer(x_pooled)
        y = self.fc(y).view(n, c, 1, 1)
        return x * y.expand_as(x)


class MultiSpectralDCTLayer(nn.Module):
    """
    Generate dct filters
    """
    def __init__(self, height, width, mapper_x, mapper_y, channel):
        super(MultiSpectralDCTLayer, self).__init__()
        assert len(mapper_x) == len(mapper_y)
        assert channel % len(mapper_x) == 0
        self.num_freq = len(mapper_x)

        # fixed DCT init
        self.register_buffer('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))
        
        # fixed random init
        # self.register_buffer('weight', torch.rand(channel, height, width))

        # learnable DCT init
        # self.register_parameter('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))
        
        # learnable random init
        # self.register_parameter('weight', torch.rand(channel, height, width))

    def forward(self, x):
        assert len(x.shape) == 4, 'x must been 4 dimensions, but got ' + str(len(x.shape))
        # n, c, h, w = x.shape
        x = x * self.weight
        result = torch.sum(x, dim=[2,3])
        return result

    def build_filter(self, pos, freq, POS):
        result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS) 
        if freq == 0:
            return result
        else:
            return result * math.sqrt(2)
    
    def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel):
        dct_filter = torch.zeros(channel, tile_size_x, tile_size_y)
        c_part = channel // len(mapper_x)
        for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)):
            for t_x in range(tile_size_x):
                for t_y in range(tile_size_y):
                    dct_filter[i * c_part: (i+1)*c_part, t_x, t_y] = self.build_filter(t_x, u_x, tile_size_x) * self.build_filter(t_y, v_y, tile_size_y)
        return dct_filter