EncNet: 语义分割的上下文编码.
EncNet通过Context Encoding Module和Semantic Encoding Loss (SE-loss)来增加模型对上下文语义的理解。
引入Context Encoding Module(上下文编码模块)来捕捉全局信息的上下文信息,尤其是与场景相关联的类别信息。通过计算每个通道的缩放因子,来突出类别和类别相关的特征图,预测一组特征图的放缩因子作为循环用于突出需要强调的类别。
上下文编码模块包括一个编码层和一个通道注意力过程。对于预训练网络,使用编码层捕获特征图的统计信息作为全局上下文语义,将编码层的输出作为编码语义(encoded semantics)。为了使用上下文,预测一组放缩因子(scaling factors)用于突出和类别相关的特征图。编码层学习带有上下文语义的固有字典,输出丰富上下文信息的残差编码。
编码层学习包含$K$个codewords的固有字典 \(D=\{d_1,...,d_K\}\),和一组视觉中心平滑因子\(S=\{s_1,...,s_K\}\)。编码层输出残差编码,其中第$i$个空间位置的特征与第$k$个字典向量的残差计算为$r_{ik}=x_i-d_k$,残差编码通过沿字典维度的重加权构造:
\[e_{i k}=\frac{\exp \left(-s_k\left\|r_{i k}\right\|^2\right)}{\sum_{j=1}^K \exp \left(-s_j\left\|r_{i j}\right\|^2\right)} r_{i k}\]第$k$个字典向量的残差编码计算为$e_k=\sum_{i=1}^N e_{i k}$,总残差编码通过所有字典向量的残差编码的平均构造,并通过BN+ReLU增强。
class Encoding(nn.Module):
def __init__(self, channels, num_codes):
super(Encoding, self).__init__()
# init codewords and smoothing factor
self.channels, self.num_codes = channels, num_codes
std = 1. / ((num_codes * channels)**0.5)
# [num_codes, channels]
self.codewords = nn.Parameter(
torch.empty(num_codes, channels,
dtype=torch.float).uniform_(-std, std),
requires_grad=True)
# [num_codes]
self.scale = nn.Parameter(
torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
requires_grad=True)
def scaled_l2(self, x, codewords, scale):
num_codes, channels = codewords.size()
batch_size = x.size(0)
reshaped_scale = scale.view((1, 1, num_codes))
expanded_x = x.unsqueeze(2).expand(
(batch_size, x.size(1), num_codes, channels))
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
scaled_l2_norm = reshaped_scale * (
expanded_x - reshaped_codewords).pow(2).sum(dim=3)
return scaled_l2_norm
def aggregate(self, assigment_weights, x, codewords):
num_codes, channels = codewords.size()
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
batch_size = x.size(0)
expanded_x = x.unsqueeze(2).expand(
(batch_size, x.size(1), num_codes, channels))
encoded_feat = (assigment_weights.unsqueeze(3) *
(expanded_x - reshaped_codewords)).sum(dim=1)
return encoded_feat
def forward(self, x):
assert x.dim() == 4 and x.size(1) == self.channels
# [batch_size, channels, height, width]
batch_size = x.size(0)
# [batch_size, height x width, channels]
x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
# assignment_weights: [batch_size, channels, num_codes]
assigment_weights = F.softmax(self.scaled_l2(x, self.codewords, self.scale), dim=2)
# aggregate: [batch_size, num_codes, channels]
encoded_feat = self.aggregate(assigment_weights, x, self.codewords)
return encoded_feat
对总残差编码按照通道注意力的形式作用于预训练模型提取的特征,以进行特征增强:
class EncModule(nn.Module):
def __init__(self, in_channels, num_codes):
super(EncModule, self).__init__()
self.encoding_project = nn.Conv2d(
in_channels,
in_channels,
1,
)
self.encoding = nn.Sequential(
Encoding(channels=in_channels, num_codes=num_codes),
nn.BatchNorm1d(num_codes),
nn.ReLU(inplace=True))
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels), nn.Sigmoid())
def forward(self, x):
"""Forward function."""
encoding_projection = self.encoding_project(x)
encoding_feat = self.encoding(encoding_projection).mean(dim=1)
batch_size, channels, _, _ = x.size()
gamma = self.fc(encoding_feat)
y = gamma.view(batch_size, channels, 1, 1)
output = F.relu_(x + x * y)
return encoding_feat, output
SE loss实现对场景内类别的关注,迫使模型学习每个场景内可能会出现的类别,为模型提供一个先验知识。同时不同于像素级别的损失,SE loss对于不同大小的物体目标的计算方式是等同的,根据个体的类别来计算,这就使大物体和小物体在损失贡献上相同,有利于小目标的分割。
文章中还对backbone网络做了一部分改动,将backbone的最后两层网络的空洞卷积速率设为2和4。在第三层和第四层均可以输出一个SE loss。
class EncHead(nn.Module):
def __init__(self,num_classes=33,
num_codes=32,
use_se_loss=True,
**kwargs):
super(EncHead, self).__init__()
self.use_se_loss = use_se_loss
self.add_lateral = add_lateral
self.num_codes = num_codes
self.in_channels = [256, 512, 1024, 2048]
self.channels = 512
self.num_classes = num_classes
self.bottleneck = nn.Conv2d(
self.in_channels[-1],
self.channels,
3,
padding=1,
)
self.enc_module = EncModule(
self.channels,
num_codes=num_codes,
)
self.cls_seg = nn.Sequential(
nn.Conv2d(512, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 33, 3, padding=1)
)
if self.use_se_loss:
self.se_layer = nn.Linear(self.channels, self.num_classes)
def forward(self, inputs):
"""Forward function."""
feat = self.bottleneck(inputs[-1])
encode_feat, output = self.enc_module(feat)
output = nn.functional.interpolate(input = output, scale_factor=8, mode="bilinear")
output = self.cls_seg(output)
if self.use_se_loss:
se_output = self.se_layer(encode_feat)
return output, se_output
else:
return output
class ENCNet(nn.Module):
def __init__(self, num_classes):
super(ENCNet, self).__init__()
self.num_classes = num_classes
self.backbone = ResNet.resnet50()
self.decoder = EncHead()
def forward(self, x):
x = self.backbone(x)
x = self.decoder(x)
return x
SE loss在实现时,用来预测当前图像中所有可能存在的类别:
# 33是类别数, pred.shape[0]是batch_size的大小
exist_class = torch.FloatTensor([[1 if c in y[i_batch] else 0 for c in range(33)]
for i_batch in range(pred.shape[0])])
exist_class = exist_class.cuda()
se_output = net(X)[1]
l1 = nn.functional.mse_loss(se_output, exist_class)