DIANet:密集的隐式注意力网络.
Dense-and-Implicit-Attention (DIA)在不同的网络层共享同一个注意力模块,以鼓励分层信息的集成。通过LSTM共享模块参数并捕获长距离依赖性。实验结果表明DIA-LSTM能够强调逐层特征的相关性,并显著提高图像分类精度,在稳定深度网络的训练方面具有强大的正则化能力。
作者对标准的LSTM结构进行修改,包括引入了瓶颈层以及对输出激活进行调整(tanh→sigmoid)。
class small_cell(nn.Module):
def __init__(self, input_size, hidden_size):
""""Constructor of the class"""
super(small_cell, self).__init__()
self.seq = nn.Sequential(nn.Linear(input_size, input_size // 4),
nn.ReLU(inplace=True),
nn.Linear(input_size // 4, 4 * hidden_size))
def forward(self,x):
return self.seq(x)
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, nlayers = 1, dropout = 0.1):
""""Constructor of the class"""
super(LSTMCell, self).__init__()
self.nlayers = nlayers
self.dropout = nn.Dropout(p=dropout)
ih, hh = [], []
for i in range(nlayers):
if i==0:
# ih.append(nn.Linear(input_size, 4 * hidden_size))
ih.append(small_cell(input_size, hidden_size))
# hh.append(nn.Linear(hidden_size, 4 * hidden_size))
hh.append(small_cell(hidden_size, hidden_size))
else:
ih.append(nn.Linear(hidden_size, 4 * hidden_size))
hh.append(nn.Linear(hidden_size, 4 * hidden_size))
self.w_ih = nn.ModuleList(ih)
self.w_hh = nn.ModuleList(hh)
def forward(self, input, hidden):
""""Defines the forward computation of the LSTMCell"""
hy, cy = [], []
for i in range(self.nlayers):
hx, cx = hidden[0][i], hidden[1][i]
gates = self.w_ih[i](input) + self.w_hh[i](hx)
i_gate, f_gate, c_gate, o_gate = gates.chunk(4, 1)
i_gate = torch.sigmoid(i_gate)
f_gate = torch.sigmoid(f_gate)
c_gate = torch.tanh(c_gate)
o_gate = torch.sigmoid(o_gate)
ncx = (f_gate * cx) + (i_gate * c_gate)
# nhx = o_gate * torch.tanh(ncx)
nhx = o_gate * torch.sigmoid(ncx)
cy.append(ncx)
hy.append(nhx)
input = self.dropout(nhx)
hy, cy = torch.stack(hy, 0), torch.stack(cy, 0) # number of layer * batch * hidden
return hy, cy
DIA-LSTM的代码可参考dia_resnet:
class Attention(nn.Module):
def __init__(self, ModuleList, channels):
super(Attention, self).__init__()
self.ModuleList = ModuleList
# self.lstm = nn.LSTMCell(channels, channels)
# self.sigmoid = nn.Sequential(nn.Linear(channels,channels), nn.Sigmoid())
self.lstm = LSTMCell(channels, channels, 1)
self.GlobalAvg = nn.AdaptiveAvgPool2d((1, 1))
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
for idx, layer in enumerate(self.ModuleList):
x, org = layer(x)
if idx == 0:
seq = self.GlobalAvg(x)
# list = seq.view(seq.size(0), 1, seq.size(1))
seq = seq.view(seq.size(0), seq.size(1)) # [b, c]
ht = torch.zeros(1, seq.size(0), seq.size(1)) # 1 mean number of layers
ct = torch.zeros(1, seq.size(0), seq.size(1))
ht, ct = self.lstm(seq, (ht, ct)) # 1 * batch size * length
# ht = self.sigmoid(ht)
x = x * (ht[-1].view(ht.size(1), ht.size(2), 1, 1))
x += org
x = self.relu(x)
else:
seq = self.GlobalAvg(x)
# list = torch.cat((list, seq.view(seq.size(0), 1, seq.size(1))), 1)
seq = seq.view(seq.size(0), seq.size(1))
ht, ct = self.lstm(seq, (ht, ct))
# ht = self.sigmoid(ht)
x = x * (ht[-1].view(ht.size(1), ht.size(2), 1, 1))
x += org
x = self.relu(x)
return x # , list