通过非局部稀疏注意力实现图像超分辨率.
Non-local能够捕捉全局的特征信息,但是计算量非常大。本文提出了Non-Local Sparse Attention (NLSA),将深度特征图的像素进行分类,同一个attention bucket里面的像素有相近的内容;可以减少运算量至接近线性,增加*Non-local的有效性。
Non-local在计算输出位置$i$的响应$y_i$时,考虑所有输入值特征$h(x_j)$的加权:
\[y_i= \sum_{j}^{} \alpha_{j}h(x_j) = \sum_{j}^{} \frac{e^{f(x_i)^Tg(x_j)}}{\sum_k e^{f(x_i)^Tg(x_k)}} h(x_j)\]NLSA将其表达为稀疏的形式:
\[y_i= \sum_{j \in \delta_i} \frac{e^{f(x_i)^Tg(x_j)}}{\sum_{k\in \delta_i} e^{f(x_i)^Tg(x_k)}} h(x_j)\]其中$\delta_i$是指Attention bucket,集合了该像素需要遍历搜索的其他的点,限制non-local attention的搜索范围。
Attention bucket是通过局部敏感哈希 (LSH)构造的。LSH函数将一个张量投影到超球体上,并选择最近的多边形顶点作为其hash code。对于向量$x$,首先将其转化为单位向量;然后乘以旋转矩阵$A \in R^{c \times m}$
\[\hat{x} = A\frac{x}{||x||_2}\]hash函数把使得\(\hat{x}\)达到最大值的$i$的值为hash值:
\[h(x) = \mathcal{\arg\max}_i \hat{x}\]bucket的确定:
\[\delta_i = \{ j | h(x_j)=h(x_i) \}\]class NonLocalSparseAttention(nn.Module):
def __init__( self, n_hashes=4, channels=64, k_size=3, reduction=4, chunk_size=144, conv=common.default_conv, res_scale=1):
super(NonLocalSparseAttention,self).__init__()
self.chunk_size = chunk_size#每个chunk有144个元素的hash值
self.n_hashes = n_hashes#hash值为4维
self.reduction = reduction
self.res_scale = res_scale
self.conv_match = common.BasicBlock(conv, channels, channels//reduction, k_size, bn=False, act=None)
self.conv_assembly = common.BasicBlock(conv, channels, channels, 1, bn=False, act=None)
# 作为对比 标准Non-local如下
# self.conv_match1 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU())
# self.conv_match2 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act = nn.PReLU())
# self.conv_assembly = common.BasicBlock(conv, channel, channel, 1,bn=False, act=nn.PReLU())
def LSH(self, hash_buckets, x):
#x: [N,H*W,C]
N = x.shape[0]#batch size
device = x.device
#generate random rotation matrix
rotations_shape = (1, x.shape[-1], self.n_hashes, hash_buckets//2) #[1,C,n_hashes,hash_buckets//2]
random_rotations = torch.randn(rotations_shape, dtype=x.dtype, device=device).expand(N, -1, -1, -1) #[N, C, n_hashes, hash_buckets//2]
#locality sensitive hashing [n hw c]*[N, C, n_hashes, hash_buckets//2]
rotated_vecs = torch.einsum('btf,bfhi->bhti', x, random_rotations) #[N, n_hashes, H*W, hash_buckets//2],把channel维度融掉了(hw乘以其对应的数进行旋转),对应于论文流程图中的求和步骤
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) #[N, n_hashes, H*W, hash_buckets]
#get hash codes
hash_codes = torch.argmax(rotated_vecs, dim=-1) #[N, n_hashes, H*W, hash_buckets]->[N,n_hashes,H*W]求得每个hash bucket中最大的值的位置 作为该feature map像素点的hash值
#add offsets to avoid hash codes overlapping between hash rounds 加了一点偏移量,防止hash code重叠
offsets = torch.arange(self.n_hashes, device=device) #生成【0,1,2,3】数组
offsets = torch.reshape(offsets * hash_buckets, (1, -1, 1)) #【0,1*hb,3*hb,3*hb】 形状是(1,4,1)
hash_codes = torch.reshape(hash_codes + offsets, (N, -1,)) #[N,n_hashes(这个维度和offsets一样),H*W]->[N,n_hashes*H*W]
return hash_codes
def add_adjacent_buckets(self, x):
#这个函数用于把相邻的bucket相连
x_extra_back = torch.cat([x[:,:,-1:, ...], x[:,:,:-1, ...]], dim=2)#把倒数第一行移到了第一行的位置 相当于向下移动一行
x_extra_forward = torch.cat([x[:,:,1:, ...], x[:,:,:1,...]], dim=2)#把第一行移到了倒数第一行的位置 相当于向上移动一行
return torch.cat([x, x_extra_back,x_extra_forward], dim=3)#将这三个东西沿着行的方向进行拼接
#这个操作十分巧妙地将第i组 第i-1和i+1组放在了一行里面 拼接了这三个组
def forward(self, input):
N,_,H,W = input.shape
x_embed = self.conv_match(input).view(N,-1,H*W).contiguous().permute(0,2,1)#channel数/4了 [N,h*w,c/4]
y_embed = self.conv_assembly(input).view(N,-1,H*W).contiguous().permute(0,2,1)#channel数没有变 [N,h*w,c]
L,C = x_embed.shape[-2:] #L是H*W,且C是channel/4
#number of hash buckets/hash bits 计算有多少个桶呢 最多128个
hash_buckets = min(L//self.chunk_size + (L//self.chunk_size)%2, 128)#保障hash_buckets(bucket的数量)是偶数
#get assigned hash codes/bucket number
hash_codes = self.LSH(hash_buckets, x_embed) #[N,n_hashes*H*W]
hash_codes = hash_codes.detach()#计算过程不需要反向传播
#group elements with same hash code by sorting
_, indices = hash_codes.sort(dim=-1) #[N,n_hashes*H*W] sort以升序排列,返回值为value-tensor和indice-tensor
_, undo_sort = indices.sort(dim=-1) #undo_sort to recover original order
#这里返回的是 【N,n_hashes*H*W】这一次返回值是原来的hash_codes中每一个值它的大小在整个序列里面的排名,如果给了这个序列按顺序排列的结果,那可以根据这个undo-sort列表,还原出原始的序列来。
mod_indices = (indices % L) #now range from (0->H*W)
x_embed_sorted = common.batched_index_select(x_embed, mod_indices) #[N,n_hashes*H*W,C]
y_embed_sorted = common.batched_index_select(y_embed, mod_indices) #[N,n_hashes*H*W,C*4]
# def batched_index_select(values, indices):
# last_dim = values.shape[-1]
# return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))
#None的作用是在最后增加一维,类似于np.newaxis
#pad the embedding if it cannot be divided by chunk_size
padding = self.chunk_size - L%self.chunk_size if L%self.chunk_size!=0 else 0
x_att_buckets = torch.reshape(x_embed_sorted, (N, self.n_hashes,-1, C)) #[N, n_hashes, H*W,C]
y_att_buckets = torch.reshape(y_embed_sorted, (N, self.n_hashes,-1, C*self.reduction))
if padding:
pad_x = x_att_buckets[:,:,-padding:,:].clone()
pad_y = y_att_buckets[:,:,-padding:,:].clone()
x_att_buckets = torch.cat([x_att_buckets,pad_x],dim=2)
y_att_buckets = torch.cat([y_att_buckets,pad_y],dim=2)#把最后几个作为pad来补足
x_att_buckets = torch.reshape(x_att_buckets,(N,self.n_hashes,-1,self.chunk_size,C)) #[N, n_hashes, num_chunks, chunk_size, C]
y_att_buckets = torch.reshape(y_att_buckets,(N,self.n_hashes,-1,self.chunk_size, C*self.reduction))
x_match = F.normalize(x_att_buckets, p=2, dim=-1,eps=5e-5)#L2归一化
#[N, n_hashes, num_chunks, chunk_size, C]
#allow attend to adjacent buckets
#论文中We then apply the Non-Local (NL) operation within the bucket that the query pixel belongs to, or across adjacent buckets after sorting.
#为了可以搜索相邻的组
x_match = self.add_adjacent_buckets(x_match)
y_att_buckets = self.add_adjacent_buckets(y_att_buckets)
#unormalized attention score
raw_score = torch.einsum('bhkie,bhkje->bhkij', x_att_buckets, x_match) #[N, n_hashes, num_chunks, chunk_size, chunk_size*3]
#softmax
bucket_score = torch.logsumexp(raw_score, dim=-1, keepdim=True)#logsumexp实际上是针对max函数的一种平滑操作
score = torch.exp(raw_score - bucket_score) #(after softmax)
bucket_score = torch.reshape(bucket_score,[N,self.n_hashes,-1])
#attention
ret = torch.einsum('bukij,bukje->bukie', score, y_att_buckets) #[N, n_hashes, num_chunks, chunk_size, C]
ret = torch.reshape(ret,(N,self.n_hashes,-1,C*self.reduction))
#if padded, then remove extra elements
if padding:
ret = ret[:,:,:-padding,:].clone()
bucket_score = bucket_score[:,:,:-padding].clone()
#recover the original order
ret = torch.reshape(ret, (N, -1, C*self.reduction)) #[N, n_hashes*H*W,C]
bucket_score = torch.reshape(bucket_score, (N, -1,)) #[N,n_hashes*H*W]
ret = common.batched_index_select(ret, undo_sort)#[N, n_hashes*H*W,C]
bucket_score = bucket_score.gather(1, undo_sort)#[N,n_hashes*H*W]
#weighted sum multi-round attention
ret = torch.reshape(ret, (N, self.n_hashes, L, C*self.reduction)) #[N, n_hashes, H*W,C]
bucket_score = torch.reshape(bucket_score, (N, self.n_hashes, L, 1)) #[N, n_hashes, H*W,1]
probs = nn.functional.softmax(bucket_score,dim=1)
ret = torch.sum(ret * probs, dim=1)
ret = ret.permute(0,2,1).view(N,-1,H,W).contiguous()*self.res_scale+input
return ret