OCRNet:语义分割中的目标上下文表示.

在图像分割模型中,为了提取多尺度特征,会在骨干网络后引入一些上下文模块,比如空洞空间金字塔卷积层ASPP。一般性的ASPP方法如图(a),其中红点是关注的点,蓝点和黄点是采样出来的周围点,若将其作为红点的上下文,背景和物体没有区分开来,这样的上下文信息对红点像素分类帮助有限。为改善此情况,本文提出OCRNet方法如图(b),其可让上下文信息关注在物体上,从而为红点提供更有用的信息。

OCRNet 方法总体思路:首先用一般的语义分割模型得到一个粗略的分割结果(soft object regions),同时从backbone还可获得每个像素的特征(pixel representation),根据每个像素的语义信息和特征,可以得到每个类别的特征(object region representation);随后可计算像素特征与各个类别特征的相似度(pixel-region relation),根据该相似度可得到每个像素点属于各类别的可能性,进一步把每个区域的表征进行加权,会得到当前像素增强的特征表示(object-contextual representation),整体流程如下:

Step1:提取类别区域特征

根据像素语义信息和像素特征得到每个类别区域特征。其中像素语义信息是常规的语义分割结果,像素特征就是backbone提取得到的特征图。假设共有$20$个类别,图像尺寸为$100 \times 100$,则这一步的步骤如下:

  1. 像素语义($20×100×100$)展开成二维($20×10000$),其每一行表示每个像素点($10000$个像素点)属于某类物体(总共$20$个类)的概率。
  2. 像素特征($512×100×100$)展开成二维($512×10000$),其每一列表示每个像素点($10000$个像素点)在某一维特征($512$维)。
  3. 像素语义的每行乘以像素特征的每列再相加,得到类别区域特征,其每一行表示某个类($20$类)的$512$维特征。

def get_proxy(feats,probs):
    batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
    # 1, 20, 100, 100
    probs = probs.view(batch_size, c, -1) 
    # (1, 20, 10000)
    feats = feats.view(batch_size, feats.size(1), -1)
    # (1, 512, 10000)
    feats = feats.permute(0, 2, 1) # batch x hw x c 
    # (1, 10000, 512)
    probs = F.softmax(self.scale * probs, dim=2)# batch x k x hw
    # (1, 20, 10000)
    proxy = torch.matmul(probs, feats).permute(0, 2, 1).unsqueeze(3)# batch x k x c
    # (1, 512, 20, 1)
    return proxy
    
if __name__ == "__main__": 
    feats = torch.randn((1, 512, 100, 100))
    probs = torch.randn((1, 20, 100, 100))
    proxy=get_proxy(feats,probs)

Step2:像素区域相似度

对像素特征 featsstep1 得到类别区域特征 proxy ,使用 self-attention 得到像素与区域的相似度,即依赖关系。

def get_sim_map(feats, proxy):
    x=feats
    batch_size, h, w = x.size(0), x.size(2), x.size(3)
    # 1, 100, 100
    
    ## qk
    query = f_pixel(x).view(batch_size, self.key_channels, -1)
    # (1, 256, 10000)
    query = query.permute(0, 2, 1)
    # (1, 256, 10000)
    key = f_object(proxy).view(batch_size, self.key_channels, -1)
    # (1, 256, 20)
    value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
    # (1, 256, 20)
    value = value.permute(0, 2, 1)
    # (1, 20, 256)
   
    ## sim
    sim_map = torch.matmul(query, key)
    # (1, 10000, 20)
    sim_map = (self.key_channels**-.5) * sim_map
    # (1, 10000, 20)
    sim_map = F.softmax(sim_map, dim=-1)  
    # (1, 10000, 20)
    return sim_map
           
if __name__ == "__main__": 
    feats = torch.randn((1, 512, 100, 100))
    proxy = get_proxy(feats, probs) 
    sim_map = get_sim_map(feats, proxy)

Step3:获得上下文表示

step2计算得到像素与区域的相似度 simmap,其乘以每个类别区域特征proxy则可上下文特征,将其和像素特征进行拼接,再做通道调整得到最终的上下文表示:

def get_context(feats, proxy, sim_map):
    context = torch.matmul(sim_map, value) # hw x k x k x c
    # (1, 10000, 256)
    context = context.permute(0, 2, 1).contiguous()
    # (1, 10000, 256)
    context = context.view(batch_size, self.key_channels, *x.size()[2:])
    # (1, 256, 100, 100)
    context = f_up(context)
    # (1, 512, 100, 100)
    
    output = self.conv_bn_dropout(torch.cat([context, feats], 1))
    # (1, 512, 100, 100)
    return output
           
if __name__ == "__main__": 
    feats = torch.randn((1, 512, 100, 100))
    proxy = get_proxy(feats, probs) 
    sim_map = get_sim_map(feats, proxy)  
    output = get_context(proxy, sim_map)