Using Fast Fourier Transform (FFT) to Speed Up Convolution Operations.

本文目录:

  1. 卷积运算 Convolution
  2. 傅里叶变换 Fourier Transform
  3. 使用FFT加速深度学习中的卷积运算

1. 卷积运算 Convolution

(1)连续形式的卷积

卷积 (Convolution) 运算是一种通过两个函数$\displaystyle f$和$\displaystyle g$生成第三个函数的数学算子,定义为:

\[(f*g)(t)=\int_{-\infty}^{\infty}f(\tau)g(t-\tau)d\tau\]

卷积表征函数$\displaystyle f$与经过翻转和平移的$\displaystyle g$的乘积函数所围成的曲边梯形的面积。如果将参加卷积的一个函数看作区间的指示函数,卷积可以被看作是“移动平均”的推广。

⚪ 讨论:卷积与互相关

互相关 (Cross-Correlation) 运算也是一种通过两个函数$\displaystyle f$和$\displaystyle g$生成第三个函数的数学算子定义为:

\[(f\circledast g)(t)=\int_{-\infty}^{\infty}\overline{f(\tau-t)}g(\tau)d\tau\]

函数$\displaystyle f(t)$和$\displaystyle g(t)$的互相关等价于函数$\displaystyle f(-t)$的共轭复数和$\displaystyle g(t)$的卷积:

\[(f\circledast g)(t) = \overline{\displaystyle f(-t)} * g(t)\]

在实数域中,卷积和互相关的主要区别是卷积具有对其中一个函数的翻转操作,该操作保证了卷积满足交换律

\[\begin{aligned} (f*g)(t)&=\int_{-\infty}^{\infty}f(\tau)g(t-\tau)d\tau \\ (\text{let} &\quad x=t-\tau)\\ &=-\int_{\infty}^{-\infty}f(t-x)g(x)dx \\ &=\int_{-\infty}^{\infty}g(\tau)f(t-\tau)d\tau =(g*f)(t) \end{aligned}\]

卷积神经网络中,卷积实际指的是互相关操作。一方面,卷积神经网络的操作是使用函数$\displaystyle g(t)$(称为卷积核)对$\displaystyle f(t)$(称为输入特征)进行“卷积”,因此交换律是不必要的;另一方面,卷积神经网络的卷积核参数是可学习的,(如果需要)也可以学习到翻转的参数,没有必要显式地进行翻转。

(2)离散形式的卷积

a. 线性卷积 Linear Convolution

线性卷积参与运算的是两个有限长度的非周期序列。对于长度为$M$的序列$x$和长度为$N$的序列$h$,线性卷积的输出序列$y$的长度为$M+N-1$,定义为:

\[y[n] = \sum_{k=-\infty}^{\infty}x[k]\cdot h[n-k]\]

b. 循环卷积 Circular Convolution

循环卷积也叫圆周卷积,参与运算的序列$h$是周期序列(序列$x$通常是有限的非周期序列,即卷积核)。

计算循环卷积时需要先给定一个周期$T$(通常是序列$h$的周期),然后将两个信号的长度都扩展为$T$,进行卷积运算后取出主值区间(即区间$[0,T]$)的结果;

\[y[n] = \sum_{k=0}^{T-1}x[k]\cdot h_T[n-k] R_T[n]\]

其中$R_T[n]$是周期为$T$的矩形序列。循环卷积的输出长度为周期$T$。

⚪ 讨论:线性卷积与循环卷积的关系

循环卷积可以看作线性卷积以$T$为周期进行周期延拓后取主值区间的结果。证明如下:

\[\begin{aligned} y_{\text{CircularConv}}[n] &= \sum_{k=0}^{T-1}x[k]\cdot h_T[n-k] R_T[n] \\ &= \sum_{k=0}^{T-1}x[k]\cdot \sum_{r=-\infty}^{\infty}h[n+rT-k] R_T[n] \\ &= \sum_{r=-\infty}^{\infty}\left(\sum_{k=0}^{T-1}x[k]\cdot h[n+rT-k]\right) R_T[n] \\ &= \sum_{r=-\infty}^{\infty}y_{\text{LinearConv}}[n+rT] R_T[n] \\ \end{aligned}\]

线性卷积与循环卷积的可视化如下:

注意到在上述例子中,线性卷积与循环卷积的结果不同。这是因为要让两个有限长度的非周期序列的线性卷积与循环卷积计算结果一致,则循环卷积的延拓周期$T$与两个序列的长度$M,N$需满足关系:

\[T \geq M+N-1\]

若不满足上述关系,线性卷积在做周期延拓时会有时域上的混叠。下面举一个简单的例子,假设序列$x$和序列$h$的长度都为$2$,以周期$T=2$进行周期延拓,则线性卷积和循环卷积的结果如下:

结果表明,在做循环卷积的第一个和最后一个有效输出时,分别有上一周期和下一周期的序列参与了运算,导致了时域上的混叠。

若以周期$T=3$进行周期延拓(通过在序列结尾padding适当的$0$实现),则线性卷积和循环卷积的结果如下:

此时,线性卷积和循环卷积的有效结果一致(输出需要丢弃尾部的padding部分)。

2. 傅里叶变换 Fourier Transform

傅里叶变换(Fourier transform, FT)是一种线性变换,通常定义为一种积分变换。其基本思想是一个函数可以用无穷多个周期函数的线性组合来逼近。这些组合系数在保有原函数的几乎全部信息的同时,还直接地反映了该函数的“频域特征”。

(1)傅里叶变换的形式

傅里叶变换实现了时域函数$f(x)$与频域函数$\hat{f}(k)$之间的转换。其形式包括:

函数在时(频)域的离散对应于其像函数在频(时)域的周期性。反之连续则对应于其像函数的非周期性。

⚪ 连续傅里叶变换

连续傅里叶变换把时域的连续、非周期性函数转换成频域的连续、非周期性函数:

\[\mathcal{F}(f)(k) = \hat{f}(k) = \int_{-\infty}^{\infty} f(x) e^{-2\pi i k x} dx \\ \mathcal{F}^{-1}(\hat{f})(x) = f(x) = \int_{-\infty}^{\infty} \hat{f}(k) e^{2\pi i x k} dk\]

⚪ 傅里叶级数(Fourier series)

傅里叶级数把时域的连续、周期性函数转换成频域的离散、非周期性函数:

\[\begin{aligned} f(x) &\sim A_0 + \sum_{n=1}^\infty \left( A_n\cos \left(\frac{2\pi n x}{T}\right) + B_n\sin \left(\frac{2\pi n x}{T}\right)\right) \\ A_0 &= \frac{1}{T} \int_{-T/2}^{T/2} f(x) dx \\ A_n &= \frac{2}{T} \int_{-T/2}^{T/2} f(x) \cos \left(\frac{2\pi n x}{T}\right) dx \\ B_n &= \frac{2}{T} \int_{-T/2}^{T/2} f(x) \sin \left(\frac{2\pi n x}{T}\right) dx \end{aligned}\]

⚪ 离散时间傅里叶变换(DTFT)

离散时间傅里叶变换把时域的离散、非周期性函数转换成频域的连续、周期性函数:

\[\mathcal{F}(f)[k] = \hat{f}[k] = \sum_{x=-\infty}^{\infty} f[x] e^{- i k x}\]

DTFT可以被看作是傅里叶级数的逆变换。

⚪ 离散傅里叶变换(DFT)

离散傅里叶变换把时域的离散、周期性函数转换成频域的离散、周期性函数:

\[\mathcal{F}(f)[k] = \hat{f}[k] = \sum_{x=0}^{T-1} f[x] e^{-\frac{2\pi i k x}{T}}, \quad k=0,1,...,T-1 \\ \mathcal{F}^{-1}(\hat{f})[x] = f[x] =\frac{1}{T} \int_{k=0}^{T-1} \hat{f}[k] e^{\frac{2\pi i x k}{T}} , \quad x=0,1,...,T-1\]

在频域上,DFT的离散谱是对DTFT连续谱的等间隔采样。

(2)卷积定理

卷积定理指出,在适当的条件下,函数卷积(乘积)的傅里叶变换是函数傅里叶变换的乘积(卷积)。即一个域中的卷积对应于另一个域中的乘积:

\[\mathcal{F}(f * g) = \mathcal{F}(f) \cdot \mathcal{F}(g) \\ \mathcal{F}(f \cdot g) = \mathcal{F}(f) *\mathcal{F}(g)\]

以连续傅里叶变换为例,证明如下:

\[\begin{aligned} \mathcal{F}(f * g) &= \int_{-\infty}^{\infty} (f*g)(x) e^{-2\pi i k x} dx \\ &= \int_{-\infty}^{\infty} \int_{-\infty}^{\infty}f(\tau)g(x-\tau)e^{-2\pi i k x} d\tau dx \\ (\text{let} \quad & y = x - \tau) \\ &= \int_{-\infty}^{\infty} \int_{-\infty}^{\infty}f(\tau)g(y)e^{-2\pi i k (y+\tau)} d\tau dy \\ &= \left[\int_{-\infty}^{\infty}f(\tau)e^{-2\pi i k \tau} d\tau\right] \left[ \int_{-\infty}^{\infty}g(y)e^{-2\pi i k y} dy \right] \\ &= \mathcal{F}(f) \cdot \mathcal{F}(g) \end{aligned}\]

(3)快速傅里叶变换 Fast Fourier Transform

快速傅里叶变换 (FFT) 是计算DFT的高效、快速方法,采用这种算法能使计算机计算DFT所需要的乘法次数大大减少,从而把DFT的$\mathcal{O}(n^2)$计算复杂度降低为$\mathcal{O}(n \log n)$。

⚪ Cooley-Tukey算法

假设$T$是$2$的整数幂,把DFT按照时域的奇数点和偶数点进行拆分:

\[\begin{aligned} \hat{f}[k] &= \sum_{x=0}^{T-1} f[x] e^{-\frac{2\pi i k x}{T}} \\ &= \sum_{n=0}^{T/2-1} f[2n] e^{-\frac{2\pi i k (2n)}{T}} + \sum_{x=0}^{T/2-1} f[2n+1] e^{-\frac{2\pi i k (2n+1)}{T}} \\ &= \sum_{n=0}^{T/2-1} f[2n] e^{-\frac{2\pi i k (2n)}{T}} + e^{-\frac{2\pi i k }{T}}\sum_{x=0}^{T/2-1} f[2n+1] e^{-\frac{2\pi i k (2n)}{T}} \\ (\text{let} \quad & W_T = e^{-\frac{2\pi i}{T}}) \\ &= \sum_{n=0}^{T/2-1} f[2n] W_T^{2k} + W_T^k\sum_{x=0}^{T/2-1} f[2n+1] W_T^{2k} \\ (\text{let} \quad & A_1(x)= \sum_{n=0}^{T/2-1} f[2n] x, \quad A_2(x)= \sum_{n=0}^{T/2-1} f[2n+1] x) \\ &= A_1(W_T^{2k}) + W_T^kA_2(W_T^{2k}), \quad k=0,1,...,T-1 \\ \end{aligned}\]

并且注意到:

\[\begin{aligned} &W_T^{2k}= e^{-\frac{2\pi i(2k)}{T}}= e^{-\frac{2\pi ik}{T/2}} = W_{\frac{T}{2}}^k \\ &W_T^{k+\frac{T}{2}}= e^{-\frac{2\pi i(k+\frac{T}{2})}{T}}= e^{-\frac{2\pi ik}{T}}\cdot e^{\pi i} = -W_T^k \\ &W_T^T= e^{-\frac{2\pi iT}{T}}= e^{-2\pi i} = 1 \\ \end{aligned}\]

下面分情况讨论。

\[\begin{aligned} \hat{f}[k]&= A_1(W_T^{2k}) + W_T^kA_2(W_T^{2k}) \\ &= A_1(W_{\frac{T}{2}}^k) + W_T^kA_2(W_{\frac{T}{2}}^k) \\ \end{aligned}\] \[\begin{aligned} \hat{f}[k]&= A_1(W_T^{2k}) + W_T^kA_2(W_T^{2k}) \\ &= A_1(W_T^{2(k+\frac{T}{2})}) + W_T^{k+\frac{T}{2}}A_2(W_T^{2(k+\frac{T}{2})}) \\ &= A_1(W_T^{2k}W_T^T) + W_T^{k+\frac{T}{2}}A_2(W_T^{2k}W_T^T)\\ &= A_1(W_{\frac{T}{2}}^k) -W_T^kA_2(W_{\frac{T}{2}}^k)\\ \end{aligned}\]

因此如果能够求出$A_1(W_{\frac{T}{2}}^k)$和$A_2(W_{\frac{T}{2}}^k)$,则可以$O(1)$的时间复杂度求出$\hat{f}[k]$。而求$A_1(W_{\frac{T}{2}}^k)$和$A_2(W_{\frac{T}{2}}^k)$又可以递归地转换为求$A_{11}(W_{\frac{T}{4}}^k),A_{12}(W_{\frac{T}{4}}^k)$和$A_{21}(W_{\frac{T}{4}}^k),A_{22}(W_{\frac{T}{4}}^k)$。递归的终止条件为:

\[A_{\cdots}(W^k_1) = e^{-2\pi i} = 1\]

⚪ 蝶形运算

下面以一个长度为$N=8$的序列$x[0],x[1],…,x[7]$为例,说明FFT运算的过程。

假设已经求出了$A_1(W_{\frac{T}{2}}^k)$(即使用$x[0],x[2],x[4],x[6]$求出了$G[0],G[1],G[2],G[3]$)和$A_2(W_{\frac{T}{2}}^k)$(即使用$x[1],x[3],x[5],x[7]$求出了$H[0],H[1],H[2],H[3]$)。则最终结果计算为:

\[\begin{aligned} X[k]&= \begin{cases} G[k] + W_N^kH[k], & k=0,1,2,3 \\ G[k] - W_N^kH[k], & k=4,5,6,7 \end{cases} \end{aligned}\]

上图中红框部分(即计算$A_1(W_{\frac{T}{2}}^k)$)又可以递归地计算为(相当于计算$A_1(W_{\frac{T}{4}}^k),A_2(W_{\frac{T}{4}}^k)$):

上图中红框部分(即计算$A_1(W_{\frac{T}{4}}^k)$满足递归的终止条件:

该序列完整的FFT运算过程如下:

由于每一个递归单元中的数据连接关系像一个蝴蝶🦋,因此FFT运算也被称作“蝶形运算”。

一个长度为$N$的FFT需要 $\frac{N}{2} \log_2 N$ 次复数乘法(每个复数乘法对应$4$次实数乘法)和 $N \log_2 N$ 次复数加法(每个复数加法对应$2$次实数加法)。因此:

⚪ iFFT

FFT算法的逆运算即离散傅里叶逆变换(IDFT)。对于DFT

\[\begin{aligned} \hat{f}[k] &= \sum_{x=0}^{T-1} f[x] W_T^{kx} \\ W_T &= e^{-\frac{2\pi i}{T}} \end{aligned}\]

如果用矩阵封装上述过程,则可以表示为:

\[\begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & W_T & W_T^2 & \cdots & W_T^{T-1} \\ 1 & W_T^2 & W_T^4 & \cdots & W_T^{2(T-1)} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & W_T^{T-1} & W_T^{2(T-1)} & \cdots & W_T^{(T-1)^2} \\ \end{bmatrix} \begin{bmatrix} f[0] \\ f[1] \\ f[2] \\ \vdots \\ f[T-1] \\ \end{bmatrix} = \begin{bmatrix} \hat{f}[0] \\ \hat{f}[1] \\ \hat{f}[2] \\ \vdots \\ \hat{f}[T-1] \\ \end{bmatrix} \\ \downarrow \\ WF=\hat{F}\]

IDFT即寻找一个矩阵$W^{-1}$,从而求解:

\[F=W^{-1}\hat{F}\]

注意到:

\[\begin{aligned} &\begin{bmatrix} 1 & W_T^k & W_T^{2k} & \cdots & W_T^{(T-1)k} \\ \end{bmatrix} \begin{bmatrix} 1 \\ W_T^{(T-1)l} \\ W_T^{(T-2)l} \\ \vdots \\ W_T^l \\ \end{bmatrix} \\ &= \sum_{t=0}^{T-1} W_T^{kt} W_T^{l(T-t)} \\ &= \sum_{t=0}^{T-1} e^{-\frac{2\pi ikt}{T}}e^{-\frac{2\pi il(T-t)}{T}}= \sum_{t=0}^{T-1} e^{-\frac{2\pi i(k-l)t}{T}}\\ &= \begin{cases} T, & k=l \\ 0, & k \neq l \\ \end{cases} \end{aligned}\]

因此,$W^{-1}$可以表示为:

\[\frac{1}{T}\begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & W_T^{T-1} & W_T^{2(T-1)} & \cdots & W_T^{(T-1)(T-1)} \\ 1 & W_T^{T-2} & W_T^{2(T-2)} & \cdots & W_T^{(T-2)(T-1)} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & W_T^1 & W_T^2 & \cdots & W_T^{T-1} \\ \end{bmatrix}\]

此外注意到:

\[\begin{aligned} \overline{W_T^k} &=\overline{e^{-\frac{2\pi ik}{T}}} = \cos\frac{2k\pi}{T} - i\sin \frac{2k\pi}{T} \\ &= \cos\left(2\pi-\frac{2k\pi}{T}\right) + i\sin\left(2\pi-\frac{2k\pi}{T}\right) \\ &= e^{-i(2\pi-\frac{2k\pi}{T})}= e^{-\frac{2\pi i(T-k)}{T}} \\ &= W_T^{T-k} \end{aligned}\]

因此,$W^{-1}$进一步化简为:

\[\frac{1}{T}\begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & \overline{W_T^1} & \overline{W_T^2} & \cdots & \overline{W_T^{T-1}} \\ 1 & \overline{W_T^2} & \overline{W_T^4} & \cdots & \overline{W_T^{2(T-1)}} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \overline{W_T^{T-1}} & \overline{W_T^{2(T-1)}} & \cdots & \overline{W_T^{(T-1)(T-1)}} \\ \end{bmatrix}\]

注意到$W^{-1}$和$W$的区别是将系数$W_T^k$取共轭复数,并除以$T$。因此实现iFFT,只需要在FFT的分治过程中乘上的系数变为其共轭复数,分治完的每一项除以$T$即可。

⚪ Bailey’s FFT算法

Bailey’s FFT也称为4-step FFT,是一种用于计算FFT的高性能算法。该算法将样本视为二维矩阵,对矩阵的列和行执行简短的 FFT 运算,并在运算之间乘以“旋转因子”进行校正。

上述分解表明,一个DFT矩阵可以分解成多个矩阵的乘积。对于一个长度为 $N = N_1 \times N_2$ 的序列,Bailey’s FFTDFT矩阵$F_N$分解为:

\[F_N = P (I_{N_2} \otimes F_{N_1}) D P^{-1} (I_{N_1} \otimes F_{N_2}) P\]

其中$P$ 是置换矩阵;$D$ 是包含 Twiddle 因子的对角矩阵;$F_{N_1}$ 和 $F_{N_2}$ 是较小的 FFT 矩阵。

3. 使用FFT加速深度学习中的卷积运算

(1)卷积在时域和频域计算的复杂度

给定一个长度为$N$的序列数据$x[n]$,使用一个大小为$K$的卷积核$h[k]$处理该数据;下面分析其在时域运算和频域运算的复杂度(主要考虑乘法)。

⚪ 时域卷积的复杂度

在时域中,直接计算卷积的过程如下:

\[y[n] = \sum_{k=0}^{K-1} h[k] \cdot x[n - k]\]

为了计算完整的卷积,需要对输入序列两端各进行长度$K-1$的填充。此时输出序列的长度为$N + K - 1$。计算每个输出点需要进行$K$次乘法和k-1次加法。因此计算复杂度为:

\[O((N + K - 1) \times K)\]

⚪ 频域卷积的复杂度

在频域中,计算卷积的主要步骤如下:

  1. 零填充:为了避免循环卷积的影响,需要将序列和卷积核都补零到至少 $M\geq N + K - 1$ 的长度;
  2. FFT计算:对零填充后的输入序列进行FFT:$O(2M \log_2 M)$;对零填充后的卷积核进行FFT:$O(2M \log_2 M)$;
  3. 频域乘法:在频域中需要进行$M$次复数乘法,对应$4M$次实数乘法;
  4. IFFT计算:对乘积结果进行iFFT:$O(2M \log_2 M)$

因此,频域卷积的总计算复杂度为:

\[\begin{aligned} &O(6M \log_2 M +4M) \\ \approx &O(6(N + K - 1) \log_2 (N + K - 1) +4(N + K - 1)) \\ \end{aligned}\]

⚪ 复杂度比较

当$K \ll N$时,时域卷积复杂度$\approx O(N)$小于频域卷积复杂度$\approx O(N \log_2 N)$,时域卷积更高效。当$K \approx N$时,时域卷积复杂度$\approx O(N^2)$大于时域卷积复杂度$\approx O(N \log_2 N)$,频域卷积更高效。

为了找到时域和频域计算复杂度相当的临界$K$,可以设:

\[(N + K - 1) \times K = 6(N + K - 1) \log_2 (N + K - 1) +4(N + K - 1) \\ \downarrow \\ K = 6 \log_2 (N + K - 1) + 4\]

列出一些$K$和$N$的对应关系:

$N$ 64 128 256 512 1024 2048
$K$ 44 49 54 59 65 70

注意到当卷积核较大时,频域卷积的复杂度优势会越来越明显。

(2)频域卷积的PyTorch实现

import torch

def fftconv(u, k):
    """
    A convolution through the fourier domain (from the Convolution Theorem)
    """
    seqlen = u.shape[-1]
    fft_size = 2 * seqlen

    k_f = torch.fft.rfft(k, n=fft_size) / fft_size
    u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)

    if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
    y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]

    return y.to(dtype=u.dtype)

⚪ 讨论:卷积的因果性

卷积操作具有因果性(causal)。因果性是指输出信号只依赖于当前及之前的输入信号,而不依赖于未来的信号。因果性可以避免未来信息泄漏到当前时刻,从而保证信号的稳定性和可靠性。

下面用图例进行说明。假设输入信号$u$和卷积核为$k$具有相同的长度$N$,同时填充至$2N$。在计算卷积时,对于序列的当前输入位置,仅有最多前$N$个输入信号对输出有贡献。

有时需要打破卷积的因果性,比如希望实现类似于自注意力机制的全局上下文提取。一种简单的解决方式是对卷积核进行重复填充,保证在卷积运算时至少有$N$个位置使得输入序列的所有有效元素均参与运算。输出结果取这部分位置即可。

import torch

def fftconv(u, k):
    """
    A convolution through the fourier domain (from the Convolution Theorem)
    """
    seqlen = u.shape[-1]
    kernellen = k.shape[-1]
    fft_size = 2 * seqlen

    k_f = torch.fft.rfft(k, n=fft_size) / fft_size
    u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)

    if len(u.shape) > 3:
        k_f = k_f.unsqueeze(1)
    y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")

    y = y[..., kernellen // 2 : seqlen + kernellen // 2]
    return y.to(dtype=u.dtype)

(3)深度学习中的例子

Hyena Operator

Hyena Operator是一种用于大规模语言建模的卷积语言模型,通过构造与输入序列等长的卷积核处理并提取特征,实现了与注意力机制相当的性能。

Hyena Operator的核心组件是递归地使用隐式参数化的长卷积和数据控制的门控机制来构建高效的子二次复杂度操作:

\[z^{n+1}_t = x_t^n \cdot (h^n * z_t^n)\]

其中卷积运算$h^n * z_t^n$可以使用FFT的卷积定理来快速计算:首先将输入序列的 FFT 相乘,然后应用逆 FFT 来有效地计算卷积的输出。要将此定理用于上述非循环卷积,需要用零填充等长的输入序列和卷积核,然后丢弃输出序列的填充部分。

\[\text{pad}(h^n * z_t^n) = \text{iFFT}(\text{FFT}(\text{pad}(h^n)) \cdot \text{FFT}(\text{pad}(z_t^n)))\]

Hyena 使用隐式参数化来定义卷积滤波器,避免了显式存储长卷积滤波器带来的参数数量问题。具体来说,卷积滤波器 $h^n$ 可以通过一个前馈神经网络(FFN)来参数化:

\[h_t = \text{Window}(t) \cdot (\text{FFN}\odot \text{PositionalEncoding})(t)\]

其中,$\text{Window}(t)$ 是一个衰减窗口函数,用于调制卷积滤波器;$\text{PositionalEncoding}(t)$ 是位置编码,用于引入序列中的位置信息;$\text{FFN}$ 是一个浅层前馈神经网络,用于生成卷积滤波器的值。

FlashFFTConv

FlashFFTConv 的核心思想是将 FFT 卷积通过Bailey’s FFT进行加速,从而充分利用现代硬件(如 GPUTensor Cores)的计算能力。

Bailey’s FFTFFT 分解为一系列矩阵乘法操作。这种分解不仅提高了计算效率,还可以通过内核融合减少内存 I/O 成本。内核融合允许将多个操作(如矩阵乘法和逐元素乘法)合并到一个内核中,从而减少数据在不同内存层次之间的传输。例如,对于长序列,FlashFFTConv 可以将内层矩阵操作和逐元素乘法融合在一起,仅在最外层矩阵操作时进行 I/O 操作。