多任务注意力网络与动态权重平均.

1. 多任务注意力网络

作者提出了多任务注意力网络(Multi-Task Attention Network,MTAN),用于进行多任务学习。MTAN的设计思路如下,首先使用一个共享参数的网络从输入数据中学习所有任务共享的特征表示;然后针对每一个任务使用一个注意力模块(相当于特征选择器),通过对共享网络的特征施加一个注意力mask来提取适用于该任务的特征表示。

MTAN的网络结构如下图所示。本文选用分割网络SegNet作为共享网络,由于SegNet采用对称的编码器-解码器结构设计,因此图中以编码器为例。灰色网络结构表示共享网络的编码器,采用类似VGG16的卷积网络;针对两个不同的任务分别使用绿色和蓝色两个注意力模块作为特征提取器,注意到一个注意力模块对应卷积网络的一组网络层。

注意力模块的结构如上图所示。该模块的操作如下:

在训练时同时训练主网络和每个子任务的注意力模块。对于第$i$个任务,其在第$j$个网络层中学习到的特征表示为:

\[\hat{a}_i^{(j)} = a_i^{(j)} \odot p^{(j)}\]

注意力mask $a_i^{(j)}$ 可以进一步表示为:

\[a_i^{(j)} = h_i^{(j)}( g_i^{(j)} ([u^{(j)};f_i^{(j)}(\hat{a}_i^{(j)})]) )\]

2. 动态权重平均

为了使多个任务在训练中得到平衡,需要为不同任务设置合适的损失权重。作者提出了动态权重平均(Dynamic Weight Average,DWA),动态地调整每轮训练中每个任务的损失权重,使得不同任务的重要程度相当。

定义第$k$个任务在第$t$轮训练中的相对下降率(relative descending rate)为前两轮训练时对应的损失函数数值之比:

\[w_k(t-1)=\frac{\mathcal{L}_k(t-1)}{\mathcal{L}_k(t-2)}\]

若$w_k(t-1)$较小,表明第$t-1$轮训练使得损失下降,该任务得到较好的学习,因此可以适当减小对该任务的关注。第$k$个任务的权重计算为:

\[\lambda_k(t) = \frac{K \exp(w_k(t-1)/T)}{\sum_{i}^{}\exp(w_i(t-1)/T)}\]

其中温度$T$控制权重分布的平坦程度。较大的$T$使得所有任务的权重都接近$1$。训练前两轮权重均设置为$1$。

3. 实验分析

作者在语义分割、深度估计和表面法线预测三个任务上进行实验,实验结果表明所提多任务注意力网络能够取得最好的结果,且所用动态权重平均的方法能够改进多任务学习的性能。

作者对主网络的特征以及两个任务的提取特征进行可视化,通过注意力机制使得不同任务能够提取到任务相关的特征: