LSGAN:使用均方误差构造目标函数.
GAN的目标函数为如下min-max形式:
\[\begin{aligned} \mathop{ \min}_{G} \mathop{ \max}_{D} \Bbb{E}_{x \text{~} P_{data}(x)}[\log D(x)] + \Bbb{E}_{x \text{~} P_{G}(x)}[\log(1-D(x))] \end{aligned}\]其中判别器相当于二分类器,输出用softmax或tanh激活。这两类函数有比较明显的梯度消失现象:
Least Squares GAN (LSGAN)对判别器不再使用二分类器,而是使用线性回归代替:
⚪ 目标函数
LSGAN的目标函数如下:
\[\begin{aligned} &\mathop{ \min}_{D} \frac{1}{2} \Bbb{E}_{x \text{~} P_{data}(x)}[(D(x)-b)^2] + \frac{1}{2} \Bbb{E}_{x \text{~} P_{G}(x)}[(D(x)-a)^2] \\ & \mathop{ \min}_{G} \frac{1}{2} \Bbb{E}_{x \text{~} P_{G}(x)}[(D(x)-c)^2] \end{aligned}\]其中$a$是生成样本的标签,$b$是真实样本的标签,$c$是将生成样本看作真实样本的标签。
⚪ 等价性:$\chi^2$散度
判别器的目标函数写作积分形式:
\[\frac{1}{2} \int_{x} P_{data}(x)(D(x)-b)^2 dx + \frac{1}{2} \int_{x} P_{G}(x)(D(x)-a)^2 dx\]对上式取极值可得最优判别器$D^{*}$:
\[D^* = \frac{bP_{data}(x) + aP_{G}(x)}{P_{data}(x) + P_{G}(x)}\]上式代入生成器的目标函数(额外增加了真实数据项):
\[\begin{aligned} & \frac{1}{2} \Bbb{E}_{x \text{~} P_{G}(x)}[(D^*(x)-c)^2] +\frac{1}{2} \Bbb{E}_{x \text{~} P_{data}(x)}[(D^*(x)-c)^2] \\ = & \frac{1}{2} \Bbb{E}_{x \text{~} P_{G}(x)}[(\frac{bP_{data}(x) + aP_{G}(x)}{P_{data}(x) + P_{G}(x)}-c)^2] +\frac{1}{2} \Bbb{E}_{x \text{~} P_{data}(x)}[(\frac{bP_{data}(x) + aP_{G}(x)}{P_{data}(x) + P_{G}(x)}-c)^2] \\ = & \frac{1}{2} \Bbb{E}_{x \text{~} P_{G}(x)}[(\frac{(b-c)P_{data}(x) + (a-c)P_{G}(x)}{P_{data}(x) + P_{G}(x)})^2] +\frac{1}{2} \Bbb{E}_{x \text{~} P_{data}(x)}[(\frac{(b-c)P_{data}(x) + (a-c)P_{G}(x)}{P_{data}(x) + P_{G}(x)})^2] \\ = & \frac{1}{2} \int_{x} P_{G}(x)(\frac{(b-c)P_{data}(x) + (a-c)P_{G}(x)}{P_{data}(x) + P_{G}(x)})^2 dx +\frac{1}{2} \int_{x} P_{data}(x)(\frac{(b-c)P_{data}(x) + (a-c)P_{G}(x)}{P_{data}(x) + P_{G}(x)})^2 dx \\ = & \frac{1}{2} \int_{x} (P_{G}(x)+P_{data}(x))(\frac{(b-c)P_{data}(x) + (a-c)P_{G}(x)}{P_{data}(x) + P_{G}(x)})^2 dx \\ = & \frac{1}{2} \int_{x} \frac{((b-c)P_{data}(x) + (a-c)P_{G}(x))^2}{P_{data}(x) + P_{G}(x)} dx \\ = & \frac{1}{2} \int_{x} \frac{((b-c)(P_{data}(x)+P_{G}(x)) - (b-a)P_{G}(x))^2}{P_{data}(x) + P_{G}(x)} dx \end{aligned}\]令$b-c = 1, b-a =2$,则有:
\[\int_{x} \frac{(P_{data}(x)+P_{G}(x) -2P_{G}(x))^2}{P_{data}(x) + P_{G}(x)} dx = \chi^2_{Pearson}[P_{data}+P_{G}||2P_{G}]\]因此目标函数等价于最小化上述两种分布的Pearson $\chi^2$散度。
⚪ 参数选择
作者提供了两种$a,b,c$的选择方式。
- 若希望目标比较接近Pearson $\chi^2$散度,则选择$a=-1,b=1,c=0$:
- 若希望生成样本接近真实样本,则选择$a=-1,b=c=1$:
⚪ 网络设计
作者在VGGNet的基础上设计网络:
⚪ pytorch实现
LSGAN的完整pytorch实现可参考PyTorch-GAN。
import torch
# 定义生成器和判别器
generator = Generator()
discriminator = Discriminator() # 判别器的输出应为线性激活标量
# 定义损失函数
adversarial_loss = torch.nn.MSELoss() # 均方误差损失代替二元交叉熵
# 定义优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
for epoch in range(n_epochs):
for i, real_imgs in enumerate(dataloader):
# 构造对抗标签
valid = torch.ones(real_imgs.shape[0], 1)
fake = torch.zeros(real_imgs.shape[0], 1)
# 采样并生成样本
z = torch.randn(real_imgs.shape[0], latent_dim)
gen_imgs = generator(z)
# 训练判别器
optimizer_D.zero_grad()
# 计算判别器的损失
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 此处不计算生成器的梯度
d_loss = (real_loss + fake_loss) / 2
# 更新判别器参数
d_loss.backward()
optimizer_D.step()
# 训练生成器
optimizer_G.zero_grad()
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()