ULSAM:超轻量级子空间注意力机制.
本文作者提出了一种简单而有效的 “超轻量级子空间注意力机制”(Ultra-Lightweight Subspace Attention Mechanism,ULSAM),它为每个特征子空间分别学习不同的注意力特征图,可以实现多尺度和多频率的特征表示,有利于精细化的图像分类。
ULSAM对输入特征进行分组,对每组子特征(对应一个特征子空间)通过深度可分离卷积构造空间注意力分布,进行空间上的重新校准;最后把所有特征连接作为输出特征。
class SubSpace(nn.Module):
def __init__(self, nin: int) -> None:
super(SubSpace, self).__init__()
self.conv_dws = nn.Conv2d(
nin, nin, kernel_size=1, stride=1, padding=0, groups=nin
)
self.bn_dws = nn.BatchNorm2d(nin, momentum=0.9)
self.relu_dws = nn.ReLU(inplace=False)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
self.conv_point = nn.Conv2d(
nin, 1, kernel_size=1, stride=1, padding=0, groups=1
)
self.bn_point = nn.BatchNorm2d(1, momentum=0.9)
self.relu_point = nn.ReLU(inplace=False)
self.softmax = nn.Softmax(dim=2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.conv_dws(x)
out = self.bn_dws(out)
out = self.relu_dws(out)
out = self.maxpool(out)
out = self.conv_point(out)
out = self.bn_point(out)
out = self.relu_point(out)
m, n, p, q = out.shape
out = self.softmax(out.view(m, n, -1))
out = out.view(m, n, p, q)
out = out.expand(x.shape[0], x.shape[1], x.shape[2], x.shape[3])
out = torch.mul(out, x)
out = out + x
return out
class ULSAM(nn.Module):
def __init__(self, nin: int, num_splits: int) -> None:
super(ULSAM, self).__init__()
assert nin % num_splits == 0
self.nin = nin
self.num_splits = num_splits
self.subspaces = nn.ModuleList(
[SubSpace(int(self.nin / self.num_splits)) for i in range(self.num_splits)]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# split at batch dimension
sub_feat = torch.chunk(x, self.num_splits, dim=1)
out = []
for idx, l in enumerate(self.subspaces):
out.append(self.subspaces[idx](sub_feat[idx]))
out = torch.cat(out, dim=1)
return out