多人姿态估计的级联金字塔网络.
本文提出了一种金字塔型的串接模型,即CPN(cascaded pyramid network),这个模型是top-down的多人检测模型,即先在图片中检测人体,再做单人姿态估计。CPN能够同时兼顾人体关节点的局部信息以及全局信息,优化对于难以检测的点(关键点遮挡,关键点不可见,复杂背景)的预测。
人体检测使用的是FPN进行边界框定位,并应用了ROIAlign。然后使用CPN检测关键点,CPN本体由两部分组成:GlobalNet和RefineNet,流程如下图所示。
GlobalNet对关键点进行粗提取(用ResNet的网络架构回归heatmap),RefineNet对不同层信息进行了融合,可以更好的综合特征定位关键点。首先对于可以看见的关键点直接预测得到,对于不可见的关键点,使用增大感受野来获得关键点位置,对于还未检测出的点,使用上下文context进行预测(即融合多个感受野信息,最后concatenate所有的同一尺寸特征图进一步回归关键点位置)。
class CPN(nn.Module):
def __init__(self, resnet, output_shape, num_class, pretrained=True):
super(CPN, self).__init__()
channel_settings = [2048, 1024, 512, 256]
self.resnet = resnet
self.global_net = globalNet(channel_settings, output_shape, num_class)
self.refine_net = refineNet(channel_settings[-1], output_shape, num_class)
def forward(self, x):
res_out = self.resnet(x)
global_fms, global_outs = self.global_net(res_out)
refine_out = self.refine_net(global_fms)
return global_outs, refine_out
def CPN50(out_size,num_class,pretrained=True):
res50 = resnet50(pretrained=pretrained)
model = CPN(res50, output_shape=out_size,num_class=num_class, pretrained=pretrained)
return model
GlobalNet采用类似于FPN的特征金字塔结构,并在每个elem-sum前添加了1x1卷积,负责所有关键点的检测,重点是对比较容易检测的眼睛、胳膊等部位的关键点。
class globalNet(nn.Module):
def __init__(self, channel_settings, output_shape, num_class):
super(globalNet, self).__init__()
self.channel_settings = channel_settings
laterals, upsamples, predict = [], [], []
for i in range(len(channel_settings)):
laterals.append(self._lateral(channel_settings[i]))
predict.append(self._predict(output_shape, num_class))
if i != len(channel_settings) - 1:
upsamples.append(self._upsample())
self.laterals = nn.ModuleList(laterals)
self.upsamples = nn.ModuleList(upsamples)
self.predict = nn.ModuleList(predict)
def _lateral(self, input_size):
layers = []
layers.append(nn.Conv2d(input_size, 256,
kernel_size=1, stride=1, bias=False))
layers.append(nn.BatchNorm2d(256))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def _upsample(self):
layers = []
layers.append(torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
layers.append(torch.nn.Conv2d(256, 256,
kernel_size=1, stride=1, bias=False))
layers.append(nn.BatchNorm2d(256))
return nn.Sequential(*layers)
def _predict(self, output_shape, num_class):
layers = []
layers.append(nn.Conv2d(256, 256,
kernel_size=1, stride=1, bias=False))
layers.append(nn.BatchNorm2d(256))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(256, num_class,
kernel_size=3, stride=1, padding=1, bias=False))
layers.append(nn.Upsample(size=output_shape, mode='bilinear', align_corners=True))
layers.append(nn.BatchNorm2d(num_class))
return nn.Sequential(*layers)
def forward(self, x):
global_fms, global_outs = [], []
for i in range(len(self.channel_settings)):
if i == 0:
feature = self.laterals[i](x[i])
else:
feature = self.laterals[i](x[i]) + up
global_fms.append(feature)
if i != len(self.channel_settings) - 1:
up = self.upsamples[i](feature)
feature = self.predict[i](feature)
global_outs.append(feature)
return global_fms, global_outs
GolbalNet对身体部位遮挡或者有复杂背景的关键点预测误差较大,RefineNet基于GlobalNet生成的特征金字塔,其链接了所有层的金字塔特征用于定位不容易检测的关键点。
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 2)
self.relu = nn.ReLU(inplace=True)
self.downsample = nn.Sequential(
nn.Conv2d(inplanes, planes * 2,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * 2),
)
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class refineNet(nn.Module):
def __init__(self, lateral_channel, out_shape, num_class):
super(refineNet, self).__init__()
cascade = []
num_cascade = 4
for i in range(num_cascade):
cascade.append(self._make_layer(lateral_channel, num_cascade-i-1, out_shape))
self.cascade = nn.ModuleList(cascade)
self.final_predict = self._predict(4*lateral_channel, num_class)
def _make_layer(self, input_channel, num, output_shape):
layers = []
for i in range(num):
layers.append(Bottleneck(input_channel, 128))
layers.append(nn.Upsample(size=output_shape, mode='bilinear', align_corners=True))
return nn.Sequential(*layers)
def _predict(self, input_channel, num_class):
layers = []
layers.append(Bottleneck(input_channel, 128))
layers.append(nn.Conv2d(256, num_class,
kernel_size=3, stride=1, padding=1, bias=False))
layers.append(nn.BatchNorm2d(num_class))
return nn.Sequential(*layers)
def forward(self, x):
refine_fms = []
for i in range(4):
refine_fms.append(self.cascade[i](x[i]))
out = torch.cat(refine_fms, dim=1)
out = self.final_predict(out)
return out
在该阶段的训练中,还使用了online hard keypoints mining难例挖掘策略。在coco数据集中有17个关键点需要预测,GolbalNet预测所有的17个点,并计算所有17个点的loss;RefineNet也预测所有的17个点,但是只有最难的8个点的loss贡献给总loss。