学习人体姿态估计中的特征金字塔.

姿态估计是一个具有挑战性的计算机视觉任务,主要难点在于摄像机等原因造成的人体尺度的变化。当前金字塔模型在处理尺度变化方面具有良好的性能。本文设计了一个特征金字塔模块(Pyramid Residual Module,PRMs)来处理尺度变化。在给定输入特征的情况下,PRMs学习不同尺度的卷积滤波器,这些卷积滤波器在多分支网络中以不同的下采样率获得。

网络的整个框架采用堆叠沙漏网络(stacked hourglass networks),是一种高度模块化的网络。沙漏网络目的在于捕获各种尺度的信息,它首先通过自底向上的方法对特征进行下采样,然后通过对特征图进行上采样来执行自上而下的处理,同时合并来自底层的高分辨率特征。 这种自下而上、自上而下的处理过程要重复几次,以建立一个“堆叠沙漏”网络,在每个堆叠的末端进行中间监督。

然而,沙漏网络的残差单元只能在一个尺度上捕获视觉模式或语义。本文作者所提出的金字塔残差模块能够捕捉多尺度视觉模式或语义。PRM的结构示意图如下,虚线表示恒等映射。

传统的池化层被广泛使用,但是池化层会使得分辨率下降过快,池化过程过于粗糙。本文使用了fractional max-pooling方法,把输入区域随机划分为与输出尺寸相同的不均匀的子区域,并对每个子区域执行最大池化操作。第$c$层金字塔特征的下采样比例为$s_c=2^{-M\frac{c}{C}}$,其中$c=0,…,C,M\geq 1$。$s_c$的取值范围为$[2^{-M},1]$。当$c=0$时,下采样比例为$1$,和原图一样大。在实验中作者设置$M=1, C=4$(也表示金字塔有五层),最小的下采样有原始输入分辨率的一半。

class BnReluConv(nn.Module):
	def __init__(self, inChannels, outChannels, kernelSize = 1, stride = 1, padding = 0):
		super(BnReluConv, self).__init__()
		self.bn = nn.BatchNorm2d(inChannels)
		self.relu = nn.ReLU()
		self.conv = nn.Conv2d(inChannels, outChannels, kernelSize, stride, padding)

	def forward(self, x):
		x = self.bn(x)
		x = self.relu(x)
		x = self.conv(x)
		return x

class Pyramid(nn.Module):
	def __init__(self, D, cardinality, inputRes):
		super(Pyramid, self).__init__()
		self.cardinality = cardinality
		scale = 2**(-1/self.cardinality)
		_scales = []
		for card in range(self.cardinality):
			temp = nn.Sequential(
					nn.FractionalMaxPool2d(2, output_ratio = scale**(card + 1)),
					nn.Conv2d(D, D, 3, 1, 1),
					nn.Upsample(size = inputRes)#, mode='bilinear')
				)
			_scales.append(temp)
		self.scales = nn.ModuleList(_scales)

	def forward(self, x):
		out = torch.zeros_like(x)
		for card in range(self.cardinality):
			out += self.scales[card](x)
		return out

class BnReluPyra(nn.Module):
	def __init__(self, D, cardinality, inputRes):
		super(BnReluPyra, self).__init__()
		self.bn = nn.BatchNorm2d(D)
		self.relu = nn.ReLU()
		self.pyra = Pyramid(D, cardinality, inputRes)

	def forward(self, x):
		x = self.bn(x)
		x = self.relu(x)
		x = self.pyra(x)
		return x

class PyraConvBlock(nn.Module):
	def __init__(self, inChannels, outChannels, inputRes, baseWidth, cardinality, type = 1):
		super(PyraConvBlock, self).__init__()
		self.branch1 = nn.Sequential(
				BnReluConv(inChannels, outChannels//2, 1, 1, 0),
				BnReluConv(outChannels//2, outChannels//2, 3, 1, 1)
			)
		self.branch2 = nn.Sequential(
				BnReluConv(inChannels, outChannels // baseWidth, 1, 1, 0),
				BnReluPyra(outChannels // baseWidth, cardinality, inputRes),
				BnReluConv(outChannels // baseWidth, outChannels//2, 1, 1, 0)
			)
		self.afteradd = BnReluConv(outChannels//2, outChannels, 1, 1, 0)

	def forward(self, x):
		x = self.branch2(x) + self.branch1(x)
		x = self.afteradd(x)
		return x

class SkipLayer(nn.Module):
	def __init__(self, inChannels, outChannels):
		super(SkipLayer, self).__init__()
		if (inChannels == outChannels):
			self.conv = None
		else:
			self.conv = nn.Conv2d(inChannels, outChannels, 1)

	def forward(self, x):
		if self.conv is not None:
			x = self.conv(x)
		return x

class ResidualPyramid(nn.Module):
	def __init__(self, inChannels, outChannels, inputRes, baseWidth, cardinality, type = 1):
		super(ResidualPyramid, self).__init__()
		self.cb = PyraConvBlock(inChannels, outChannels, inputRes, baseWidth, cardinality, type)
		self.skip = SkipLayer(inChannels, outChannels)

	def forward(self, x):
		out = self.cb(x)
		out = out + self.skip(x)
		return out