虚拟对抗训练:一种半监督学习的正则化方法.
- paper:Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning
虚拟对抗训练(Virtual Adversarial Training, VAT)把对抗训练的思想引入半监督学习中。对抗训练是指对输入增加对抗噪声,然后训练模型对这种对抗攻击具有鲁棒性。在对抗攻击中,噪声$r$的构造是通过:
\[r = \mathop{\arg \max}_{||r|| \leq \epsilon} \mathcal{L}(y,f_{\theta}(x+r))\]其中$y$是样本的真实标签,$f_{\theta}(x)$是模型的预测结果,$\mathcal{L}$是衡量两个结果距离的函数。构造攻击噪声$r$后,对抗训练是通过监督学习实现的,即最小化以下损失:
\[\mathcal{L}_s^{adv} = \sum_{(x,y) \in \mathcal{D}} \mathcal{L}(y,f_{\theta}(x+r))\]在半监督学习中样本标签$y$通常是未知的,因此采用模型的当前预测结果$f_{\theta}(x)$代替,并且不对其进行梯度计算(通过stop_gradient算符实现)。对应的VAT过程为:
\[\begin{aligned} r &= \mathop{\arg \max}_{||r|| \leq \epsilon} \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x+r)) \\ \mathcal{L}_u^{VAT} &= \sum_{x \in \mathcal{D}} \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x+r)) \end{aligned}\]VAT提供了当前模型对每个样本点的预测流型的负平滑测度,通过最小化VAT的无监督损失使得该流形更加平滑。
为了构造攻击噪声$r$,对损失函数\(\mathcal{L}\)进行二阶泰勒展开:
\[\mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x+r)) ≈ \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x)) + r^T\nabla_x \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x)) + \frac{1}{2} r^T\nabla_x \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x)) r\]注意到损失函数\(\mathcal{L}\)通常具有以下性质:
\[\mathcal{L}(x,x) = 0 , \quad \nabla_x \mathcal{L}(x,x) = 0\]因此损失函数的近似:
\[\mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x+r)) ≈ \frac{1}{2} r^T\nabla_x^2 \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x)) r\]若约束$r$为单位向量$r^Tr=1$,则上式右端为Hessian矩阵\(\mathcal{H}=\nabla_x^2 \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x))\)的瑞利商(Rayleigh Quotient),,取值范围是:
\[\lambda_{min}≤r^T \mathcal{H} r≤\lambda_{max}\]攻击噪声$r$应使损失函数\(\mathcal{L}\)取值最大,则$r$的取值为\(\mathcal{H}\)的最大特征值对应的特征向量,可以通过幂迭代(power iteration)方法求解。迭代格式为:
\[r \leftarrow \frac{\mathcal{H}r}{||\mathcal{H}r||}\]其中$u$可以随机初始化。下面证明幂迭代的收敛性。
初始化$r^{(0)}$,若\(\mathcal{H}\)可对角化,则\(\mathcal{H}\)的特征向量\(\{v_1 v_2 \cdots v_n\}\)构成一组完备的基,$u^{(0)}$可由这组基表示:
\[r^{(0)} = c_1v_1+c_2v_2+\cdots c_nv_n\]先不考虑迭代中分母的归一化,则迭代过程\(r \leftarrow \mathcal{H}r\)经过$t$次后为:
\[\mathcal{H}^tr^{(0)} = c_1\mathcal{H}^tv_1+c_2\mathcal{H}^tv_2+\cdots c_n\mathcal{H}^tv_n\]注意到\(\mathcal{H}v=\lambda v\),则有:
\[\mathcal{H}^tr^{(0)} = c_1\lambda_1^tv_1+c_2\lambda_2^tv_2+\cdots c_n\lambda_n^tv_n\]不失一般性地假设$\lambda_1$为最大特征值,则有:
\[\frac{\mathcal{H}^tr^{(0)}}{\lambda_1^t} = c_1v_1+c_2(\frac{\lambda_2}{\lambda_1})^tv_2+\cdots c_n(\frac{\lambda_n}{\lambda_1})^tv_n\]注意到当$t \to \infty$时,$(\frac{\lambda_2}{\lambda_1})^t,\cdots (\frac{\lambda_n}{\lambda_1})^t \to 0$。则有:
\[\frac{\mathcal{H}^tr^{(0)}}{\lambda_1^t} ≈ c_1v_1\]上述结果表明当迭代次数$t$足够大时,\(\mathcal{H}^tr^{(0)}\)提供了最大特征根对应的特征向量的近似方向,对其归一化后相当于单位特征向量:
\[\begin{aligned} r &= \frac{\mathcal{H}^tr^{(0)}}{||\mathcal{H}^tr^{(0)}||} \\ \mathcal{H} r &≈ \lambda_1 r \end{aligned}\]上式即为幂迭代的迭代公式,证明完毕。
在VAT的幂迭代中,不需要直接计算\(\mathcal{H}\)的值,只需要计算\(\mathcal{H}r\)的值,这可以通过差分来近似计算:
\[\begin{aligned} \mathcal{H}r &= \nabla_x^2 \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x)) r \\ &= \nabla_x (r\cdot \nabla_x \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x)) \\ & ≈ \nabla_x (\frac{ \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x+\xi r))-\mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x))}{\xi}) \\ & = \frac{1}{\xi} \nabla_x \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x+\xi r)) \end{aligned}\]其中$\xi$是一个标量常数。至此VAT的完成流程为:
- 初始化向量$r$~\(\mathcal{N}(0,1)\),标量$\xi, \epsilon$;
- $r \leftarrow \frac{r}{|r|}$
- 迭代$n$次:
- $r \leftarrow \nabla_x \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x+\xi r))$
- $r \leftarrow \frac{r}{|r|}$
- 对损失函数\(\mathcal{L}_u^{VAT} = \sum_{x \in \mathcal{D}} \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x+\epsilon r))\)执行梯度下降。
在实践中取迭代次数$n=1$即可;若$n=1$,则退化为向输入数据中添加高斯噪声。
import contextlib
import torch
import torch.nn as nn
import torch.nn.functional as F
@contextlib.contextmanager
def _disable_tracking_bn_stats(model):
def switch_attr(m):
if hasattr(m, 'track_running_stats'):
m.track_running_stats ^= True
model.apply(switch_attr)
yield
model.apply(switch_attr)
def _l2_normalize(d):
d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2)))
d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8
return d
class VATLoss(nn.Module):
def __init__(self, xi=10.0, eps=1.0, ip=1):
"""VAT loss
:param xi: hyperparameter of VAT (default: 10.0)
:param eps: hyperparameter of VAT (default: 1.0)
:param ip: iteration times of computing adv noise (default: 1)
"""
super(VATLoss, self).__init__()
self.xi = xi
self.eps = eps
self.ip = ip
def forward(self, model, x):
with torch.no_grad():
pred = F.softmax(model(x), dim=1)
# prepare random unit tensor
r = torch.rand(x.shape).sub(0.5).to(x.device)
r = _l2_normalize(r)
with _disable_tracking_bn_stats(model):
# calc adversarial direction
for _ in range(self.ip):
x.requires_grad_()
pred_hat = model(x + self.xi * r)
logp_hat = F.log_softmax(pred_hat, dim=1)
adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean')
adv_distance.backward()
r = _l2_normalize(x.grad)
model.zero_grad()
# calc LDS
r_adv = r * self.eps
pred_hat = model(x + r_adv)
logp_hat = F.log_softmax(pred_hat, dim=1)
lds = F.kl_div(logp_hat, pred, reduction='batchmean')
return lds
VAT的使用过程如下:
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
vat_loss = VATLoss(xi=10.0, eps=1.0, ip=1)
cross_entropy = nn.CrossEntropyLoss()
# LDS should be calculated before the forward for cross entropy
lds = vat_loss(model, data)
output = model(data)
loss = cross_entropy(output, target) + args.alpha * lds
loss.backward()
optimizer.step()