BN-VAE: 通过批量归一化缓解KL散度消失问题.

1. KL Vanishing Problem

VAE的目标函数为最小化负ELBO

\[\mathcal{L} = \mathbb{E}_{z \text{~} q(z|x)} [-\log p(x | z)] + KL[q(z|x)||p(z)]\]

损失包括重构损失和变分后验的正则化项。如果正则化项过强,会导致$KL[q(z|x)||p(z)]=0$。此时后验分布$q(z|x)$退化为高斯先验$p(z)$,即数据$x$和隐变量$z$无关。此时VAE会退化为常规的自编码器,编码器输出常数向量,VAE失去了无监督构建编码向量的能力,这个问题称为KL Vanishing问题。KL Vanishing问题多存在于自然语言处理任务中。

2. 将BN引入VAE

KL Vanishing问题是指在训练时KL散度项变成$KL[q(z|x)||p(z)]=0$,若通过调整编码器输出使得KL散度有一个大于零的下界,则能够缓解KL Vanishing问题。

$KL[q(z|x)||p(z)]$衡量后验分布$q(z|x)$和先验分布$p(z)$之间的KL散度。$q(z|x)$优化的目标是趋近标准正态分布,此时$p(z)$指定为标准正态分布$z$~\(\mathcal{N}(0,I)\)。$q(z|x)$通过神经网络进行拟合(即概率编码器),其形式人为指定为多维对角正态分布 \(\mathcal{N}(\mu,\sigma^{2})\)。

由于两个分布都是正态分布,KL散度有闭式解(closed-form solution),计算如下:

\[KL[q(z|x)||p(z)] =\frac{1}{B} \sum_{b=1}^{B} \sum_{d=1}^{D} \frac{1}{2} (-\log \sigma_{b,d}^2 + \mu_{b,d}^2+\sigma_{b,d}^2-1)\]

上式表示采样$B$个样本,且隐变量的编码维度为$D$。由于$e^x≥x+1$,所以$\sigma_{b,d}^2−\log \sigma_{b,d}^2−1≥0$,因此:

\[KL[q(z|x)||p(z)] \geq \frac{1}{B} \sum_{b=1}^{B} \sum_{d=1}^{D} \frac{1}{2} \mu_{b,d}^2 = \frac{1}{2}\sum_{d=1}^{D} (\frac{1}{B} \sum_{b=1}^{B} \mu_{b,d}^2)\]

对上式取期望:

\[\Bbb{E}[KL[q(z|x)||p(z)]] \geq \Bbb{E}[\frac{1}{2}\sum_{d=1}^{D} (\frac{1}{B} \sum_{b=1}^{B} \mu_{b,d}^2)] = \frac{1}{2}\sum_{d=1}^{D} (\frac{1}{B} \sum_{b=1}^{B} \Bbb{E}[\mu_{b,d}^2])\]

根据均值方差公式$Var[\mu]=\Bbb{E}[\mu^2]-(\Bbb{E}[\mu])^2$,因此:

\[\Bbb{E}[KL[q(z|x)||p(z)]] \geq \frac{1}{2}\sum_{d=1}^{D} (\frac{1}{B} \sum_{b=1}^{B} (Var[\mu_{b,d}]+(\Bbb{E}[\mu_{b,d}])^2))\]

上式给出了KL散度项的一个下界,由编码器的输出均值$μ$在一个批量样本内的二阶矩决定。如果在均值$μ$后增加BatchNorm层,则能够将均值$μ$的均值调整为$\beta$,方差调整为$\gamma^2$,此时有:

\[\Bbb{E}[KL[q(z|x)||p(z)]] \geq \frac{D}{2}(\gamma^2+\beta^2)\]

通过引入BatchNorm层,使得KL散度项有个正的下界,从而缓解了KL Vanishing问题。