ODConv:全维动态卷积.

全维动态卷积 (Omni-Dimensional Dynamic Convolution, ODConv)通过并行策略采用多维注意力机制沿核空间的四个维度学习互补性注意力,以进一步构造动态卷积核。它可以嵌入到现有CNN网络中,并且可提升大模型的性能。

ODConv将动态卷积中一个维度上的动态特性进行了扩展,同时了考虑了空域、输入通道、输出通道等维度上的动态性。通常动态卷积核$w$的尺寸为$c_{out}\times c_{in} \times k \times k$,之前的方法构造动态卷积核的过程是对输入特征$x$通过通道注意力构造权重$\alpha_w \in \Bbb{R}^{N}$,然后融合$N$个卷积核:

\[\alpha_{w1}W_1+\cdots + \alpha_{wN}W_N\]

ODConv把上述标量注意力机制扩展为多维注意力机制,以对卷积核空间的四个维度学习更灵活的注意力。除了标量注意力$\alpha_w \in \Bbb{R}^{N}$,作者还引入了空间注意力$\alpha_s \in \Bbb{R}^{k \times k}$,输入通道注意力$\alpha_c \in \Bbb{R}^{c_{in}}$和输出通道注意力$\alpha_f \in \Bbb{R}^{c_{out}}$。这四种注意力的融合形式如下:

对于卷积核$W_i$,$\alpha_{si}$对每个卷积核的不同空间位置赋予不同的注意力值;$\alpha_{ci}$对每个卷积核的不同输入通道赋予不同的注意力值;$\alpha_{fi}$对不同输出通道的卷积核赋予不同的注意力值;$\alpha_{wi}$对$N$个整体卷积核赋予不同的注意力值。则ODConv中的卷积核融合最终可表示为:

\[(\alpha_{w1} \odot \alpha_{f1} \odot \alpha_{c1} \odot \alpha_{s1})W_1+\cdots + (\alpha_{wN} \odot \alpha_{fN} \odot \alpha_{cN} \odot \alpha_{sN})W_N\]

import functools
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _pair
from torch.nn.parameter import Parameter


class _coefficient(nn.Module):
    def __init__(self, in_channels, num_experts, out_channels, dropout_rate):
        super(_coefficient, self).__init__()
        self.num_experts = num_experts
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(in_channels, num_experts*out_channels)

    def forward(self, x):
        x = torch.flatten(x)
        x = self.dropout(x)
        x = self.fc(x)
        x = x.view(self.num_experts, -1)
        return torch.softmax(x, dim=0)
    
    
class DyNet2D(_ConvNd):
    r"""
    in_channels (int): Number of channels in the input image
    out_channels (int): Number of channels produced by the convolution
    kernel_size (int or tuple): Size of the convolving kernel
    stride (int or tuple, optional): Stride of the convolution. Default: 1
    padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
    padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
    dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
    groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
    bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
    num_experts (int): Number of experts per layer 
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros', num_experts=3, dropout_rate=0.2):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(DyNet2D, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias, padding_mode)

        # 全局平均池化
        self._avg_pooling = functools.partial(F.adaptive_avg_pool2d, output_size=(1, 1))
        # 注意力全连接层
        self._coefficient_fn = _coefficient(in_channels, num_experts, out_channels, dropout_rate)
        # 多套卷积层的权重
        self.weight = Parameter(torch.Tensor(
            num_experts, out_channels, in_channels // groups, *kernel_size))
        self.reset_parameters()

    def _conv_forward(self, input, weight):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                            weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)
    
    def forward(self, inputs): # [b, c, h, w]
        res = []
        for input in inputs:
            input = input.unsqueeze(0) # [1, c, h, w]
            pooled_inputs = self._avg_pooling(input) # [1, c, 1, 1]
            routing_weights = self._coefficient_fn(pooled_inputs) # [k,]
            kernels = torch.sum(routing_weights[: , :, None, None, None] * self.weight, 0)
            out = self._conv_forward(input, kernels)
            res.append(out)
        return torch.cat(res, dim=0)