EMANet: 语义分割的期望最大化注意力网络.
近年来,自注意力机制在自然语言处理领域取得卓越成果。Nonlocal被提出后,在计算机视觉领域也受到了广泛的关注,并被一系列文章证明了在语义分割中的有效性。它使得每个像素可以充分捕获全局信息。然而,自注意力机制需要生成一个巨大的注意力图,其空间复杂度和时间复杂度巨大。其瓶颈在于,每一个像素的注意力图都需要对全图计算。
本文所提出的期望最大化注意力机制(Expectation-Maximization Attention, EMA),摒弃了在全图上计算注意力图的流程,转而通过期望最大化(EM)算法迭代出一组紧凑的基,在这组基上运行注意力机制,从而大大降低了复杂度。其中,E步更新注意力图,M步更新这组基。E、M交替执行,收敛之后用来重建特征图。本文把这一机制嵌入网络中,构造出轻量且易实现的EMA Unit。其作为语义分割头,在多个数据集上取得了较高的精度。
期望最大算法 (Expectation Maximization, EM)旨在为隐变量模型寻找最大似然解。对于观测数据$X$,每一个数据点$x_i$都对应隐变量$z_i$。完整数据\(\{X,Z\}\)的似然函数为$\ln p(X,Z|\theta)$, $\theta$是模型的参数。
E步根据当前参数$\theta^{odd}$计算隐变量$Z$的后验分布,并以之寻找完整数据的似然\(\mathcal{Q}\left(\boldsymbol{\theta}, \boldsymbol{\theta}^{o l d}\right)\):
\[\mathcal{Q}\left(\boldsymbol{\theta}, \boldsymbol{\theta}^{\text {old }}\right)=\sum_{\mathbf{Z}} p\left(\mathbf{Z} \mid \mathbf{X}, \boldsymbol{\theta}^{\text {old }}\right) \ln p(\mathbf{X}, \mathbf{Z} \mid \boldsymbol{\theta})\]M步通过最大化似然函数来更新参数$\theta$:
\[\boldsymbol{\theta}^{\text {new }}=\underset{\boldsymbol{\theta}}{\arg \max } \mathcal{Q}\left(\boldsymbol{\theta}, \boldsymbol{\theta}^{\text {old }}\right)\]自注意力机制的核心算子是:
\[\mathbf{y}_i=\frac{1}{C(\mathbf{x})} \sum_{\forall j} f\left(\mathbf{x}_i, \mathbf{x}_j\right) g\left(\mathbf{x}_j\right)\]它将第$i$个像素的特征$x_i$更新为其他所有像素特征经过$g(\cdot)$变换之后的加权平均$y_i$,权重通过归一化后的核函数$f(\cdot,\cdot)$计算,表征两个像素之间的相关度。自注意力机制可视为像素特征被一组过完备的基进行重构,这组基数目巨大,且存在大量信息冗余。
期望最大化注意力机制由$A_E,A_M,A_R$三部分组成,前两者分别对应EM算法的E步和M步。
假定输入的特征图为\(\mathbf{X} \in R^{N \times C}\),基初始值为\(\mathbf{\mu} \in R^{K \times C}\), $A_E$步估计隐变量\(\mathbf{Z} \in R^{N \times K}\),即每个基对像素的权责。具体地,第$k$个基对第$n$个像素的权责可以计算为:
\[z_{n k}=\frac{\mathcal{K}\left(\mathbf{x}_n, \boldsymbol{\mu}_k\right)}{\sum_{j=1}^K \mathcal{K}\left(\mathbf{x}_n, \boldsymbol{\mu}_j\right)}\]其中内核\(\mathcal{K}(a,b)=\exp(a^Tb)\)。在实现中,可以用如下的方式实现:
\[\mathbf{Z}=\operatorname{softmax}\left(\lambda \mathbf{X}\left(\boldsymbol{\mu}^{\top}\right)\right)\]$A_M$步更新基\(\mathbf{\mu}\)。为了保证\(\mathbf{\mu}\)和\(\mathbf{X}\)处在同一表征空间内,此处\(\mathbf{\mu}\)被计算作\(\mathbf{X}\)的加权平均。具体地,第$k$个基被更新为:
\[\boldsymbol{\mu}_k=\frac{\sum_{n=1}^N z_{n k} \mathbf{x}_n}{\sum_{n=1}^N z_{n k}}\]值得注意的是,如果$\lambda \to \infty$,则$Z$会变成一组one-hot编码。在这种情形下,每个像素仅由一个基负责,而基被更新为其所负责的像素的均值,这便是标准的K-means算法。
$A_E$和$A_M$交替执行$T$步。此后,$A_R$步使用近似收敛的$Z$和$\mu$对$X$进行重估计:
\[\tilde{\mathbf{X}}=\mathbf{Z} \boldsymbol{\mu}\]\(\tilde{\mathbf{X}}\)具有低秩的特性。其在保持类间差异的同时,类别内部差异得到缩小。从图像角度来看,起到了类似保边滤波的效果。
综上,EMA在获得低秩重构特性的同时,将复杂度从Nonlocal的$O(N^2)$降低至$O(NKT)$。实验中,EMA仅需3步就可达到近似收敛,因此$T$作为一个小常数,可以被省去。考虑到$K < < N$,其复杂度得到显著的降低。
对于EM算法而言,参数的初始化会影响到最终收敛时的效果。对于深度网络训练过程中的大量图片,在逐个批次训练的同时,EM参数的迭代初值$\mu^{(0)}$理应得到不断优化。本文中,迭代初值$\mu^{(0)}$的维护采用滑动平均更新方式,即:
\[\boldsymbol{\mu}^{(0)} \leftarrow \alpha \boldsymbol{\mu}^{(0)}+(1-\alpha) \overline{\boldsymbol{\mu}}^{(T)}\]EMA的迭代过程可以展开为一个RNN,其反向传播也会面临梯度爆炸或消失等问题。RNN中采取LayerNorm(LN)来进行归一化是一个合理的选择。但在EMA中,LN会改变基的方向,进而影响其语义。因此本文选择L2Norm来对基进行归一化。这样$\mu^{(0)}$的更新轨迹便处在一个高维球面上。
EMA模块的实现如下:
class EMAModule(nn.Module):
"""Expectation Maximization Attention Module used in EMANet.
Args:
channels (int): Channels of the whole module.
num_bases (int): Number of bases.
num_stages (int): Number of the EM iterations.
"""
def __init__(self, channels, num_bases, num_stages, momentum):
super(EMAModule, self).__init__()
assert num_stages >= 1, 'num_stages must be at least 1!'
self.num_bases = num_bases
self.num_stages = num_stages
self.momentum = momentum
bases = torch.zeros(1, channels, self.num_bases)
bases.normal_(0, math.sqrt(2. / self.num_bases))
# [1, channels, num_bases]
bases = F.normalize(bases, dim=1, p=2)
self.register_buffer('bases', bases)
def forward(self, feats):
"""Forward function."""
batch_size, channels, height, width = feats.size()
# [batch_size, channels, height*width]
feats = feats.view(batch_size, channels, height * width)
# [batch_size, channels, num_bases]
bases = self.bases.repeat(batch_size, 1, 1)
with torch.no_grad():
for i in range(self.num_stages):
# [batch_size, height*width, num_bases]
attention = torch.einsum('bcn,bck->bnk', feats, bases)
attention = F.softmax(attention, dim=2)
# l1 norm
attention_normed = F.normalize(attention, dim=1, p=1)
# [batch_size, channels, num_bases]
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
# l2 norm
bases = F.normalize(bases, dim=1, p=2)
feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
feats_recon = feats_recon.view(batch_size, channels, height, width)
if self.training:
bases = bases.mean(dim=0, keepdim=True)
# l2 norm
bases = F.normalize(bases, dim=1, p=2)
self.bases = (1 - self.momentum) * self.bases + self.momentum * bases
return feats_recon
在EMA Unit中,除了核心的EMA之外,两个$1 \times 1$卷积分别放置于EMA前后。前者将输入的值域从$R^+$映射到$R$;后者将\(\tilde{\mathbf{X}}\)映射到\(\mathbf{X}\)的残差空间。囊括进两个卷积的额外负荷,EMAU的FLOPs仅相当于同样输入输出大小时$3 \times 3$卷积的$1/3$,参数量仅为$2C^2+CK$。
class EMAHead(nn.Module):
def __init__(self,
ema_channels,
num_bases,
num_stages,
concat_input=True,
in_channels=2048,
channels=256,
momentum=0.1,
**kwargs):
super(EMAHead, self).__init__(**kwargs)
self.ema_channels = ema_channels
self.num_bases = num_bases
self.num_stages = num_stages
self.concat_input = concat_input
self.momentum = momentum
self.ema_module = EMAModule(self.ema_channels, self.num_bases,
self.num_stages, self.momentum)
self.channels = channels
self.in_channels = in_channels
self.ema_in_conv = nn.Conv2d(
self.in_channels,
self.ema_channels,
3,
padding=1)
self.ema_mid_conv = nn.Conv2d(
self.ema_channels,
self.ema_channels,
1)
for param in self.ema_mid_conv.parameters():
param.requires_grad = False
self.ema_out_conv = nn.Conv2d(
self.ema_channels,
self.ema_channels,
1)
self.bottleneck = nn.Conv2d(
self.ema_channels,
self.channels,
3,
padding=1)
if self.concat_input:
self.conv_cat = nn.Conv2d(
self.in_channels + self.channels,
self.channels,
kernel_size=3,
padding=1)
self.attention = None
def forward(self, inputs):
"""Forward function."""
x = inputs
feats = self.ema_in_conv(x)
identity = feats
feats = self.ema_mid_conv(feats)
recon = self.ema_module(feats)
self.attention = recon
recon = F.relu(recon, inplace=True)
recon = self.ema_out_conv(recon)
output = F.relu(identity + recon, inplace=True)
output = self.bottleneck(output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
return output
EMANet的完整结构如下:
class EMANet(nn.Module):
def __init__(self, num_classes):
super(EMANet, self).__init__()
self.num_classes = num_classes
self.backbone = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
self.Head = EMAHead(in_channels=2048, channels=256, ema_channels=512, num_bases=64, num_stages=3, momentum=0.1)
self.cls_seg = nn.Sequential(
nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True),
nn.Conv2d(256, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, self.num_classes, 3, padding=1),
)
def forward(self, x):
x = self.backbone(x)
x = self.Head(x)
x = self.cls_seg(x)
return x