EMANet: 语义分割的期望最大化注意力网络.

近年来,自注意力机制在自然语言处理领域取得卓越成果。Nonlocal被提出后,在计算机视觉领域也受到了广泛的关注,并被一系列文章证明了在语义分割中的有效性。它使得每个像素可以充分捕获全局信息。然而,自注意力机制需要生成一个巨大的注意力图,其空间复杂度和时间复杂度巨大。其瓶颈在于,每一个像素的注意力图都需要对全图计算。

本文所提出的期望最大化注意力机制(Expectation-Maximization Attention, EMA),摒弃了在全图上计算注意力图的流程,转而通过期望最大化(EM)算法迭代出一组紧凑的基,在这组基上运行注意力机制,从而大大降低了复杂度。其中,E步更新注意力图,M步更新这组基。EM交替执行,收敛之后用来重建特征图。本文把这一机制嵌入网络中,构造出轻量且易实现的EMA Unit。其作为语义分割头,在多个数据集上取得了较高的精度。

期望最大算法 (Expectation Maximization, EM)旨在为隐变量模型寻找最大似然解。对于观测数据$X$,每一个数据点$x_i$都对应隐变量$z_i$。完整数据\(\{X,Z\}\)的似然函数为$\ln p(X,Z|\theta)$, $\theta$是模型的参数。

E步根据当前参数$\theta^{odd}$计算隐变量$Z$的后验分布,并以之寻找完整数据的似然\(\mathcal{Q}\left(\boldsymbol{\theta}, \boldsymbol{\theta}^{o l d}\right)\):

\[\mathcal{Q}\left(\boldsymbol{\theta}, \boldsymbol{\theta}^{\text {old }}\right)=\sum_{\mathbf{Z}} p\left(\mathbf{Z} \mid \mathbf{X}, \boldsymbol{\theta}^{\text {old }}\right) \ln p(\mathbf{X}, \mathbf{Z} \mid \boldsymbol{\theta})\]

M步通过最大化似然函数来更新参数$\theta$:

\[\boldsymbol{\theta}^{\text {new }}=\underset{\boldsymbol{\theta}}{\arg \max } \mathcal{Q}\left(\boldsymbol{\theta}, \boldsymbol{\theta}^{\text {old }}\right)\]

自注意力机制的核心算子是:

\[\mathbf{y}_i=\frac{1}{C(\mathbf{x})} \sum_{\forall j} f\left(\mathbf{x}_i, \mathbf{x}_j\right) g\left(\mathbf{x}_j\right)\]

它将第$i$个像素的特征$x_i$更新为其他所有像素特征经过$g(\cdot)$变换之后的加权平均$y_i$,权重通过归一化后的核函数$f(\cdot,\cdot)$计算,表征两个像素之间的相关度。自注意力机制可视为像素特征被一组过完备的基进行重构,这组基数目巨大,且存在大量信息冗余。

期望最大化注意力机制由$A_E,A_M,A_R$三部分组成,前两者分别对应EM算法的E步和M步。

假定输入的特征图为\(\mathbf{X} \in R^{N \times C}\),基初始值为\(\mathbf{\mu} \in R^{K \times C}\), $A_E$步估计隐变量\(\mathbf{Z} \in R^{N \times K}\),即每个基对像素的权责。具体地,第$k$个基对第$n$个像素的权责可以计算为:

\[z_{n k}=\frac{\mathcal{K}\left(\mathbf{x}_n, \boldsymbol{\mu}_k\right)}{\sum_{j=1}^K \mathcal{K}\left(\mathbf{x}_n, \boldsymbol{\mu}_j\right)}\]

其中内核\(\mathcal{K}(a,b)=\exp(a^Tb)\)。在实现中,可以用如下的方式实现:

\[\mathbf{Z}=\operatorname{softmax}\left(\lambda \mathbf{X}\left(\boldsymbol{\mu}^{\top}\right)\right)\]

$A_M$步更新基\(\mathbf{\mu}\)。为了保证\(\mathbf{\mu}\)和\(\mathbf{X}\)处在同一表征空间内,此处\(\mathbf{\mu}\)被计算作\(\mathbf{X}\)的加权平均。具体地,第$k$个基被更新为:

\[\boldsymbol{\mu}_k=\frac{\sum_{n=1}^N z_{n k} \mathbf{x}_n}{\sum_{n=1}^N z_{n k}}\]

值得注意的是,如果$\lambda \to \infty$,则$Z$会变成一组one-hot编码。在这种情形下,每个像素仅由一个基负责,而基被更新为其所负责的像素的均值,这便是标准的K-means算法。

$A_E$和$A_M$交替执行$T$步。此后,$A_R$步使用近似收敛的$Z$和$\mu$对$X$进行重估计:

\[\tilde{\mathbf{X}}=\mathbf{Z} \boldsymbol{\mu}\]

\(\tilde{\mathbf{X}}\)具有低秩的特性。其在保持类间差异的同时,类别内部差异得到缩小。从图像角度来看,起到了类似保边滤波的效果。

综上,EMA在获得低秩重构特性的同时,将复杂度从Nonlocal的$O(N^2)$降低至$O(NKT)$。实验中,EMA仅需3步就可达到近似收敛,因此$T$作为一个小常数,可以被省去。考虑到$K < < N$,其复杂度得到显著的降低。

对于EM算法而言,参数的初始化会影响到最终收敛时的效果。对于深度网络训练过程中的大量图片,在逐个批次训练的同时,EM参数的迭代初值$\mu^{(0)}$理应得到不断优化。本文中,迭代初值$\mu^{(0)}$的维护采用滑动平均更新方式,即:

\[\boldsymbol{\mu}^{(0)} \leftarrow \alpha \boldsymbol{\mu}^{(0)}+(1-\alpha) \overline{\boldsymbol{\mu}}^{(T)}\]

EMA的迭代过程可以展开为一个RNN,其反向传播也会面临梯度爆炸或消失等问题。RNN中采取LayerNorm(LN)来进行归一化是一个合理的选择。但在EMA中,LN会改变基的方向,进而影响其语义。因此本文选择L2Norm来对基进行归一化。这样$\mu^{(0)}$的更新轨迹便处在一个高维球面上。

EMA模块的实现如下:

class EMAModule(nn.Module):
    """Expectation Maximization Attention Module used in EMANet.
    Args:
        channels (int): Channels of the whole module.
        num_bases (int): Number of bases.
        num_stages (int): Number of the EM iterations.
    """
 
    def __init__(self, channels, num_bases, num_stages, momentum):
        super(EMAModule, self).__init__()
        assert num_stages >= 1, 'num_stages must be at least 1!'
        self.num_bases = num_bases
        self.num_stages = num_stages
        self.momentum = momentum
 
        bases = torch.zeros(1, channels, self.num_bases)
        bases.normal_(0, math.sqrt(2. / self.num_bases))
        # [1, channels, num_bases]
        bases = F.normalize(bases, dim=1, p=2)
        self.register_buffer('bases', bases)
 
    def forward(self, feats):
        """Forward function."""
        batch_size, channels, height, width = feats.size()
        # [batch_size, channels, height*width]
        feats = feats.view(batch_size, channels, height * width)
        # [batch_size, channels, num_bases]
        bases = self.bases.repeat(batch_size, 1, 1)
 
        with torch.no_grad():
            for i in range(self.num_stages):
                # [batch_size, height*width, num_bases]
                attention = torch.einsum('bcn,bck->bnk', feats, bases)
                attention = F.softmax(attention, dim=2)
                # l1 norm
                attention_normed = F.normalize(attention, dim=1, p=1)
                # [batch_size, channels, num_bases]
                bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
                # l2 norm
                bases = F.normalize(bases, dim=1, p=2)
 
        feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
        feats_recon = feats_recon.view(batch_size, channels, height, width)
 
        if self.training:
            bases = bases.mean(dim=0, keepdim=True)
            # l2 norm
            bases = F.normalize(bases, dim=1, p=2)
            self.bases = (1 - self.momentum) * self.bases + self.momentum * bases
 
        return feats_recon

EMA Unit中,除了核心的EMA之外,两个$1 \times 1$卷积分别放置于EMA前后。前者将输入的值域从$R^+$映射到$R$;后者将\(\tilde{\mathbf{X}}\)映射到\(\mathbf{X}\)的残差空间。囊括进两个卷积的额外负荷,EMAUFLOPs仅相当于同样输入输出大小时$3 \times 3$卷积的$1/3$,参数量仅为$2C^2+CK$。

class EMAHead(nn.Module):
    def __init__(self,
                 ema_channels,
                 num_bases,
                 num_stages,
                 concat_input=True,
                 in_channels=2048,
                 channels=256,
                 momentum=0.1,
                 **kwargs):
        super(EMAHead, self).__init__(**kwargs)
        self.ema_channels = ema_channels
        self.num_bases = num_bases
        self.num_stages = num_stages
        self.concat_input = concat_input
        self.momentum = momentum
        self.ema_module = EMAModule(self.ema_channels, self.num_bases,
                                    self.num_stages, self.momentum)
 
        self.channels = channels
        self.in_channels = in_channels
        self.ema_in_conv = nn.Conv2d(
            self.in_channels,
            self.ema_channels,
            3,
            padding=1)
        self.ema_mid_conv = nn.Conv2d(
            self.ema_channels,
            self.ema_channels,
            1)
        
        for param in self.ema_mid_conv.parameters():
            param.requires_grad = False
 
        self.ema_out_conv = nn.Conv2d(
            self.ema_channels,
            self.ema_channels,
            1)
        self.bottleneck = nn.Conv2d(
            self.ema_channels,
            self.channels,
            3,
            padding=1)
        if self.concat_input:
            self.conv_cat = nn.Conv2d(
                self.in_channels + self.channels,
                self.channels,
                kernel_size=3,
                padding=1)
        self.attention = None

    def forward(self, inputs):
        """Forward function."""
        x = inputs
        feats = self.ema_in_conv(x)
        identity = feats
        feats = self.ema_mid_conv(feats)
        recon = self.ema_module(feats)
        self.attention = recon
        recon = F.relu(recon, inplace=True)
        recon = self.ema_out_conv(recon)
        output = F.relu(identity + recon, inplace=True)
        output = self.bottleneck(output)
        if self.concat_input:
            output = self.conv_cat(torch.cat([x, output], dim=1))
        return output

EMANet的完整结构如下:

class EMANet(nn.Module):
    def __init__(self, num_classes):
        super(EMANet, self).__init__()
        self.num_classes = num_classes
        self.backbone = ResNet.resnet50(replace_stride_with_dilation=[1,2,4])
        self.Head = EMAHead(in_channels=2048, channels=256, ema_channels=512, num_bases=64, num_stages=3, momentum=0.1)
        self.cls_seg = nn.Sequential(
            nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, self.num_classes, 3, padding=1),
        )
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.Head(x)
        x = self.cls_seg(x)
        return x