MMD GAN:最大平均差异生成对抗网络.

本文作者通过积分概率度量(integral probability metrics, IPM)构造了真实数据分布$P_{data}(x)$和生成数据分布$P_G(x)$之间的最大平均差异(maximum mean discrepancy, MMD),并进一步设计了最大平均差异生成对抗网络(Maximum Mean Discrepancy GAN, MMD GAN)。

1. 最大平均差异

积分概率度量寻找满足某种限制条件的函数集合\(\mathcal{F}\)中的连续函数$f(\cdot)$,使得该函数能够提供足够多的关于矩的信息;然后寻找一个最优的\(f(x)\in \mathcal{F}\)使得两个概率分布$p(x)$和$q(x)$之间的差异最大,该最大差异即为两个分布之间的距离:

\[d_{\mathcal{F}}(p(x),q(x)) = \mathop{\sup}_{f(x)\in \mathcal{F}} \Bbb{E}_{x \text{~} p(x)}[f(x)]-\Bbb{E}_{x \text{~} q(x)}[f(x)]\]

最大平均差异没有直接定义函数空间\(\mathcal{F}\),而是通过核方法定义了连续函数$f(\cdot)$的内积空间。

为了使函数$f(\cdot)$是可计算的,MMD将其限制在再生希尔伯特空间(reproducing kernel Hilbert space, RKHS)的单位球内:$||f||_{\mathcal{H}} \leq 1$。MMD定义如下:

\[\text{MMD}(f,p(x),q(x)) = \mathop{\sup}_{||f||_{\mathcal{H}} \leq 1} \Bbb{E}_{x \text{~} p(x)}[f(x)]-\Bbb{E}_{x \text{~} q(x)}[f(x)]\]

根据Riesz表示定理,希尔伯特空间中的函数$f(x)$可以表示为内积形式:

\[f(x) = <f,\phi(x)>_{\mathcal{H}}\]

MMD可以写作:

\[\begin{aligned} \text{MMD}(f,p(x),q(x)) &= \mathop{\sup}_{||f||_{\mathcal{H}} \leq 1} <f,\Bbb{E}_{x \text{~} p(x)}[\phi(x)]>_{\mathcal{H}} -<f,\Bbb{E}_{x \text{~} q(x)}[\phi(x)]>_{\mathcal{H}} \\ &= \mathop{\sup}_{||f||_{\mathcal{H}} \leq 1} <f,\Bbb{E}_{x \text{~} p(x)}[\phi(x)]-\Bbb{E}_{x \text{~} q(x)}[\phi(x)]>_{\mathcal{H}} \end{aligned}\]

MMD的平方可以写作:

\[\begin{aligned} \text{MMD}^2(f,p(x),q(x)) &= (\mathop{\sup}_{||f||_{\mathcal{H}} \leq 1} <f,\Bbb{E}_{x \text{~} p(x)}[\phi(x)]-\Bbb{E}_{x \text{~} q(x)}[\phi(x)]>_{\mathcal{H}})^2 \\ &= ||\Bbb{E}_{x \text{~} p(x)}[\phi(x)]-\Bbb{E}_{x \text{~} q(x)}[\phi(x)]||_{\mathcal{H}}^2 \end{aligned}\]

其中上式应用了两个向量内积的最大值必定在两向量同方向上取得,即当\(f=\Bbb{E}_{x \text{~} p(x)}[\phi(x)]-\Bbb{E}_{x \text{~} q(x)}[\phi(x)]\)时取得最大值。

对期望的计算应用蒙特卡洛估计,可得:

\[\begin{aligned} \text{MMD}^2(f,p(x),q(x)) &= ||\Bbb{E}_{x \text{~} p(x)}[\phi(x)]-\Bbb{E}_{x \text{~} q(x)}[\phi(x)]||_{\mathcal{H}}^2 \\ &≈ ||\frac{1}{m}\sum_{i=1}^m\phi(x_i)-\frac{1}{n}\sum_{j=1}^n\phi(x_j)||_{\mathcal{H}}^2 \end{aligned}\]

引入核函数 $\kappa(x,x’)=<\phi(x),\phi(x’)>_{\mathcal{H}}$,则MMD的平方近似计算为:

\[\begin{aligned} \text{MMD}^2(f,p(x),q(x)) = & \Bbb{E}_{x,x' \text{~} p(x)} [\kappa(x,x')] + \Bbb{E}_{x,x' \text{~} q(x)} [\kappa(x,x')] \\ & -2\Bbb{E}_{x \text{~} p(x),x' \text{~} q(x)} [\kappa(x,x')] \\ ≈ &\frac{1}{m^2}\sum_{i=1}^m\sum_{i'=1}^m \kappa(x_i,x_{i'})+\frac{1}{n^2}\sum_{j=1}^n\sum_{j'=1}^n \kappa(x_j,x_{j'}) \\ & - \frac{2}{mn}\sum_{i=1}^m\sum_{j=1}^n \kappa(x_i,x_j) \end{aligned}\]

2. MMD GAN

MMD GAN使用MMD衡量真实数据分布$P_{data}(x)$和生成数据分布$P_G(x)$之间的距离,其中核函数$\kappa(x,x’)$通过判别器进行学习:

\[\kappa(x,x') = \exp(-||D(x)-D(x')||)\]

MMD GAN的目标函数如下:

\[\begin{aligned} \mathop{ \min}_{G} \mathop{ \max}_{D} \Bbb{E}_{x,x' \text{~} P_{data}(x)} [\kappa(x,x')] + \Bbb{E}_{x,x' \text{~} P_G(x)} [\kappa(x,x')] -2\Bbb{E}_{x \text{~} P_{data}(x),x' \text{~} P_G(x)} [\kappa(x,x')] \end{aligned}\]