CCNet:语义分割中的交叉注意力机制.

语义分割工作中的Non Local机制虽然可以很大程度上解决感受野问题,但是在计算复杂度上受限严重。为了减少计算量,最简单的方法就是减小通道数、降低分辨率,但是这些方法会造成信息损失,导致模型结构降低。CCNet的提出正是为了解决计算复杂度的问题。

1. Criss-cross attention

Criss-cross attention机制是计算一个点的横纵位置的attention信息,而不是与所有点进行交互。

首先主干网络的输出$X$经过一个卷积减少通道数,得到特征图$H∈ [C×W ×H]$。$H$经过三个$1\times 1$卷积模块分别生成$Q$、$K$和$V$,其中$Q, K∈ [C’×W ×H]$,$C’$设置为$C$的八分之一,用于减少计算量。接着,$QK$通过Affinity操作计算生成$A$。

对于Affinity操作:在$Q$中的每一个位置$u$,都可以在channel轴得到一个向量$Q_{u} \in [C’]$,同时可以从$K$中提取与位置$u$处于同一行、列的向量$\Omega_{u}\in [(H+W-1),C’]$。则第$i$个位置的Affinity计算为:

\[d_{i, u}=Q_u \Omega_{i, u}^T\]

经过Softmax激活后,得到注意力图 $A∈ [(H+W −1)×W ×H]$。

对于生成的$V∈[C×W ×H]$,同样对于每一个位置$u$可以在channel轴上得到一个向量集 $\Phi_u∈[(H+W −1)×C]$,将这个向量集与生成的$A$相乘,完成Aggregation操作,最后再加上原始输入$H$,输出生成的$H’$。

import torch
import torch.nn as nn
import torch.nn.functional as F

# 用于抑制行列注意力中对自身的一次重复计算
def INF(B,H,W):
     return -torch.diag(torch.tensor(float("inf")).repeat(H),0).unsqueeze(0).repeat(B*W,1,1)

class CrissCrossAttention(nn.Module):
    def __init__(self, in_channels):
        super(CrissCrossAttention, self).__init__()
        self.in_channels = in_channels
        self.channels = in_channels // 8
        self.ConvQuery = nn.Conv2d(self.in_channels, self.channels, kernel_size=1)
        self.ConvKey = nn.Conv2d(self.in_channels, self.channels, kernel_size=1)
        self.ConvValue = nn.Conv2d(self.in_channels, self.in_channels, kernel_size=1)
 
        self.SoftMax = nn.Softmax(dim=3)
        self.INF = INF
        self.gamma = nn.Parameter(torch.zeros(1))
 
    def forward(self, x):
        b, _, h, w = x.size()
 
        # [b, c', h, w]
        query = self.ConvQuery(x)
        # [b, w, c', h] -> [b*w, c', h] -> [b*w, h, c']
        query_H = query.permute(0, 3, 1, 2).contiguous().view(b*w, -1, h).permute(0, 2, 1)
        # [b, h, c', w] -> [b*h, c', w] -> [b*h, w, c']
        query_W = query.permute(0, 2, 1, 3).contiguous().view(b*h, -1, w).permute(0, 2, 1)
        
        # [b, c', h, w]
        key = self.ConvKey(x)
        # [b, w, c', h] -> [b*w, c', h]
        key_H = key.permute(0, 3, 1, 2).contiguous().view(b*w, -1, h)
        # [b, h, c', w] -> [b*h, c', w]
        key_W = key.permute(0, 2, 1, 3).contiguous().view(b*h, -1, w)
        
        # [b, c, h, w]
        value = self.ConvValue(x)
        # [b, w, c, h] -> [b*w, c, h]
        value_H = value.permute(0, 3, 1, 2).contiguous().view(b*w, -1, h)
        # [b, h, c, w] -> [b*h, c, w]
        value_W = value.permute(0, 2, 1, 3).contiguous().view(b*h, -1, w)
        
        # [b*w, h, c']* [b*w, c', h] -> [b*w, h, h] -> [b, h, w, h]
        energy_H = (torch.bmm(query_H, key_H) + self.INF(b, h, w)).view(b, w, h, h).permute(0, 2, 1, 3)
        # [b*h, w, c']*[b*h, c', w] -> [b*h, w, w] -> [b, h, w, w]
        energy_W = torch.bmm(query_W, key_W).view(b, h, w, w)
        # [b, h, w, h+w]  concate channels in axis=3 
        concate = self.SoftMax(torch.cat([energy_H, energy_W], 3))
        
        # [b, h, w, h] -> [b, w, h, h] -> [b*w, h, h]
        attention_H = concate[:,:,:, 0:h].permute(0, 2, 1, 3).contiguous().view(b*w, h, h)
        # [b, h, w, w] -> [b*h, w, w]
        attention_W = concate[:,:,:, h:h+w].contiguous().view(b*h, w, w)
 
        # [b*w, c, h]*[b*w, h, h] -> [b, w, c, h] -> [b, c, h, w]
        out_H = torch.bmm(value_H, attention_H.permute(0, 2, 1)).view(b, w, -1, h).permute(0, 2, 3, 1)
        # [b*h, c, w]*[b*h, w, w] -> [b*h, c, w] -> [b, c, h, w]
        out_W = torch.bmm(value_W, attention_W.permute(0, 2, 1)).view(b, h, -1, w).permute(0, 2, 1, 3)
 
        return self.gamma*(out_H + out_W) + x
 
if __name__ == "__main__":
    model = CrissCrossAttention(512)
    x = torch.randn(2, 512, 28, 28)
    out = model(x)
    print(out.shape)

2. Recurrent Criss-Cross Attention module

不同于Non-Local一次性计算全图的attentionCriss-cross attention机制则是计算一个点的横纵位置的attention信息。 但是如果只计算一次横纵位置的attention,则其他位置并没有被关联到,也就是这次计算的attention是局限在横纵轴位置上的,其中包括的语义信息并不丰富。

为了解决这个问题,作者串行了两个Criss-cross attention模块,这样,对于一个点的位置,首先计算了他的横纵轴的attention,然后将这个信息输出后,再经过一个Criss-cross attention计算,这个点就间接的与全图位置内的任意点进行了计算。如下图Loop1中浅绿色方块包含了蓝色方块的内容,Loop2中的深绿色与浅绿色方块进行计算,其中包含了浅绿色+蓝色方块内容,也就是深绿色方块同时关联了浅绿色方块和蓝色方块。

RCCA Module通过循环叠加了几个Criss-Cross attention module,还集成了上采样和输出模块。

class RCCAModule(nn.Module):
    def __init__(self, recurrence = 2, in_channels = 2048, num_classes=33):
        super(RCCAModule, self).__init__()
        self.recurrence = recurrence
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.inter_channels = in_channels // 4
        self.conv_in = nn.Sequential(
            nn.Conv2d(self.in_channels, self.inter_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(self.inter_channels)
        )
        self.CCA = CrissCrossAttention(self.inter_channels)
        self.conv_out = nn.Sequential(
            nn.Conv2d(self.inter_channels, self.inter_channels, 3, padding=1, bias=False)
        )
        self.cls_seg = nn.Sequential(
            nn.Conv2d(self.in_channels+self.inter_channels, self.inter_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(self.inter_channels),
            nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True),
            nn.Conv2d(self.inter_channels, self.num_classes, 1)
        )
 
    def forward(self, x):
        # reduce channels from C to C'   2048->512
        output = self.conv_in(x)
 
        for i in range(self.recurrence):
            output = self.CCA(output)
 
        output = self.conv_out(output)
        output = self.cls_seg(torch.cat([x, output], 1))
        return output
 
if __name__ == "__main__":
    model = RCCAModule(in_channels=2048)
    x = torch.randn(2, 2048, 28, 28)
    out = model(x)
    print(out.shape)

3. Criss-cross attention network

因为Criss-Cross attention module比较灵活,可以加在任意位置,所以这里作者简单的加在CNN的输出后面,用于处理feature maps,通过简单的上采样来完成分割任务。

class CCNet(nn.Module):
    def __init__(self, num_classes):
        super(CCNet, self).__init__()
        self.backbone = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
        self.decode_head = RCCAModule(recurrence=2, in_channels=2048, num_classes=num_classes)
    
    def forward(self, x):
        x = self.backbone(x)
        x = self.decode_head(x)
        return x
 
if __name__ == "__main__":
    model = CCNet(num_classes=2)
    x = torch.randn(2, 3, 224, 224)
    out = model(x)
    print(out.shape)