图像分类的残差注意力网络.

Attention的出发点是将注意力集中在部分显著或者是感兴趣的图像点上。其实卷积网络本身就自带Attention效果,以分类网络为例,高层feature map所激活的pixel也恰好集中在分类任务相关的区域。

本文提出一种可堆叠的Residual Attention Module模块,在普通的ResNet网络中,增加侧分支,侧分支通过一系列的卷积和池化操作,逐渐提取高层特征并增大模型的感受野。

高层特征的激活对应位置能够反映attention的区域,然后再对这种具有attention特征的feature map进行上采样,使其大小回到原始feature map的大小。

attention map对应到原始图片的每一个位置上,与原来的feature map进行element-wise product的操作,相当于一个权重器,增强有意义的特征,抑制无意义的信息。

Attention Module分为两个分支,右边的分支就是普通的卷积网络,即主干分支,叫做Trunk Branch。左边的分支是为了得到一个掩码mask,该掩码的作用是得到输入特征的attention map,所以叫做Mask Branch,这个Mask Branch包含down sampleup sample的过程,目的是为了保证和右边分支的输出大小一致。

得到Attention mapmask以后,可以直接用mask和主干分支进行一个element-wise product的操作,即$M(x) \cdot T(x)$,来对特征做一次权重操作。但是这样导致的问题就是:$M(x)$的掩码是通过最后的sigmoid函数得到的,$M(x)$值在$[0, 1]$之间,连续多个Module模块直接相乘的话会导致feature map的值越来越小,同时也有可能打破原有网络的特性,使得网络的性能降低。比较好的方式就借鉴ResNet恒等映射的方法:

\[H(x) = (1+M(x)) \cdot T(x)\]

其中$M(x)$为Soft Mask Branch的输出,$T(x)$为Trunk Branch的输出,那么当$M(x)=0$时,该层的输入就等于$T(x)$,因此该层的效果不可能比原始的$T(x)$差,这一点也借鉴了ResNet中恒等映射的思想,同时这样的加法,也使得Trunk Branch输出的feature map中显著的特征更加显著,增加了特征的判别性。经过这种残差结构的堆叠,能够很容易的将模型的深度达到很深的层次,具有非常好的性能。

def attention_block(input, input_channels=None, output_channels=None, encoder_depth=1):
    p = 1
    t = 2
    r = 1
    if input_channels is None:
        input_channels = input.get_shape()[-1].value
    if output_channels is None:
        output_channels = input_channels
    # First Residual Block
    for i in range(p):
        input = residual_block(input)
    # Trunk Branch
    output_trunk = input
    for i in range(t):
        output_trunk = residual_block(output_trunk)
        
    # Soft Mask Branch
    ## encoder
    ### first down sampling
    output_soft_mask = MaxPool2D(padding='same')(input)  # 32x32
    for i in range(r):
        output_soft_mask = residual_block(output_soft_mask)
 
    skip_connections = []
    for i in range(encoder_depth - 1):
        ## skip connections
        output_skip_connection = residual_block(output_soft_mask)
        skip_connections.append(output_skip_connection)
        ## down sampling
        output_soft_mask = MaxPool2D(padding='same')(output_soft_mask)
        for _ in range(r):
            output_soft_mask = residual_block(output_soft_mask)
 
    ## decoder
    skip_connections = list(reversed(skip_connections))
    for i in range(encoder_depth - 1):
        ## upsampling
        for _ in range(r):
            output_soft_mask = residual_block(output_soft_mask)
        output_soft_mask = UpSampling2D()(output_soft_mask)
        ## skip connections
        output_soft_mask = Add()([output_soft_mask, skip_connections[i]])
 
    ### last upsampling
    for i in range(r):
        output_soft_mask = residual_block(output_soft_mask)
    output_soft_mask = UpSampling2D()(output_soft_mask)
 
    ## Output
    output_soft_mask = Conv2D(input_channels, (1, 1))(output_soft_mask)
    output_soft_mask = Conv2D(input_channels, (1, 1))(output_soft_mask)
    output_soft_mask = Activation('sigmoid')(output_soft_mask)
 
    # Attention: (1 + output_soft_mask) * output_trunk
    output = Lambda(lambda x: x + 1)(output_soft_mask)
    output = Multiply()([output, output_trunk])  #
 
    # Last Residual Block
    for i in range(p):
        output = residual_block(output)
 
    return output