SRM:一种基于风格的卷积神经网络重校准模块.
本文提出了一种基于风格的重校准模块(Style-based Recalibration Module, SRM),该模块通过利用中间特征图的风格来自适应地重新校准特征。SRM首先通过风格池化从特征图的每个通道中提取风格信息,然后通过与通道无关的风格集成来估计每个通道的重校准权重。通过把单个风格的相对重要性纳入特征图,SRM有效地增强了CNN的表示能力。
SRM的总体结构如图所示,由两个主要组件组成:风格池化(Style Pooling)和风格集成(Style Integration)。给定输入特征\(X \in \Bbb{R}^{N\times C \times H \times W}\),风格池化通过汇总跨空间维度的特征来从每个通道提取风格特征\(T \in \Bbb{R}^{N\times C \times d}\);风格集成通过基于通道的操作利用风格特征来生成特定于输入样本的风格权重\(G \in \Bbb{R}^{N\times C}\)。风格权重最终重新校准特征图,以强调或隐藏部分信息。该模块可以无缝集成到现代CNN架构中,并以端到端的方式进行训练。
风格池化将每个特征图的通道级统计信息(均值和标准差)用作风格特征($d = 2$):
风格集成通过通道级的全连接层把风格特征转换为风格权重,并通过BN和Sigmoid相应地强调或压抑与各个通道关联的风格的重要性。
class SRMLayer(nn.Module):
def __init__(self, channel, reduction=None):
# Reduction for compatibility with layer_block interface
super(SRMLayer, self).__init__()
# CFC: channel-wise fully connected layer
self.cfc = nn.Conv1d(channel, channel, kernel_size=2, bias=False,
groups=channel)
self.bn = nn.BatchNorm1d(channel)
def forward(self, x):
b, c, _, _ = x.size()
# Style pooling
# AvgPool(全局平均池化):
mean = x.view(b, c, -1).mean(-1).unsqueeze(-1)
# StdPool(全局标准池化)
std = x.view(b, c, -1).std(-1).unsqueeze(-1)
u = torch.cat((mean, std), -1) # (b, c, 2)
# Style integration
# CFC(全连接层)
z = self.cfc(u) # (b, c, 1)
# BN(归一化)
z = self.bn(z)
# Sigmoid
g = torch.sigmoid(z)
g = g.view(b, c, 1, 1)
return x * g.expand_as(x)