虚拟对抗训练:一种用于监督学习和半监督学习的正则化方法.

1. 对抗训练

对抗训练(Adversarial Training)是指通过构造对抗样本,对模型进行对抗攻击和防御来增强模型的稳健性。对抗样本通常是指具有小扰动的样本,对于人类来说“看起来”几乎一样,但对于模型来说预测结果却完全不一样:

对抗训练的一般形式:

\[\mathcal{\min}_{\theta} \mathbb{E}_{(x,y)\sim \mathcal{D}} \left[ \mathcal{\max}_{\Delta x \in \Omega} \mathcal{L}(x+\Delta x,y;\theta) \right]\]

其中\(\mathcal{D}\)表示训练集,$x$表示输入样本,$y$表示样本标签,$\theta$表示模型参数,\(\mathcal{L}\)是损失函数,$\Delta x$是样本的对抗扰动,$\Omega$是扰动空间。

完整的对抗训练步骤如下:

  1. 为输入样本$x$添加对抗扰动$\Delta x$,$\Delta x$的目标是使得损失\(\mathcal{L}(x+\Delta x,y;\theta)\)尽可能增大,即尽可能让现有模型的预测出错;
  2. 对抗扰动$\Delta x$要满足一定的约束,比如模长不超过一个常数$\epsilon$:\(\|\Delta x\| \leq \epsilon\);
  3. 对每个样本$x$都构造一个对抗样本$x+\Delta x$,用样本对$(x+\Delta x,y)$最小化损失函数训练模型参数$\theta$。

2. 虚拟对抗训练

虚拟对抗训练(Virtual Adversarial Training, VAT)是一种正则化方法。通过寻找使得损失$l(f(x+\epsilon),f(x))$尽可能大的扰动噪声$\epsilon$,并最小化该损失,从而增强网络对于扰动噪声的鲁棒性。

对损失$l(f(x+\epsilon),f(x))$在$\epsilon$处进行泰勒展开:

\[l(f(x+\epsilon),f(x)) \approx l(f(x),f(x)) + \epsilon^T\nabla_xl(f(x),f_{sg}(x)) + \frac{1}{2}\epsilon^T\nabla_x^2l(f(x),f_{sg}(x))\epsilon\]

对于一般地损失函数$l(,)$,具有性质$l(x,x)=0$;因此上式简化为:

\[l(f(x+\epsilon),f(x)) \approx \frac{1}{2}\epsilon^T\nabla_x^2l(f(x),f_{sg}(x))\epsilon\]

寻找使得损失$l(f(x+\epsilon),f(x))$尽可能大的扰动噪声$\epsilon$,等价于首先计算Hessian矩阵\(\mathcal{H}=\nabla_x^2l(f(x),f_{sg}(x))\),然后求解最大化\(\epsilon^T \mathcal{H} \epsilon\)的$\epsilon$。

根据瑞利商的定义,最大化\(\epsilon^T \mathcal{H} \epsilon\)的$\epsilon$为\(\mathcal{H}\)的最大特征值对应的特征向量(主特征向量)。\(\mathcal{H}\)的主特征向量可以通过幂迭代(power iteration)方法求解。迭代格式:

\[u \leftarrow \frac{\mathcal{H}u}{||\mathcal{H}u||}\]

其中$u$随机初始化。下面简单证明迭代过程收敛,初始化$u^{(0)}$,若\(\mathcal{H}\)可对角化,则\(\mathcal{H}\)的特征向量\(\{v_1 v_2 \cdots v_n\}\)构成一组完备的基,$u^{(0)}$可由这组基表示:

\[u^{(0)} = c_1v_1+c_2v_2+\cdots c_nv_n\]

先不考虑迭代中分母的归一化,则迭代过程\(u \leftarrow \mathcal{H}u\)经过$t$次后为:

\[\mathcal{H}^tu^{(0)} = c_1\mathcal{H}^tv_1+c_2\mathcal{H}^tv_2+\cdots c_n\mathcal{H}^tv_n\]

注意到\(\mathcal{H}v=\lambda v\),则有:

\[\mathcal{H}^tu^{(0)} = c_1\lambda_1^tv_1+c_2\lambda_2^tv_2+\cdots c_n\lambda_n^tv_n\]

不失一般性地假设$\lambda_1$为最大特征值,则有:

\[\frac{\mathcal{H}^tu^{(0)}}{\lambda_1^t} = c_1v_1+c_2(\frac{\lambda_2}{\lambda_1})^tv_2+\cdots c_n(\frac{\lambda_n}{\lambda_1})^tv_n\]

注意到当$t \to \infty$时,$(\frac{\lambda_2}{\lambda_1})^t,\cdots (\frac{\lambda_n}{\lambda_1})^t \to 0$。则有:

\[\frac{\mathcal{H}^tu^{(0)}}{\lambda_1^t} ≈ c_1v_1\]

上述结果表明当迭代次数$t$足够大时,\(\mathcal{H}^tu^{(0)}\)提供了最大特征根对应的特征向量的近似方向,对其归一化后相当于单位特征向量:

\[\begin{aligned} u &= \frac{\mathcal{H}^tu^{(0)}}{||\mathcal{H}^tu^{(0)}||} \\ \mathcal{H} u &≈ \lambda_1 u \end{aligned}\]

在幂迭代中,并不需要知道\(\mathcal{H}\)的具体值,而只需要计算\(\mathcal{H}u\),这可以通过差分来近似计算:

\[\begin{aligned} \mathcal{H}u &= \nabla_x^2l(f(x),f_{sg}(x))u \\ &= \nabla_x\left(u \cdot \nabla_xl(f(x),f_{sg}(x))\right) \\ &\approx \nabla_x\left( \frac{l(f(x+\xi u),f_{sg}(x))-l(f(x),f_{sg}(x))}{\xi} \right) \\ &= \frac{1}{\xi} \nabla_x l(f(x+\xi u),f_{sg}(x)) \\ \end{aligned}\]

VAT的完整流程如下:

  1. 初始化向量\(u\sim \mathcal{N}(0,1)\)、标量$\epsilon, \xi$;
  2. 迭代$r$次:\(\begin{aligned} u &\leftarrow \frac{u}{\mid\mid u \mid\mid} \\ u &\leftarrow \nabla_x l(f(x+\xi u),f_{sg}(x)) \end{aligned}\)
  3. $u \leftarrow \frac{u}{\mid\mid u \mid\mid}$
  4. 用$l(f(x+\epsilon u),f_{sg}(x))$作为损失函数执行梯度下降。

注意到当$r=0$时相当于向输入增加高斯噪声,VAT通过迭代$r \geq 1$次来增加噪声的特异性。

以分类任务为例,完整的VAT过程实现如下:

import contextlib
import torch
import torch.nn as nn
import torch.nn.functional as F

@contextlib.contextmanager
def _disable_tracking_bn_stats(model):

    def switch_attr(m):
        if hasattr(m, 'track_running_stats'):
            m.track_running_stats ^= True
            
    model.apply(switch_attr)
    yield
    model.apply(switch_attr)

def _l2_normalize(d):
    d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2)))
    d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8
    return d


class VATLoss(nn.Module):

    def __init__(self, xi=10.0, eps=1.0, ip=1):
        """VAT loss
        :param xi: hyperparameter of VAT (default: 10.0)
        :param eps: hyperparameter of VAT (default: 1.0)
        :param ip: iteration times of computing adv noise (default: 1)
        """
        super(VATLoss, self).__init__()
        self.xi = xi
        self.eps = eps
        self.ip = ip

    def forward(self, model, x):
        with torch.no_grad():
            pred = F.softmax(model(x), dim=1)

        # prepare random unit tensor
        d = torch.rand(x.shape).sub(0.5).to(x.device)
        d = _l2_normalize(d)

        with _disable_tracking_bn_stats(model):
            # calc adversarial direction
            for _ in range(self.ip):
                d.requires_grad_()
                pred_hat = model(x + self.xi * d)
                logp_hat = F.log_softmax(pred_hat, dim=1)
                adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean')
                adv_distance.backward()
                d = _l2_normalize(d.grad)
                model.zero_grad()
    
            # calc LDS
            r_adv = d * self.eps
            pred_hat = model(x + r_adv)
            logp_hat = F.log_softmax(pred_hat, dim=1)
            lds = F.kl_div(logp_hat, pred, reduction='batchmean')

        return lds