深度变分信息瓶颈.

变分信息瓶颈 (Variational Information Bottleneck, VIB)的出发点是通过使用尽可能少的信息来完成任务,从而得到更好的泛化能力。

以分类任务为例,存在标注数据$(x_1,y_1),…,(x_N,y_N)$。分类任务可以被进一步拆分成编码+分类两个步骤。第一步是把$x$编码为一个隐变量$z$,第二步是把隐变量$z$识别为类别$y$。

\[x \to z \to y\]

变分信息瓶颈希望能尽可能地减少隐变量$z$包含的信息量,这可以通过互信息$I(x,z)$衡量。互信息$I(x,z)$衡量随机变量$x$由于已知随机变量$z$而降低的不确定性,计算为:

\[I(x,z) = \mathbb{E}_{p(x,z)} \left[ \log \frac{p(x,z)}{p(x)p(z)} \right]\]

最小化隐变量$z$包含的信息量,等价于最小化互信息$I(x,z)$。因此可以构造损失函数:

\[\mathcal{L}_{VIB} = \iint p(x,z)\log \frac{p(x,z)}{p(x)p(z)} dxdz\]

通常隐变量$z$的先验分布是未知的,因此通过引入一个形式已知的分布$q(z)$来估计上述损失函数的一个上界:

\[\begin{aligned} \mathcal{L}_{VIB} &= \iint p(x,z)\log \frac{p(x,z)q(z)}{p(x)p(z)q(z)} dxdz \\ &= \iint p(x,z)\log \frac{p(z|x)q(z)}{p(z)q(z)} dxdz \\ &= \iint p(x,z)\log \frac{p(z|x)}{q(z)} dxdz + \iint p(x,z)\log \frac{q(z)}{p(z)} dxdz \\ &= \iint p(z|x)p(x)\log \frac{p(z|x)}{q(z)} dxdz - \iint p(x,z)\log \frac{p(z)}{q(z)} dxdz \\ &= \int p(x) \left[ \int p(z|x)\log \frac{p(z|x)}{q(z)} dz\right]dx - \int p(z)\log \frac{p(z)}{q(z)} dz \\ &= \int p(x) KL\left[ p(z|x) \mid\mid q(z)\right]dx - KL\left[ p(z) \mid\mid q(z)\right] \\ &\leq \int p(x) KL\left[ p(z|x) \mid\mid q(z)\right]dx \\ &= \mathbb{E}_{p(x)} \left[ KL\left[ p(z|x) \mid\mid q(z)\right] \right] \\ \end{aligned}\]

此外,对于原始任务的损失函数,比如分类任务中的交叉熵损失,将其写成先编码$z$再分类$y$的形式,并且编码器$p(z|x)$编码特征分布的均值和方差,并引入重参数化操作:

\[\begin{aligned} \mathcal{L}_{CE} &= \mathbb{E}_{p(x)} \left[ \mathbb{E}_{p(z|x)} \left[ -\log p(y|z) \right] \right] \\ \end{aligned}\]

对于分类任务,引入变分信息瓶颈后的总损失函数表示为:

\[\begin{aligned} \mathcal{L} &= \mathbb{E}_{p(x)} \left[ \mathbb{E}_{p(z|x)} \left[ -\log p(y|z) \right] + \lambda KL\left[ p(z|x) \mid\mid q(z)\right] \right] \\ \end{aligned}\]

相比原始的监督学习任务,变分信息瓶颈的改动是:

  1. 使用编码器$p(z|x)$编码特征的均值和方差,加入了重参数化操作;
  2. 加入了后验分布$p(z|x)$与给定的先验分布$q(z)$之间的KL散度为额外的损失函数。

变分信息瓶颈的表现形式与变分自编码器非常类似。先验分布$q(z)$指定为标准正态分布$N(0,1)$,后验分布$p(z|x)$建模为$N(\mu, \sigma^2)$;由于两个分布都是正态分布,KL散度有闭式解(closed-form solution),计算如下:

\[\begin{aligned} KL[q(z|x)||q(z)] &= KL[\mathcal{N}(\mu,\sigma^{2})||\mathcal{N}(0,1)] \\ &= \int_{}^{} \frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}} \log \frac{\frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}}}{\frac{1}{\sqrt{2\pi}}e^{-\frac{x^2}{2}}} dx \\&= \int_{}^{} \frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}} [-\frac{1}{2}\log \sigma^2 + \frac{x^2}{2}-\frac{(x-\mu)^2}{2\sigma^2}] dx \\ &= \frac{1}{2} (-\log \sigma^2 + \mu^2+\sigma^2-1) \end{aligned}\]

变分信息瓶颈的损失函数计算为:

(mu, std), logit = self.model(x)

class_loss = F.cross_entropy(logit,y)
info_loss = -0.5*(1+2*std.log()-mu.pow(2)-std.pow(2)).sum(1).mean()
total_loss = class_loss + self.lambd*info_loss

应用变分信息瓶颈时,需要把网络建模为编码器+解码器形式,并引入重参数化:

class ToyNet(nn.Module):
    def __init__(self, K=256):
        super(ToyNet, self).__init__()
        self.K = K

        self.encode = nn.Sequential(
            nn.Linear(784, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 2*self.K))

        self.decode = nn.Sequential(
                nn.Linear(self.K, 10))

    def forward(self, x):
        if x.dim() > 2 : x = x.view(x.size(0),-1)

        statistics = self.encode(x)
        mu = statistics[:,:self.K]
        std = F.softplus(statistics[:,self.K:]-5,beta=1)

        encoding = self.reparametrize_n(mu,std,num_sample)
        logit = self.decode(encoding)
        return (mu, std), logit

    def reparametrize_n(self, mu, std):
        eps = torch.randn_like(std)
        return mu + eps * std