PSPNet: 金字塔场景解析网络.
PSPNet模型提出是为了解决场景解析问题,以提升场景分析中对于相似颜色、形状的物体的检测精度。在ADE20K数据集上进行实验时,主要发现有如下问题:
- 错误匹配:图像分割模型缺乏收集上下文能力,导致了分类错误;
- 相似的标签会导致一些奇怪的错误,比如earth和field,模型会出现混淆。
- 小目标的丢失问题,小物体很难被网络所发现;而一些特别大的物体,在感受野不够大的情况下,往往会丢失一部分信息,导致预测不连续。
为解决以上问题,作者提出了金字塔池化模块(Pyramid Pooling Module, PPM)。在深层网络中,感受野的大小大致上体现了模型能获得的上下文新消息。尽管在理论上深层卷积网络的感受野已经大于图像尺寸,但是实际上会小得多。这就导致了很多网络不能充分的将上下文信息结合起来,于是作者提出了一种全局的先验方法:引入平均池化。
PPM模块并联了四个不同大小的平均池化层,将原始的特征图池化生成不同级别的特征图,经过卷积和上采样恢复到原始大小。这种操作聚合了多尺度的图像特征,融合了不同尺度和不同子区域之间的信息。最后,这个先验信息再和原始特征图进行相加,输入到最后的卷积模块完成预测。
PPM模块的实现如下:
class PPM(nn.ModuleList):
def __init__(self, pool_sizes, in_channels, out_channels):
super(PPM, self).__init__()
self.pool_sizes = pool_sizes
self.in_channels = in_channels
self.out_channels = out_channels
for pool_size in pool_sizes:
self.append(
nn.Sequential(
nn.AdaptiveMaxPool2d(pool_size),
nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1),
)
)
def forward(self, x):
out_puts = []
for ppm in self:
ppm_out = nn.functional.interpolate(ppm(x), size=x.size()[-2:], mode='bilinear', align_corners=True)
out_puts.append(ppm_out)
PSPNet的网络架构十分简单,backbone为resnet网络,将原始图像下采样$8$倍成特征图,特征图输入到PPM模块,并与其输出相加,最后经过卷积和$8$倍双线性差值上采样得到结果。
此外作者还在网络中引入了辅助损失(auxiliary loss),在resnet101的res4b22层引出一条FCN分支,用于计算辅助损失。论文里设置了赋值损失loss2的权重为$0.4$。
用于计算辅助损失的辅助头定义如下:
# 构建一个FCN分割头,用于计算辅助损失
class Aux_Head(nn.Module):
def __init__(self, in_channels=1024, num_classes=3):
super(Aux_Head, self).__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.decode_head = nn.Sequential(
nn.Conv2d(self.in_channels, self.in_channels//2, kernel_size=3, padding=1),
nn.BatchNorm2d(self.in_channels//2),
nn.ReLU(),
nn.Conv2d(self.in_channels//2, self.in_channels//4, kernel_size=3, padding=1),
nn.BatchNorm2d(self.in_channels//4),
nn.ReLU(),
nn.Conv2d(self.in_channels//4, self.num_classes, kernel_size=3, padding=1),
)
def forward(self, x):
return self.decode_head(x)
PSPNet构建如下:
from torchvision.models import resnet50, resnet101
from torchvision.models._utils import IntermediateLayerGetter
class PSPHEAD(nn.Module):
def __init__(self, in_channels, out_channels,pool_sizes = [1, 2, 3, 6],num_classes=3):
super(PSPHEAD, self).__init__()
self.pool_sizes = pool_sizes
self.num_classes = num_classes
self.in_channels = in_channels
self.out_channels = out_channels
self.psp_modules = PPM(self.pool_sizes, self.in_channels, self.out_channels)
self.final = nn.Sequential(
nn.Conv2d(self.in_channels + len(self.pool_sizes)*self.out_channels, self.out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(self.out_channels),
nn.ReLU(),
)
def forward(self, x):
out = self.psp_modules(x)
out.append(x)
out = torch.cat(out, 1)
out = self.final(out)
class Pspnet(nn.Module):
def __init__(self, num_classes, aux_loss = True):
super(Pspnet, self).__init__()
self.num_classes = num_classes
# backbone调用resnet50
# 替换最后两个layer为dialation模式
# 引出layer3的计算结果用于计算辅助损失。
self.backbone = IntermediateLayerGetter(
resnet50(pretrained=False, replace_stride_with_dilation=[False, True, True]),
return_layers={'layer3':"aux" ,'layer4': 'stage4'}
)
self.aux_loss = aux_loss
self.decoder = PSPHEAD(in_channels=2048, out_channels=512, pool_sizes = [1, 2, 3, 6], num_classes=self.num_classes)
self.cls_seg = nn.Sequential(
nn.Conv2d(512, self.num_classes, kernel_size=3, padding=1),
)
if self.aux_loss:
self.aux_head = Aux_Head(in_channels=1024, num_classes=self.num_classes)
def forward(self, x):
_, _, h, w = x.size()
feats = self.backbone(x)
x = self.decoder(feats["stage4"])
x = self.cls_seg(x)
x = nn.functional.interpolate(x, size=(h, w),mode='bilinear', align_corners=True)
# 如果需要添加辅助损失
if self.aux_loss:
aux_output = self.aux_head(feats['aux'])
aux_output = nn.functional.interpolate(aux_output, size=(h, w),mode='bilinear', align_corners=True)
return {"output":x, "aux_output":aux_output}
return {"output":x}