Implement SoftArgmax with Pytorch.

在编程时,有时候需要返回一个张量最大值所在的维度序号(如分类任务中返回概率最大的类别编号、定位任务中返回概率最大的空间坐标编号),此时需要用到argmax操作。

Pytorch中的argmax函数定义为torch.argmax(input, dim=None, keepdim=False),其中的dim参数指定寻找最大值的维度,keepdim参数指定是否保持原张量的维度。

如一个尺寸为$(3,4,5)$的三维张量,若设置dim=1,keepdim=False则输出张量的尺寸是$(3,5)$;若设置dim=1,keepdim=True则输出张量的尺寸是$(3,1,5)$。

由于argmax函数是不可导的,在构建网络时无法反向传播梯度。在实际构建网络时通常使用SoftArgmax函数作为替代。对于张量中的每一个位置$i$,做如下近似:

\[y=argmax(x) ≈ \sum_{i}^{} i \cdot softmax(x)_i\]

如对一个尺寸为$(batchNumber \times Channel \times Height \times Width \times Depth)$的三维空间张量寻找其最大值对应的坐标,返回尺寸为$(batchNumber \times Channel \times 3)$的张量,其Pytorch实现如下:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class SoftArgmax(nn.Module):
    def __init__(self, height, width, depth, channel):
        super(SoftArgmax, self).__init__()
        self.height = height
        self.width = width
        self.depth = depth
        self.channel = channel
        
        pos_x, pos_y, pos_z = np.meshgrid(
                np.linspace(-1., 1., self.height),
                np.linspace(-1., 1., self.width),
                np.linspace(-1., 1., self.depth)
                )
        pos_x = torch.from_numpy(pos_x.reshape(self.height*self.width*self.depth)).float()
        pos_y = torch.from_numpy(pos_y.reshape(self.height*self.width*self.depth)).float()
        pos_z = torch.from_numpy(pos_z.reshape(self.height*self.width*self.depth)).float()
        self.register_buffer('pos_x', pos_x)
        self.register_buffer('pos_y', pos_y)
        self.register_buffer('pos_z', pos_z)
        
        
    def forward(self, input):
        # input:  (N, C, H, W, D)
        # output: (N, C, 3)
        input = input.view(-1, self.height*self.width*self.depth)
        softmax_attention = F.softmax(input, dim=1)
        
        self.pos_x = self.pos_x.to(input.device)
        self.pos_y = self.pos_y.to(input.device)
        self.pos_z = self.pos_z.to(input.device)
        softmax_attention = softmax_attention.to(input.device)
        
        expected_x = torch.sum(self.pos_x*softmax_attention, dim=1, keepdim=True)
        expected_y = torch.sum(self.pos_y*softmax_attention, dim=1, keepdim=True)
        expected_z = torch.sum(self.pos_z*softmax_attention, dim=1, keepdim=True)
        expected_xyz = torch.cat([expected_x, expected_y, expected_z], 1)
        coordinates = expected_xyz.view(-1, self.channel, 3)
        return coordinates