虚拟对抗训练:一种半监督学习的正则化方法.

虚拟对抗训练(Virtual Adversarial Training, VAT)把对抗训练的思想引入半监督学习中。对抗训练是指对输入增加对抗噪声,然后训练模型对这种对抗攻击具有鲁棒性。在对抗攻击中,噪声$r$的构造是通过:

\[r = \mathop{\arg \max}_{||r|| \leq \epsilon} \mathcal{L}(y,f_{\theta}(x+r))\]

其中$y$是样本的真实标签,$f_{\theta}(x)$是模型的预测结果,$\mathcal{L}$是衡量两个结果距离的函数。构造攻击噪声$r$后,对抗训练是通过监督学习实现的,即最小化以下损失:

\[\mathcal{L}_s^{adv} = \sum_{(x,y) \in \mathcal{D}} \mathcal{L}(y,f_{\theta}(x+r))\]

在半监督学习中样本标签$y$通常是未知的,因此采用模型的当前预测结果$f_{\theta}(x)$代替,并且不对其进行梯度计算(通过stop_gradient算符实现)。对应的VAT过程为:

\[\begin{aligned} r &= \mathop{\arg \max}_{||r|| \leq \epsilon} \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x+r)) \\ \mathcal{L}_u^{VAT} &= \sum_{x \in \mathcal{D}} \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x+r)) \end{aligned}\]

VAT提供了当前模型对每个样本点的预测流型的负平滑测度,通过最小化VAT的无监督损失使得该流形更加平滑。

为了构造攻击噪声$r$,对损失函数\(\mathcal{L}\)进行二阶泰勒展开:

\[\mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x+r)) ≈ \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x)) + r^T\nabla_x \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x)) + \frac{1}{2} r^T\nabla_x \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x)) r\]

注意到损失函数\(\mathcal{L}\)通常具有以下性质:

\[\mathcal{L}(x,x) = 0 , \quad \nabla_x \mathcal{L}(x,x) = 0\]

因此损失函数的近似:

\[\mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x+r)) ≈ \frac{1}{2} r^T\nabla_x^2 \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x)) r\]

若约束$r$为单位向量$r^Tr=1$,则上式右端为Hessian矩阵\(\mathcal{H}=\nabla_x^2 \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x))\)的瑞利商(Rayleigh Quotient),,取值范围是:

\[\lambda_{min}≤r^T \mathcal{H} r≤\lambda_{max}\]

攻击噪声$r$应使损失函数\(\mathcal{L}\)取值最大,则$r$的取值为\(\mathcal{H}\)的最大特征值对应的特征向量,可以通过幂迭代(power iteration)方法求解。迭代格式为:

\[r \leftarrow \frac{\mathcal{H}r}{||\mathcal{H}r||}\]

其中$u$可以随机初始化。下面证明幂迭代的收敛性。

初始化$r^{(0)}$,若\(\mathcal{H}\)可对角化,则\(\mathcal{H}\)的特征向量\(\{v_1 v_2 \cdots v_n\}\)构成一组完备的基,$u^{(0)}$可由这组基表示:

\[r^{(0)} = c_1v_1+c_2v_2+\cdots c_nv_n\]

先不考虑迭代中分母的归一化,则迭代过程\(r \leftarrow \mathcal{H}r\)经过$t$次后为:

\[\mathcal{H}^tr^{(0)} = c_1\mathcal{H}^tv_1+c_2\mathcal{H}^tv_2+\cdots c_n\mathcal{H}^tv_n\]

注意到\(\mathcal{H}v=\lambda v\),则有:

\[\mathcal{H}^tr^{(0)} = c_1\lambda_1^tv_1+c_2\lambda_2^tv_2+\cdots c_n\lambda_n^tv_n\]

不失一般性地假设$\lambda_1$为最大特征值,则有:

\[\frac{\mathcal{H}^tr^{(0)}}{\lambda_1^t} = c_1v_1+c_2(\frac{\lambda_2}{\lambda_1})^tv_2+\cdots c_n(\frac{\lambda_n}{\lambda_1})^tv_n\]

注意到当$t \to \infty$时,$(\frac{\lambda_2}{\lambda_1})^t,\cdots (\frac{\lambda_n}{\lambda_1})^t \to 0$。则有:

\[\frac{\mathcal{H}^tr^{(0)}}{\lambda_1^t} ≈ c_1v_1\]

上述结果表明当迭代次数$t$足够大时,\(\mathcal{H}^tr^{(0)}\)提供了最大特征根对应的特征向量的近似方向,对其归一化后相当于单位特征向量:

\[\begin{aligned} r &= \frac{\mathcal{H}^tr^{(0)}}{||\mathcal{H}^tr^{(0)}||} \\ \mathcal{H} r &≈ \lambda_1 r \end{aligned}\]

上式即为幂迭代的迭代公式,证明完毕。

VAT的幂迭代中,不需要直接计算\(\mathcal{H}\)的值,只需要计算\(\mathcal{H}r\)的值,这可以通过差分来近似计算:

\[\begin{aligned} \mathcal{H}r &= \nabla_x^2 \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x)) r \\ &= \nabla_x (r\cdot \nabla_x \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x)) \\ & ≈ \nabla_x (\frac{ \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x+\xi r))-\mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x))}{\xi}) \\ & = \frac{1}{\xi} \nabla_x \mathcal{L}(\text{sg}(f_{\theta}(x)),f_{\theta}(x+\xi r)) \end{aligned}\]

其中$\xi$是一个标量常数。至此VAT的完成流程为:

在实践中取迭代次数$n=1$即可;若$n=1$,则退化为向输入数据中添加高斯噪声。

import contextlib
import torch
import torch.nn as nn
import torch.nn.functional as F


@contextlib.contextmanager
def _disable_tracking_bn_stats(model):

    def switch_attr(m):
        if hasattr(m, 'track_running_stats'):
            m.track_running_stats ^= True
            
    model.apply(switch_attr)
    yield
    model.apply(switch_attr)


def _l2_normalize(d):
    d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2)))
    d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8
    return d


class VATLoss(nn.Module):

    def __init__(self, xi=10.0, eps=1.0, ip=1):
        """VAT loss
        :param xi: hyperparameter of VAT (default: 10.0)
        :param eps: hyperparameter of VAT (default: 1.0)
        :param ip: iteration times of computing adv noise (default: 1)
        """
        super(VATLoss, self).__init__()
        self.xi = xi
        self.eps = eps
        self.ip = ip

    def forward(self, model, x):
        with torch.no_grad():
            pred = F.softmax(model(x), dim=1)

        # prepare random unit tensor
        r = torch.rand(x.shape).sub(0.5).to(x.device)
        r = _l2_normalize(r)

        with _disable_tracking_bn_stats(model):
            # calc adversarial direction
            for _ in range(self.ip):
                x.requires_grad_()
                pred_hat = model(x + self.xi * r)
                logp_hat = F.log_softmax(pred_hat, dim=1)
                adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean')
                adv_distance.backward()
                r = _l2_normalize(x.grad)
                model.zero_grad()
    
            # calc LDS
            r_adv = r * self.eps
            pred_hat = model(x + r_adv)
            logp_hat = F.log_softmax(pred_hat, dim=1)
            lds = F.kl_div(logp_hat, pred, reduction='batchmean')

        return lds

VAT的使用过程如下:

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()

    vat_loss = VATLoss(xi=10.0, eps=1.0, ip=1)
    cross_entropy = nn.CrossEntropyLoss()

    # LDS should be calculated before the forward for cross entropy
    lds = vat_loss(model, data)
    output = model(data)
    loss = cross_entropy(output, target) + args.alpha * lds
    loss.backward()
    optimizer.step()