TFPose: 基于Transformer的2D人体关节点回归.

1. 人体姿态估计

人体姿态估计要求计算机在输入图像中获取人体感兴趣的关键点,在行为理解等许多计算机视觉任务中起着重要作用。现有解决该任务的主流方法大致可以分为基于热图的方法和基于回归的方法。

基于热图的方法通常首先使用卷积神经网络(通常是全卷积网络)预测人体关节点的热力图谱,然后根据热图中的峰值位置定位人体关节。大多数姿态估计方法都是基于热图的,因为热图具有较高的精度。然而,基于热图的方法可能会遇到以下问题。

  1. 需要进行后处理步骤(如取最大值操作)。由于后处理操作可能是不可微的,这使得框架无法端到端的训练。
  2. 卷积神经网络预测的热图分辨率通常低于输入图像的分辨率。降低分辨率会引入量化误差,限制了关键点的定位精度。这种量化误差可以通过根据峰值附近像素的值动态移动输出坐标来解决,但这会使框架更加复杂,并引入更多的超参数。
  3. 用作ground-truth的热图需要人工设计和启发性调整,这可能会导致ground-truth热图中包含许多噪声和模糊。

而基于回归的方法通常引入全连接层直接将输入图像映射为人体关节的坐标,从而避免生成热图这一中间步骤。因此基于回归的方法比基于热图的方法更直接,且可以绕过上述基于热图的方法的缺点,具有更大的应用前景。

然而基于回归方法的性能往往不如基于热图的方法,原因可能以下几点。首先,为了减少全连接层的网络参数,在大多数基于回归的方法中应用一个全局平均池化层来减少全连接层之前的特征分辨率,从而破坏了卷积特征映射的空间结构,显著降低了性能。并且在基于回归的方法中卷积特征和预测是不对齐的,导致关键点定位精度低。此外,基于回归的方法只回归关节的坐标,没有考虑这些关键点之间的结构化依赖。

Transformer除了在自然语言处理方面取得了重大进展,近期也引起了计算机视觉界的广泛关注。Transformer最初是为序列到序列任务而设计的,应用到图像时是通过序列化的image patches实现的。如果将人体姿态估计任务看作为预测长度为KK的序列坐标的问题(KK是一个人的身体关节数),则能够通过Transformer实现人体姿态估计。

2. 使用Transformer进行姿态估计

本文提出了一个基于回归的姿态估计框架TFPose,将姿态估计任务转化为一个序列预测问题,并通过Transformer有效地解决。由于Transformer中的注意机制,所提框架能够自适应地注意到与目标关键点最相关的特征,并内在地利用关键点之间的结构化关系。在MS COCOMPII数据集上的实验表明,所提方法可以显著提高基于回归的姿态估计的表现,并实现与基于热图的姿态估计方法相当的水平。

所提框架如图所示,以卷积神经网络的特征映射为输入,通过Transformer依次预测KK个坐标。TFPose可以绕过基于回归的方法遇到的困难,它不依赖于全局平均池化,能够避免卷积特征与预测之间的特征错位,并且可以自然地捕获关键点之间的结构化依赖关系,从而提高性能。

TFPose通过将卷积神经网络与Transformer结构相结合,直接并行地预测所有关键点坐标序列。Transformer解码器将一定数量的关键点查询向量和编码器输出特征作为输入,并通过一个多层前馈网络预测最终的关键点坐标。

TFPose是第一个基于Transformer的姿态估计框架。该框架适应简单直观的基于回归的方法,具有端到端的可训练性,克服了基于热图的方法的诸多缺点。此外,TFPose可以自然地学会利用关键点之间的结构化依赖关系,而无需启发式设计,这将提高性能和模型的可解释性。TFPose极大地提升了回归方法的表现,可以与目前最先进的基于热图的方法相媲美。

3. TFPose

TFPose首先应用一个基于卷积神经网络的人体检测器获取每个目标的边界框,然后根据检测框从输入图像中裁剪出每个目标。用IRh×h×3IRh×h×3表示裁剪后的图像,其中hhww分别是图像的高度和宽度。对于基于热图的方法,通常是应用卷积神经网络FF去预测关键点热图HRh×h×kHRh×h×k,其中kk为预测关键点的个数。形式上有

H=F(I)H=F(I)

HH的每个像素代表身体关节出现的概率。为了进一步获得关节坐标JR3×kJR3×k,这些方法通常使用取最大值操作来获得具有峰值激活的位置。形式上,设ppHH上的空间位置,可以写成

Jk=argmaxp(Hk(p))Jk=argmaxp(Hk(p))

需要注意的是,在基于热图的方法中,pp的定位精度受限于HH的分辨率,而这往往远远低于输入图像的分辨率,从而导致量化误差。此外argmax操作是不可微操作,使得整个过程无法端到端训练。

TFPose中,将JJ作为长度为kk的序列,直接将输入图像II映射到身体关节的坐标JJ:

J=F(I)J=F(I)

首先使用卷积神经网络提取输入图像的多层次特征。多级特征映射分别用C2、C3、C4和C5表示,其尺寸缩小系数分别为4、8、16和32。分别对这些特征图应用1 × 1卷积,使得它们的输出通道数量相同。这些特征映射被展平并连接在一起,将F0Rn×cF0Rn×c输入到Transformer中的编码器中,其中n是F0F0中的像素数。和原始Transformer相同,F0F0加上位置嵌入编码,用FE0FE0表示带有位置嵌入的F0F0。然后F0F0FE0FE0被送到Transformer中计算注意力特征MRn×cMRn×c。在Transformer解码器中使用查询矩阵QRK×cQRK×c得到K个关节点的坐标JRK×2JRK×2

Transformer输入F0的位置嵌入中EliR1×cEliR1×c表示层次嵌入,编码特征向量所在的层级。EPRn×cEPRn×c表示一个特征向量在特征图上的空间位置的像素嵌入。使用FE0FE0表示带有位置嵌入的F0F0F0F0FE0FE0都是Transformer的输入。

TFPose中,使用NE=6NE=6个编码器层。对于EthEth编码器层,将前一个编码器层的输出作为该层的输入,计算每个编码器层的输出向量之间的像素对像素注意力。应用NENETransformer编码器层后,可得到编码器特征M。

在解码器中,目标是从编码器特征M中解码所需的关键点坐标,使用查询矩阵QRK×cQRK×c实现,K是关节点个数,Q是一个额外的可学习矩阵,它在训练过程中与模型参数联合更新,其中每一行对应一个关键点。

TFPose中有NDNDTransformer解码器层。每一解码器层以编码器特征M和前一解码器层Qd1RK×cQd1RK×c的输出作为输入(第一层以M和矩阵Q作为输入),Qd1Qd1加入位置嵌入, 结果用QEd1QEd1表示。Qd=1Qd=1QEd1QEd1将被发送到查询到查询注意力模块,该模块旨在模拟人体关节之间的依赖关系。该注意模块使用Qd1Qd1QEd1QEd1QED1QED1作为值向量、查询向量和键向量。之后,该注意力模块和编码器特征M被用于计算像素到查询注意力,并应用全连接层预测K个关节点的坐标JRK×2JRK×2

受之前基于回归的姿态识别网络的启发,不是简单地预测最终解码器层中的关键点坐标,而是要求所有解码器层预测关键点坐标。即先让第一个解码器层直接预测目标坐标。然后每一个解码器层对前解码器层进行预测细化ΔˆydRk×2Δ^ydRk×2。通过这种方式,关键点坐标可以逐步细化。形式上,设ˆyd1^yd1为第(d-1)层解码器层预测的关键点坐标,则第d层解码器层的预测为

ˆyd=σ(σ1(ˆyd1+Δˆyd))^yd=σ(σ1(^yd1+Δ^yd))

其中σσσ1σ1分别表示sigmoid函数和inverse sigmoid函数。ˆy0^y0是一个随机初始化的矩阵,在训练过程中与模型参数联合更新。

TFPose的损失函数由两部分组成。第一部分是L1回归损失。设yRK×2yRK×2ground-truth坐标,回归损失公式为:

Lreg=NDd=1yˆydLreg=NDd=1y^yd

其中NDND为解码器的个数,每个解码器层都由目标关键点坐标监督。第二部分是辅助损失LauxLaux,即在训练中使用了辅助热图学习,这样可以获得更好的性能。为了使用热图学习,从编码器特征M中收集特征向量C5,并将其重构为原始的空间形状。结果用MC5R(h/32)×(w/32)×cMC5R(h/32)×(w/32)×c表示。对MC5MC5应用3次反卷积对特征图上采样8倍,生成热图ˆhR(h/4)×(w/4)×k^hR(h/4)×(w/4)×k,然后计算预测热图和ground-truth热图之间的均方误差损失。形式上,辅助损失函数为

Laux=∥HˆH2Laux=H^H2

将两个损失函数相加,得到最终的总损失如下,λ是一个常数,用来平衡这两个损失。

Loverall=Lreg+λLauxLoverall=Lreg+λLaux

4. 实验分析

数据集使用COCOMPII数据集。卷积神经网络默认使用在ImageNet上预训练的ResNet-18,输入图像的大小为256×192或384×288。对于Transformer,采用Deformable DETR中提出的Deformable Attention Module,并使用了相同的超参数。所有模型均采用AdamW优化,β1和β2被设置为0.9和0.999,权重衰减设置为104104。λ默认设置为50,以平衡回归损失和辅助损失。所有的实验都使用初始学习率为4×1034×103的余弦退火学习率。训练时应用了随机旋转、随机缩放、翻转等数据扩增。

为了研究自注意力机制如何定位人体关节点,将该模块在特征图C3的采样位置进行可视化。在自注意力模块中有8个自注意力head,每个head会在每个特征图上采样4个点。所以对于C3特征图有32个采样点。如图所示,采样点(红点)都密集分布在ground truth(黄圈)附近。该可视化说明TFPose在一定程度上解决了特征对齐问题。

为了进一步研究自注意力机制是如何工作的,将注意力权重可视化。如图所示有两个明显的关注模式: 第一个关注模式是对称的关节(如左肩与右肩)更有可能相互影响;第二个注意模式是相邻关节(如眼睛、鼻子、嘴巴)更有可能相互影响。这种注意模式表明TFPose可以利用身体关节之间的上下文和结构关系来定位和分类身体关节的类型。