DIANet:密集的隐式注意力网络.

Dense-and-Implicit-Attention (DIA)在不同的网络层共享同一个注意力模块,以鼓励分层信息的集成。通过LSTM共享模块参数并捕获长距离依赖性。实验结果表明DIA-LSTM能够强调逐层特征的相关性,并显著提高图像分类精度,在稳定深度网络的训练方面具有强大的正则化能力。

作者对标准的LSTM结构进行修改,包括引入了瓶颈层以及对输出激活进行调整(tanh→sigmoid)。

class small_cell(nn.Module):
    def __init__(self, input_size, hidden_size):
        """"Constructor of the class"""
        super(small_cell, self).__init__()
        self.seq = nn.Sequential(nn.Linear(input_size, input_size // 4),
                      nn.ReLU(inplace=True),
                      nn.Linear(input_size // 4, 4 * hidden_size))
    def forward(self,x):
        return self.seq(x)

class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, nlayers = 1, dropout = 0.1):
        """"Constructor of the class"""
        super(LSTMCell, self).__init__()

        self.nlayers = nlayers
        self.dropout = nn.Dropout(p=dropout)

        ih, hh = [], []
        for i in range(nlayers):
            if i==0:
                # ih.append(nn.Linear(input_size, 4 * hidden_size))
                ih.append(small_cell(input_size, hidden_size))
                # hh.append(nn.Linear(hidden_size, 4 * hidden_size))
                hh.append(small_cell(hidden_size, hidden_size))
            else:
                ih.append(nn.Linear(hidden_size, 4 * hidden_size))
                hh.append(nn.Linear(hidden_size, 4 * hidden_size))
        self.w_ih = nn.ModuleList(ih)
        self.w_hh = nn.ModuleList(hh)

    def forward(self, input, hidden):
        """"Defines the forward computation of the LSTMCell"""
        hy, cy = [], []
        for i in range(self.nlayers):
            hx, cx = hidden[0][i], hidden[1][i]
            gates = self.w_ih[i](input) + self.w_hh[i](hx)
            i_gate, f_gate, c_gate, o_gate = gates.chunk(4, 1)
            i_gate = torch.sigmoid(i_gate)
            f_gate = torch.sigmoid(f_gate)
            c_gate = torch.tanh(c_gate)
            o_gate = torch.sigmoid(o_gate)
            ncx = (f_gate * cx) + (i_gate * c_gate)
            # nhx = o_gate * torch.tanh(ncx)
            nhx = o_gate * torch.sigmoid(ncx)
            cy.append(ncx)
            hy.append(nhx)
            input = self.dropout(nhx)

        hy, cy = torch.stack(hy, 0), torch.stack(cy, 0)  # number of layer * batch * hidden
        return hy, cy

DIA-LSTM的代码可参考dia_resnet

class Attention(nn.Module):
    def __init__(self, ModuleList, channels):
        super(Attention, self).__init__()
        self.ModuleList = ModuleList

        # self.lstm = nn.LSTMCell(channels, channels)
        # self.sigmoid = nn.Sequential(nn.Linear(channels,channels), nn.Sigmoid())
        self.lstm = LSTMCell(channels, channels, 1)

        self.GlobalAvg = nn.AdaptiveAvgPool2d((1, 1))
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        for idx, layer in enumerate(self.ModuleList):
            x, org = layer(x)
            if idx == 0:
                seq = self.GlobalAvg(x)
                # list = seq.view(seq.size(0), 1, seq.size(1))
                seq = seq.view(seq.size(0), seq.size(1)) # [b, c]
                ht = torch.zeros(1, seq.size(0), seq.size(1))  # 1 mean number of layers
                ct = torch.zeros(1, seq.size(0), seq.size(1))
                ht, ct = self.lstm(seq, (ht, ct))  # 1 * batch size * length
                # ht = self.sigmoid(ht)
                x = x * (ht[-1].view(ht.size(1), ht.size(2), 1, 1))
                x += org
                x = self.relu(x)
            else:
                seq = self.GlobalAvg(x)
                # list = torch.cat((list, seq.view(seq.size(0), 1, seq.size(1))), 1)
                seq = seq.view(seq.size(0), seq.size(1))
                ht, ct = self.lstm(seq, (ht, ct))
                # ht = self.sigmoid(ht)
                x = x * (ht[-1].view(ht.size(1), ht.size(2), 1, 1))
                x += org
                x = self.relu(x)
        return x # , list