KAN:柯尔莫哥洛夫-阿诺德网络.
1. Kolmogorov–Arnold Networks (KAN)
(1)柯尔莫哥洛夫-阿诺德表示定理 (Kolmogorov-Arnold Representation Theorem)
柯尔莫哥洛夫-阿诺德表示定理指出,对于任意有界域上的多变量连续函数$f$,可以表示为有限数量的单变量连续函数的两层嵌套加法运算。对于光滑函数$f:[0,1]^n\rightarrow \mathbb{R}$,有:
\[f(\mathbf{x}) = f(x_1,\cdots,x_n) = \sum_{q=1}^{2n+1} \Phi_q \left( \sum_{p=1}^n\phi_{q,p}(x_p) \right)\]其中$\phi_{q,p}:[0,1]\rightarrow \mathbb{R},\Phi_q:\mathbb{R}\rightarrow \mathbb{R}$。该定义指出学习一个高维函数等价于学习多项式数量的一维函数。然而由于这些需要学习的一维函数可能是非光滑的,因此在实践中通常认为该定理不适用于深度学习。注意到原始的柯尔莫哥洛夫-阿诺德表示定理只有两层非线性层和$2n+1$项隐藏层单元,若将该网络推广到任意宽度和深度时,可以缓解一维函数的学习困难。
(2)KAN的网络结构
一个具有$n_{in}$维输入、$n_{out}$维输出的KAN层定义为:
\[\mathbf{x}_{l+1} = \underbrace{\begin{pmatrix} \phi_{l,1,1}(\cdot) & \phi_{l,1,2}(\cdot) & \cdots & \phi_{l,1,n_{in}}(\cdot) \\ \phi_{l,2,1}(\cdot) & \phi_{l,2,2}(\cdot) & \cdots & \phi_{l,2,n_{in}}(\cdot) \\ \vdots & \vdots & & \vdots \\ \phi_{l,n_{out},1}(\cdot) & \phi_{l,n_{out},2}(\cdot) & \cdots & \phi_{l,n_{out},n_{in}}(\cdot) \end{pmatrix}}_{\Phi_l} \mathbf{x}_{l}\]KAN网络是$L$个KAN层的组合:
\[KAN(\mathbf{x}) = \left( \Phi_{L-1} \circ \Phi_{L-2} \circ \cdots \circ \Phi_{1} \circ \Phi_{0}\right)\mathbf{x}\]原始的柯尔莫哥洛夫-阿诺德表示定理可以视为$L=2$时的KAN网络,其神经元的数量为$[n,2n+1,1]$。
由于所需要学习的函数是单变量函数,因此将一维函数$\phi(\cdot)$参数化为残差激活函数,由一个基函数$b(x)$和一个B样条函数$\text{spline}(x)$求和:
\[\begin{aligned} & \phi(x) = w\left( b(x)+\text{spline}(x) \right) \\ & b(x) = \text{silu}(x) = \frac{x}{1+e^{-x}} \\ & \text{spline}(x) = \sum_{i=1} c_i B_i(x) \end{aligned}\]其中B样条函数初始化为$\text{spline}(x)\approx 0$,$w$采用Xavier初始化。
⚪ KAN的近似理论(Approximation Theory)
假设函数$f(\mathbf{x})$可以表示为:
\[f(\mathbf{x}) = \left( \Phi_{L-1} \circ \Phi_{L-2} \circ \cdots \circ \Phi_{1} \circ \Phi_{0}\right)\mathbf{x}\]则存在$k$阶网格尺寸为$G$的B样条函数使得对于任意$0≤m≤k$,存在一个常数$C$满足:
\[\left\| f-\left( \Phi_{L-1} \circ \Phi_{L-2} \circ \cdots \circ \Phi_{1} \circ \Phi_{0}\right)\mathbf{x}\right\|_{C^m} \leq CG^{-k-1+m}\]该理论表明KAN网络近似函数$f$的精度与输入维度无关,因此不会受到维度诅咒(curse of dimensionality)的影响。
⚪ KAN的尺度定律(Scaling Law)
KAN网络的测试误差$l$随着模型参数$N$的增加而减小:
\[l \propto N^{-(k+1)}\]其中$k$是B样条的分段多项式阶数。
(3)KAN的网格扩展
由于B样条函数可以通过设置网格细粒度来提高目标函数的精确程度,因此对于KAN网络,可以先用更少的参数训练,然后通过简单地精细化其样条网格,将其扩展到具有更多参数的KAN网络,而不需要从头开始重新训练更大的模型。
假设用$k$阶的B样条函数在有界区域$[a, b]$中近似一维函数$f$,具有$G_1$区间的粗粒度网格,其网格点为${t_0 = a, t_1, t_2,\cdots,t_{G_1} = b}$,将其增广为${t_{-k},\cdots,t_{-1},t_0 , t_1, \cdots,t_{G_1},t_{G_1+1},\cdots,,t_{G_1+k}}$。设置$G_1+k$个B样条基函数,其中第$i$个基函数$B_i(x)$只在 $t_{-k+i},t_{i+1}$ 区间内非零,则函数$f$表示为这些B样条基函数的线性组合:
\[f(x) = \sum_{i=0}^{G_1+k-1} c_i B_i(x)\]给定具有$G_2>G_1$区间的细粒度网格,函数$f$表示为:
\[f(x) = \sum_{j=0}^{G_2+k-1} c_j^\prime B_j^\prime(x)\]其中参数$c_j^\prime$可以通过最小化上述两种函数表示之间的分布距离初始化:
\[\{c_j^\prime\} = {\arg\min}_{\{c_j^\prime\}} \mathbb{E}_{x\sim p(x)}\left(\sum_{j=0}^{G_2+k-1} c_j^\prime B_j^\prime(x)-\sum_{i=0}^{G_1+k-1} c_i B_i(x)\right)^2\](4)KAN的可解释性
在实践中从一个足够大的KAN网络开始,用稀疏正则化训练,然后进行剪枝;剪枝后的KAN网络具有较好的可解释性。
⚪ 稀疏正则化
激活函数$\phi(x)$的L1范数定义为其$N_p$个输入上的平均幅度:
\[\left| \phi \right|_1 = \frac{1}{N_p} \sum_{s=1}^{N_p} \left| \phi(x^{(s)}) \right|\]一个具有$n_{in}$维输入、$n_{out}$维输出的KAN层$\Phi$的L1范数定义为其所有激活函数$\phi(x)$的L1范数之和:
\[\left| \Phi \right|_1 = \sum_{i=1}^{n_{in}} \sum_{j=1}^{n_{out}} \left| \phi_{i,j} \right|_1\]$\Phi$的熵定义为:
\[S(\Phi) = -\sum_{i=1}^{n_{in}} \sum_{j=1}^{n_{out}} \frac{\left| \phi_{i,j} \right|_1}{\left| \Phi \right|_1} \log \left(\frac{\left| \phi_{i,j} \right|_1}{\left| \Phi \right|_1}\right)\]则稀疏正则化训练是指在训练损失中引入L1损失和熵正则化:
\[\mathcal{L}_{total} = \mathcal{L}_{pred} + \lambda\left( \mu_1\sum_{l=0}^{L-1} \left| \Phi_l \right|_1 + \mu_2\sum_{l=0}^{L-1} S(\Phi_l) \right)\]⚪ 剪枝
使用稀疏正则化训练后,在节点级别对KAN网络进行剪枝。对于每个节点(假设是第$l$层的第$i$个神经元),将其传入和传出的分数定义为:
\[\begin{aligned} & I(l,i) = \max_{k}(\left| \phi_{l-1,k,i} \right|_1) \\ & O(l,i) = \max_{j}(\left| \phi_{l+1,j,i} \right|_1) \end{aligned}\]如果传入和传出的分数都大于阈值$θ = 10^{−2}$,则认为节点是重要的,否则对节点进行剪枝。
⚪ 可解释性
在完成稀疏正则化和剪枝后,KAN网络具有一定程度上的符号可解释性。将网络中所有的激活函数符号化,并设置合适的仿射参数,则KAN网络可以执行符号回归:
2. 比较KAN和MLP
KAN网络和多层感知机(Multi-Layer Perceptrons, MLP)的主要区别在于:
- KAN网络的激活函数是可学习的(B样条),MLP的激活函数是固定的(如ReLU)。
- KAN网络的激活函数作用于神经元之间的边上,MLP的激活函数作用于神经元节点上。
在相同的参数量下,KAN网络的计算时间成本通常是MLP的10倍左右。然而KAN网络通常具有更小的计算图,在拟合一些简单的函数时(符号回归),更小深度的KAN网络性能超过MLP。
KAN网络相比于MLP最大的优势是可解释性;在实践中可以通过权衡训练时间和需求选择不同的模型。