深度学习中隐含的几何奥卡姆剃刀.

(1)对参数的梯度惩罚

梯度下降算法是一种一阶近似优化算法,相当于隐式地在损失函数中添加了对参数的梯度惩罚项:

\[\begin{aligned} \tilde{g}(W) & \approx g(W) + \frac{1}{4}\gamma \nabla_{W} ||g(W)||^2 \\ & = \nabla_{W} \left( L(W) + \frac{1}{4}\gamma ||\nabla_{W} L(W)||^2 \right) \end{aligned}\]

梯度惩罚项有助于模型到达更加平缓的区域,有利于提高泛化性能。此外也可以显式地将梯度惩罚加入到损失中:

\[\mathcal{L}(x,y;W) + \lambda ||\nabla_{W} \mathcal{L}(x,y;W)||^2\]

(2)对输入的梯度惩罚

对抗训练中,对输入样本施加\(\epsilon \nabla_x \mathcal{L}(x,y;\theta)\)的对抗扰动,等价于向损失函数中加入对输入的梯度惩罚:

\[\begin{aligned} \mathcal{L}(x+\Delta x,y;W) &\approx \mathcal{L}(x,y;W)+\epsilon ||\nabla_x\mathcal{L}(x,y;W)||^2 \end{aligned}\]

此时梯度惩罚(或对抗训练)使得模型对于较小的输入扰动具有鲁棒性。此外,对输入的梯度惩罚也被用于约束模型的Lipschitz连续性

(3)两种梯度惩罚之间的关系

假设对于一个$L$层的MLP模型,

\[h^{(l+1)} = g^{(l)}(W^{(l)}h^{(l)}+b^{(l)})\]

记全体参数为$\theta = (W^{(1)},b^{(1)},…,W^{(L)},b^{(L)})$,设$f$是$h^{(l+1)}$的任意一个标量函数(损失函数),则存在不等式:

\[||\nabla_x f||^2 \left( \frac{1+||h^{(1)}||^2}{||W^{(1)}||^2||\nabla_x h^{(1)}||^2}+\cdots + \frac{1+||h^{(L)}||^2}{||W^{(L)}||^2||\nabla_x h^{(L)}||^2} \right) \leq ||\nabla_{\theta} f||^2\]

该不等式显示,对参数的梯度惩罚一定程度上包含了输入的梯度惩罚。为证明上述不等式,只需证:

\[\begin{aligned} ||\nabla_x f||^2 \left( \frac{1+||h^{(l)}||^2}{||W^{(l)}||^2||\nabla_x h^{(l)}||^2} \right) \leq ||\nabla_{\theta^{(l)}} f||^2 \end{aligned}\]

等价地证明:

\[\begin{aligned} ||\nabla_x f||^2 \left( \frac{||h^{(l)}||^2}{||W^{(l)}||^2||\nabla_x h^{(l)}||^2} \right) &\leq ||\nabla_{W^{(l)}} f||^2 \\ ||\nabla_x f||^2 \left( \frac{1}{||W^{(l)}||^2||\nabla_x h^{(l)}||^2} \right) &\leq ||\nabla_{b^{(l)}} f||^2 \end{aligned}\]

不妨记$z^{(l)}=W^{(l)}h^{(l)}+b^{(l)}$,则有:

\[\begin{aligned} \nabla_x f = \frac{\partial f}{\partial x} = \frac{\partial f}{\partial z^{(l)}}\frac{\partial z^{(l)}}{\partial h^{(l)}} \frac{\partial h^{(l)}}{\partial x} = \frac{\partial f}{\partial z^{(l)}} W^{(l)} \frac{\partial h^{(l)}}{\partial x} \end{aligned}\]

此外根据$W^{(l)}=(z^{(l)}-b^{(l)})(h^{(l)})^{-1}$,有:

\[\begin{aligned} \frac{\partial f}{\partial z^{(l)}} = \frac{\partial f}{\partial W^{(l)}} \frac{\partial W^{(l)}}{\partial z^{(l)}} = \frac{\partial f}{\partial W^{(l)}} (h^{(l)})^{-1} \end{aligned}\]

结合两式得到:

\[\begin{aligned} \nabla_x f = \frac{\partial f}{\partial W^{(l)}} (h^{(l)})^{-1}W^{(l)} \frac{\partial h^{(l)}}{\partial x}= \nabla_{W^{(l)}} f \cdot (h^{(l)})^{-1}W^{(l)} \nabla_x h^{(l)} \end{aligned}\]

上式两边取范数得:

\[\begin{aligned} ||\nabla_x f ||&= ||\nabla_{W^{(l)}} f \cdot (h^{(l)})^{-1}W^{(l)} \nabla_x h^{(l)}|| \\ & \leq ||\nabla_{W^{(l)}} f|| \cdot ||h^{(l)}||^{-1}\cdot ||W^{(l)}||\cdot ||\nabla_x h^{(l)}|| \end{aligned}\]

整理得到:

\[\begin{aligned} ||\nabla_x f||^2 \left( \frac{||h^{(l)}||^2}{||W^{(l)}||^2||\nabla_x h^{(l)}||^2} \right) &\leq ||\nabla_{W^{(l)}} f||^2 \\ \end{aligned}\]

类似的思路可证:

\[\begin{aligned} ||\nabla_x f||^2 \left( \frac{1}{||W^{(l)}||^2||\nabla_x h^{(l)}||^2} \right) &\leq ||\nabla_{b^{(l)}} f||^2 \end{aligned}\]

(4)本文的主要结论

SGD隐式地包含了对参数的梯度惩罚项,而对参数的梯度惩罚隐式地包含了对输入的梯度惩罚,而对输入的梯度惩罚又跟Dirichlet能量有关,Dirichlet能量则可以作为模型复杂度的表征。所以结论是:SGD本身会倾向于选择复杂度比较小的模型。