虚拟对抗训练:一种用于监督学习和半监督学习的正则化方法.
- paper:Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning
1. 对抗训练
对抗训练(Adversarial Training)是指通过构造对抗样本,对模型进行对抗攻击和防御来增强模型的稳健性。对抗样本通常是指具有小扰动的样本,对于人类来说“看起来”几乎一样,但对于模型来说预测结果却完全不一样:
对抗训练的一般形式:
\[\mathcal{\min}_{\theta} \mathbb{E}_{(x,y)\sim \mathcal{D}} \left[ \mathcal{\max}_{\Delta x \in \Omega} \mathcal{L}(x+\Delta x,y;\theta) \right]\]其中\(\mathcal{D}\)表示训练集,$x$表示输入样本,$y$表示样本标签,$\theta$表示模型参数,\(\mathcal{L}\)是损失函数,$\Delta x$是样本的对抗扰动,$\Omega$是扰动空间。
完整的对抗训练步骤如下:
- 为输入样本$x$添加对抗扰动$\Delta x$,$\Delta x$的目标是使得损失\(\mathcal{L}(x+\Delta x,y;\theta)\)尽可能增大,即尽可能让现有模型的预测出错;
- 对抗扰动$\Delta x$要满足一定的约束,比如模长不超过一个常数$\epsilon$:\(\|\Delta x\| \leq \epsilon\);
- 对每个样本$x$都构造一个对抗样本$x+\Delta x$,用样本对$(x+\Delta x,y)$最小化损失函数训练模型参数$\theta$。
2. 虚拟对抗训练
虚拟对抗训练(Virtual Adversarial Training, VAT)是一种正则化方法。通过寻找使得损失$l(f(x+\epsilon),f(x))$尽可能大的扰动噪声$\epsilon$,并最小化该损失,从而增强网络对于扰动噪声的鲁棒性。
对损失$l(f(x+\epsilon),f(x))$在$\epsilon$处进行泰勒展开:
\[l(f(x+\epsilon),f(x)) \approx l(f(x),f(x)) + \epsilon^T\nabla_xl(f(x),f_{sg}(x)) + \frac{1}{2}\epsilon^T\nabla_x^2l(f(x),f_{sg}(x))\epsilon\]对于一般地损失函数$l(,)$,具有性质$l(x,x)=0$;因此上式简化为:
\[l(f(x+\epsilon),f(x)) \approx \frac{1}{2}\epsilon^T\nabla_x^2l(f(x),f_{sg}(x))\epsilon\]寻找使得损失$l(f(x+\epsilon),f(x))$尽可能大的扰动噪声$\epsilon$,等价于首先计算Hessian矩阵\(\mathcal{H}=\nabla_x^2l(f(x),f_{sg}(x))\),然后求解最大化\(\epsilon^T \mathcal{H} \epsilon\)的$\epsilon$。
根据瑞利商的定义,最大化\(\epsilon^T \mathcal{H} \epsilon\)的$\epsilon$为\(\mathcal{H}\)的最大特征值对应的特征向量(主特征向量)。\(\mathcal{H}\)的主特征向量可以通过幂迭代(power iteration)方法求解。迭代格式:
\[u \leftarrow \frac{\mathcal{H}u}{||\mathcal{H}u||}\]其中$u$随机初始化。下面简单证明迭代过程收敛,初始化$u^{(0)}$,若\(\mathcal{H}\)可对角化,则\(\mathcal{H}\)的特征向量\(\{v_1 v_2 \cdots v_n\}\)构成一组完备的基,$u^{(0)}$可由这组基表示:
\[u^{(0)} = c_1v_1+c_2v_2+\cdots c_nv_n\]先不考虑迭代中分母的归一化,则迭代过程\(u \leftarrow \mathcal{H}u\)经过$t$次后为:
\[\mathcal{H}^tu^{(0)} = c_1\mathcal{H}^tv_1+c_2\mathcal{H}^tv_2+\cdots c_n\mathcal{H}^tv_n\]注意到\(\mathcal{H}v=\lambda v\),则有:
\[\mathcal{H}^tu^{(0)} = c_1\lambda_1^tv_1+c_2\lambda_2^tv_2+\cdots c_n\lambda_n^tv_n\]不失一般性地假设$\lambda_1$为最大特征值,则有:
\[\frac{\mathcal{H}^tu^{(0)}}{\lambda_1^t} = c_1v_1+c_2(\frac{\lambda_2}{\lambda_1})^tv_2+\cdots c_n(\frac{\lambda_n}{\lambda_1})^tv_n\]注意到当$t \to \infty$时,$(\frac{\lambda_2}{\lambda_1})^t,\cdots (\frac{\lambda_n}{\lambda_1})^t \to 0$。则有:
\[\frac{\mathcal{H}^tu^{(0)}}{\lambda_1^t} ≈ c_1v_1\]上述结果表明当迭代次数$t$足够大时,\(\mathcal{H}^tu^{(0)}\)提供了最大特征根对应的特征向量的近似方向,对其归一化后相当于单位特征向量:
\[\begin{aligned} u &= \frac{\mathcal{H}^tu^{(0)}}{||\mathcal{H}^tu^{(0)}||} \\ \mathcal{H} u &≈ \lambda_1 u \end{aligned}\]在幂迭代中,并不需要知道\(\mathcal{H}\)的具体值,而只需要计算\(\mathcal{H}u\),这可以通过差分来近似计算:
\[\begin{aligned} \mathcal{H}u &= \nabla_x^2l(f(x),f_{sg}(x))u \\ &= \nabla_x\left(u \cdot \nabla_xl(f(x),f_{sg}(x))\right) \\ &\approx \nabla_x\left( \frac{l(f(x+\xi u),f_{sg}(x))-l(f(x),f_{sg}(x))}{\xi} \right) \\ &= \frac{1}{\xi} \nabla_x l(f(x+\xi u),f_{sg}(x)) \\ \end{aligned}\]VAT的完整流程如下:
- 初始化向量\(u\sim \mathcal{N}(0,1)\)、标量$\epsilon, \xi$;
- 迭代$r$次:\(\begin{aligned} u &\leftarrow \frac{u}{\mid\mid u \mid\mid} \\ u &\leftarrow \nabla_x l(f(x+\xi u),f_{sg}(x)) \end{aligned}\)
- $u \leftarrow \frac{u}{\mid\mid u \mid\mid}$
- 用$l(f(x+\epsilon u),f_{sg}(x))$作为损失函数执行梯度下降。
注意到当$r=0$时相当于向输入增加高斯噪声,VAT通过迭代$r \geq 1$次来增加噪声的特异性。
以分类任务为例,完整的VAT过程实现如下:
import contextlib
import torch
import torch.nn as nn
import torch.nn.functional as F
@contextlib.contextmanager
def _disable_tracking_bn_stats(model):
def switch_attr(m):
if hasattr(m, 'track_running_stats'):
m.track_running_stats ^= True
model.apply(switch_attr)
yield
model.apply(switch_attr)
def _l2_normalize(d):
d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2)))
d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8
return d
class VATLoss(nn.Module):
def __init__(self, xi=10.0, eps=1.0, ip=1):
"""VAT loss
:param xi: hyperparameter of VAT (default: 10.0)
:param eps: hyperparameter of VAT (default: 1.0)
:param ip: iteration times of computing adv noise (default: 1)
"""
super(VATLoss, self).__init__()
self.xi = xi
self.eps = eps
self.ip = ip
def forward(self, model, x):
with torch.no_grad():
pred = F.softmax(model(x), dim=1)
# prepare random unit tensor
d = torch.rand(x.shape).sub(0.5).to(x.device)
d = _l2_normalize(d)
with _disable_tracking_bn_stats(model):
# calc adversarial direction
for _ in range(self.ip):
d.requires_grad_()
pred_hat = model(x + self.xi * d)
logp_hat = F.log_softmax(pred_hat, dim=1)
adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean')
adv_distance.backward()
d = _l2_normalize(d.grad)
model.zero_grad()
# calc LDS
r_adv = d * self.eps
pred_hat = model(x + r_adv)
logp_hat = F.log_softmax(pred_hat, dim=1)
lds = F.kl_div(logp_hat, pred, reduction='batchmean')
return lds