APCNet: 语义分割的自适应金字塔上下文网络.
目前的图像分割模型存在以下问题:
- 多尺度问题:对于语义分割任务而言,物体往往存在尺寸不同、位置不同的特点,对于一些没有聚合上下文信息的模型来说,检测这种尺寸相差较大的物体比较困难,同时会丢失一些细节信息。
- 自适应区域:在图像中,并不是所有的区域都与被分割物体有关,有些像素点对于物体正确分割的影响大,而有些像素点则没有什么影响。同时,这些像素点或者相关区域的位置不一定就在被分割物体的周围,也有可能远离被分割物体。这就要求模型具有自适应选择区域的能力,能够识别这些重要区域帮助物体的正确分割。
- 全局和局部的信息融合权重(Global-guided Local Affinity, GLA):在构建了上下文向量之后,如何将上下文向量和原始特征图进行加权,这个权重如何该去选择和计算。
本文提出了端到端的APCNet模型,能够进行全局信息的融合,以及提升多尺度检测效果。APCNet的金字塔层由若干个ACM(Adaptive Context Module)模块构成,每一个ACM模块接收一个scale参数$s$,来确定区域大小。
ACM模块计算每个局部位置的上下文向量,并将这个向量加权到特征图上,实现聚合上下文信息的作用。ACM由两个分支构成,分别是GLA分支和Aggregate分支。
在GLA分支中,backbone输出的特征图记为$X$,$X$先经过一个1x1的Conv来得到一个特征映射$x$,通过一个空间全局池化,将$x$映射成一个全局信息向量(Global Information)$g(X)$。随后将$x$和$g(X)$相加后通过一个1x1的Conv和Sigmoid激活,生成一个GLA向量$\alpha^s$,将这个向量reshape后得到上下文向量。
在Aggregate分支中, 特征图$X$通过AdaptivePooling(size=s)、Conv(kernel size=1x1)、reshape后得到的形状为$s^2\times 512$大小的$y^s$,与GLA分支中的GLA向量$\alpha^s$进行矩阵乘法,生成$hw\times 512$的结果,完成初步的特征融合。最后reshape成原始大小,并与GLA部分的残差相加,最终输出总的融合结果。
class ACMModle(nn.Module):
def __init__(self, in_channels=2048, channels=512, pool_scale=1, fusion=True):
super(ACMModle, self).__init__()
self.pool_scale = pool_scale
self.in_channels = in_channels
self.channels = channels
self.fusion = fusion
# Global Information vector
self.reduce_Conv = nn.Conv2d(self.in_channels, self.channels, 1)
self.reduce_Pool_Conv = nn.Conv2d(self.in_channels, self.channels, 1)
self.residual_conv = nn.Conv2d(self.channels, self.channels, 1)
self.global_info = nn.Conv2d(self.channels, self.channels, 1)
self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
if self.fusion:
self.fusion_conv = nn.Conv2d(self.channels, self.channels, 1)
def forward(self, X):
batch_size, c, h, w = X.shape
x = self.reduce_Conv(X)
GI = self.global_info(F.adaptive_avg_pool2d(x, 1))
GI = torchvision.transforms.Resize(x.shape[2:])(GI)
Affinity_matrix = self.gla(x + GI).permute(0, 2, 3, 1).reshape(batch_size, -1, self.pool_scale**2)
Affinity_matrix = F.sigmoid(Affinity_matrix)
pooled_x = F.adaptive_avg_pool2d(X, self.pool_scale)
pooled_x = pooled_x.view(batch_size, -1, self.pool_scale**2).permute(0, 2, 1).contiguous()
MatrixProduct = torch.matmul(Affinity_matrix, pooled_x)
MatrixProduct = MatrixProduct.permute(0, 2, 1).contiguous()
MatrixProduct = MatrixProduct.view(batch_size, c, h, w)
MatrixProduct = self.residual_conv(MatrixProduct)
Z_out = F.relu(MatrixProduct + x)
if self.fusion:
Z_out = self.fusion_conv(Z_out)
return Z_out
class ACMModuleList(nn.ModuleList):
def __init__(self, pool_scales = [1,2,3,6], in_channels = 2048, channels = 512):
super(ACMModuleList, self).__init__()
self.pool_scales = pool_scales
self.in_channels = in_channels
self.channels = channels
for pool_scale in pool_scales:
self.append(
ACMModle(in_channels, channels, pool_scale)
)
def forward(self, x):
out = []
for ACM in self:
ACM_out = ACM(x)
out.append(ACM_out)
return out
class APCNet(nn.Module):
def __init__(self, num_classes):
super(APCNet, self).__init__()
self.num_classes = num_classes
self.backbone = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
self.in_channels = 2048
self.channels = 512
self.ACM_pyramid = ACMModuleList(pool_scales=[1,2,3,6], in_channels=self.in_channels, channels=self.channels)
self.conv1 = nn.Sequential(
nn.Conv2d(4*self.channels + self.in_channels, self.channels, 3, padding=1),
nn.BatchNorm2d(self.channels),
nn.ReLU()
)
self.cls_conv = nn.Conv2d(self.channels, self.num_classes, 3, padding=1)
def forward(self, x):
x = self.backbone(x)
ACM_out = self.ACM_pyramid(x)
ACM_out.append(x)
x = torch.cat(ACM_out, dim=1)
x = self.conv1(x)
x = Resize((8*x.shape[-2], 8*x.shape[-1]))(x)
x = self.cls_conv(x)
return x