EncNet: 语义分割的上下文编码.

EncNet通过Context Encoding ModuleSemantic Encoding Loss (SE-loss)来增加模型对上下文语义的理解。

引入Context Encoding Module(上下文编码模块)来捕捉全局信息的上下文信息,尤其是与场景相关联的类别信息。通过计算每个通道的缩放因子,来突出类别和类别相关的特征图,预测一组特征图的放缩因子作为循环用于突出需要强调的类别。

上下文编码模块包括一个编码层和一个通道注意力过程。对于预训练网络,使用编码层捕获特征图的统计信息作为全局上下文语义,将编码层的输出作为编码语义(encoded semantics)。为了使用上下文,预测一组放缩因子(scaling factors)用于突出和类别相关的特征图。编码层学习带有上下文语义的固有字典,输出丰富上下文信息的残差编码。

编码层学习包含$K$个codewords的固有字典 \(D=\{d_1,...,d_K\}\),和一组视觉中心平滑因子\(S=\{s_1,...,s_K\}\)。编码层输出残差编码,其中第$i$个空间位置的特征与第$k$个字典向量的残差计算为$r_{ik}=x_i-d_k$,残差编码通过沿字典维度的重加权构造:

\[e_{i k}=\frac{\exp \left(-s_k\left\|r_{i k}\right\|^2\right)}{\sum_{j=1}^K \exp \left(-s_j\left\|r_{i j}\right\|^2\right)} r_{i k}\]

第$k$个字典向量的残差编码计算为$e_k=\sum_{i=1}^N e_{i k}$,总残差编码通过所有字典向量的残差编码的平均构造,并通过BN+ReLU增强。

class Encoding(nn.Module):
    def __init__(self, channels, num_codes):
        super(Encoding, self).__init__()
        # init codewords and smoothing factor
        self.channels, self.num_codes = channels, num_codes
        std = 1. / ((num_codes * channels)**0.5)
        # [num_codes, channels]
        self.codewords = nn.Parameter(
            torch.empty(num_codes, channels,
                        dtype=torch.float).uniform_(-std, std),
            requires_grad=True)
        # [num_codes]
        self.scale = nn.Parameter(
            torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
            requires_grad=True)
        
    def scaled_l2(self, x, codewords, scale):
        num_codes, channels = codewords.size()
        batch_size = x.size(0)
        reshaped_scale = scale.view((1, 1, num_codes))
        expanded_x = x.unsqueeze(2).expand(
            (batch_size, x.size(1), num_codes, channels))
        reshaped_codewords = codewords.view((1, 1, num_codes, channels))
 
        scaled_l2_norm = reshaped_scale * (
            expanded_x - reshaped_codewords).pow(2).sum(dim=3)
        return scaled_l2_norm
 
    def aggregate(self, assigment_weights, x, codewords):
        num_codes, channels = codewords.size()
        reshaped_codewords = codewords.view((1, 1, num_codes, channels))
        batch_size = x.size(0)
 
        expanded_x = x.unsqueeze(2).expand(
            (batch_size, x.size(1), num_codes, channels))
        encoded_feat = (assigment_weights.unsqueeze(3) *
                        (expanded_x - reshaped_codewords)).sum(dim=1)
        return encoded_feat
 
    def forward(self, x):
        assert x.dim() == 4 and x.size(1) == self.channels
        # [batch_size, channels, height, width]
        batch_size = x.size(0)
        # [batch_size, height x width, channels]
        x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
 
        # assignment_weights: [batch_size, channels, num_codes]
        assigment_weights = F.softmax(self.scaled_l2(x, self.codewords, self.scale), dim=2)
        # aggregate: [batch_size, num_codes, channels]
        encoded_feat = self.aggregate(assigment_weights, x, self.codewords)
        return encoded_feat

对总残差编码按照通道注意力的形式作用于预训练模型提取的特征,以进行特征增强:

class EncModule(nn.Module):
    def __init__(self, in_channels, num_codes):
        super(EncModule, self).__init__()
        self.encoding_project = nn.Conv2d(
            in_channels,
            in_channels,
            1,
            )
        self.encoding = nn.Sequential(
            Encoding(channels=in_channels, num_codes=num_codes),
            nn.BatchNorm1d(num_codes),
            nn.ReLU(inplace=True))
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels), nn.Sigmoid())
 
    def forward(self, x):
        """Forward function."""
        encoding_projection = self.encoding_project(x)
        encoding_feat = self.encoding(encoding_projection).mean(dim=1)
        
        batch_size, channels, _, _ = x.size()
        gamma = self.fc(encoding_feat)
        y = gamma.view(batch_size, channels, 1, 1)
        output = F.relu_(x + x * y)
        return encoding_feat, output

SE loss实现对场景内类别的关注,迫使模型学习每个场景内可能会出现的类别,为模型提供一个先验知识。同时不同于像素级别的损失,SE loss对于不同大小的物体目标的计算方式是等同的,根据个体的类别来计算,这就使大物体和小物体在损失贡献上相同,有利于小目标的分割。

文章中还对backbone网络做了一部分改动,将backbone的最后两层网络的空洞卷积速率设为2和4。在第三层和第四层均可以输出一个SE loss

class EncHead(nn.Module):
    def __init__(self,num_classes=33,
                 num_codes=32,
                 use_se_loss=True,
                 **kwargs):
        super(EncHead, self).__init__()
        self.use_se_loss = use_se_loss
        self.add_lateral = add_lateral
        self.num_codes = num_codes
        self.in_channels = [256, 512, 1024, 2048]
        self.channels = 512
        self.num_classes = num_classes
        self.bottleneck = nn.Conv2d(
            self.in_channels[-1],
            self.channels,
            3,
            padding=1,
            )
            
        self.enc_module = EncModule(
            self.channels,
            num_codes=num_codes,
        )
        self.cls_seg = nn.Sequential(
            nn.Conv2d(512, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 33, 3, padding=1)
        )
        
        if self.use_se_loss:
            self.se_layer = nn.Linear(self.channels, self.num_classes)
 
    def forward(self, inputs):
        """Forward function."""
        feat = self.bottleneck(inputs[-1])
        encode_feat, output = self.enc_module(feat)
        output = nn.functional.interpolate(input = output, scale_factor=8, mode="bilinear")
        output = self.cls_seg(output)
        if self.use_se_loss:
            se_output = self.se_layer(encode_feat)
            return output, se_output
        else:
            return output
 
class ENCNet(nn.Module):
    def __init__(self, num_classes):
        super(ENCNet, self).__init__()
        self.num_classes = num_classes
        self.backbone = ResNet.resnet50()
        self.decoder = EncHead()
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.decoder(x)
        return x

SE loss在实现时,用来预测当前图像中所有可能存在的类别:

# 33是类别数, pred.shape[0]是batch_size的大小
exist_class = torch.FloatTensor([[1 if c in y[i_batch] else 0 for c in range(33)]
                for i_batch in range(pred.shape[0])])

exist_class = exist_class.cuda()
se_output = net(X)[1]

l1 = nn.functional.mse_loss(se_output, exist_class)