EBGAN:基于能量的生成对抗网络.
1. 将能量模型引入GAN
能量模型是指使用如下能量分布拟合一批真实数据$x_1,x_2,\cdots,x_n$~\(P_{data}(x)\):
\[q_{\theta}(x) = \frac{e^{-U_{\theta}(x)}}{Z_{\theta}},Z_{\theta} = \int e^{-U_{\theta}(x)}dx\]其中$U_{\theta}(x)$是带参数的能量函数;$Z_{\theta}$是配分函数(归一化因子)。直观地,真实数据分布在能量函数中势最小的位置。我们希望通过对抗训练使得生成数据$\hat{x}_1,\hat{x}_2,\cdots \hat{x}_n$的势也尽可能小。
使用判别器$D(x)$拟合能量函数$U_{\theta}(x)$,使用生成器$G(x)$构造生成分布$P_G(x)$。则判别器的目标函数为最小化真实数据分布的能量,并最大化生成数据分布的能量:
\[D^* \leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]- \Bbb{E}_{x \text{~} P_G(x)}[D(x) ]\]与此同时生成器的目标函数为最小化生成数据分布的能量:
\[G^* \leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[D(x) ]\]至此,在能量模型的角度下,GAN的目标函数写作:
\[\begin{aligned} D^* &\leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]- \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \\ G^* &\leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \end{aligned}\]2. Energy-based GAN (EBGAN)
EBGAN的判别器采用自编码器的形式,且能量函数采用样本的重构损失:
\[U(x) = ||D(x)-x|| = ||Dec(Enc(x))-x||\]直观地,如果一幅图像经过自编码器可以被很好的还原,则判别器认为其是真实图像,此时重构误差比较小,可以看作图像的“能量”,且最小能量值为$0$。
EBGAN的判别器实现如下(注:判别器输出编码特征,是因为后续构造损失函数需要用到):
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
# Downsampling
self.down = nn.Sequential(nn.Conv2d(opt.channels, 64, 3, 2, 1), nn.ReLU())
# Embedding layer
self.down_size = opt.img_size // 2
down_dim = 64 * self.down_size ** 2
self.embedding = nn.Linear(down_dim, 32)
# Fully-connected layers
self.fc = nn.Sequential(
nn.BatchNorm1d(32, 0.8),
nn.ReLU(inplace=True),
nn.Linear(32, down_dim),
nn.BatchNorm1d(down_dim),
nn.ReLU(inplace=True),
)
# Upsampling
self.up = nn.Sequential(nn.Upsample(scale_factor=2), nn.Conv2d(64, opt.channels, 3, 1, 1))
def forward(self, img):
out = self.down(img)
embedding = self.embedding(out.view(out.size(0), -1))
out = self.fc(embedding)
out = self.up(out.view(out.size(0), 64, self.down_size, self.down_size))
return out, embedding
EBGAN的生成器为标准的GAN生成器,实现如下:
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.init_size = opt.img_size // 4
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
self.conv_blocks = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
)
def forward(self, noise):
out = self.l1(noise)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
3. EBGAN的目标函数
能量模型角度下GAN的目标函数为:
\[\begin{aligned} D^* &\leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]- \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \\ G^* &\leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \end{aligned}\]注意到EBGAN的能量函数由均方误差构造,因此能量最小值为$0$。直接优化上式容易导致生成样本的能量$\to \infty$,从而使训练不稳定。
在实践中,通常限制生成样本的能量不超过$m$:
\[\begin{aligned} D^* &\leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]+ \Bbb{E}_{x \text{~} P_G(x)}[\max(0, m-D(x)) ] \\ G^* &\leftarrow \mathop{ \min}_{G} \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \end{aligned}\]此外,作者设计了排斥正则化(repelling regularizer)方法,旨在训练生成器时使得生成样本的多样性越大越好。具体地,对于生成样本$x_i,x_j$,希望其判别器中的编码器提取的特征编码$S_i,S_j$相似程度越低越好,采用余弦相似度衡量:
\[f_{PT}(S) = \frac{1}{N(N-1)}\sum_i \sum_{j \neq i} (\frac{S_i^TS_j}{||S_i|| \cdot ||S_j||})^2\]将上式作为pulling-away损失加入到生成器的目标函数中:
\[\begin{aligned} D^* &\leftarrow \mathop{ \min}_{D} \Bbb{E}_{x \text{~} P_{data}(x)} [ D(x)]+ \Bbb{E}_{x \text{~} P_G(x)}[\max(0, m-D(x)) ] \\ G^* &\leftarrow \mathop{ \min}_{G} \Bbb{E}_{(x_i,x_j) \text{~} P_G(x)}[(\frac{E(x_i)^TE(x_j)}{||E(x_i)|| \cdot ||E(x_j)||})^2]+ \Bbb{E}_{x \text{~} P_G(x)}[D(x) ] \end{aligned}\]pulling-away损失的实现如下:
def pullaway_loss(embeddings):
# embeddings -> [batch, dims]
norm = torch.sqrt(torch.sum(embeddings ** 2, -1, keepdim=True))
normalized_emb = embeddings / norm
similarity = torch.matmul(normalized_emb, normalized_emb.transpose(1, 0)) # [batch, batch]
batch_size = embeddings.size(0)
loss_pt = torch.sum(similarity) / (batch_size * (batch_size - 1))
return loss_pt
下面给出EBGAN的训练过程:
lambda_pt = 0.1
margin = max(1, opt.batch_size / 64.0)
# 定义损失函数
pixelwise_loss = nn.MSELoss() # 均方误差损失
for epoch in range(n_epochs):
for i, real_imgs in enumerate(dataloader):
# 采样并生成样本
z = torch.randn(real_imgs.shape[0], latent_dim)
gen_imgs = generator(z)
# 训练判别器
optimizer_D.zero_grad()
# 计算判别器的损失
real_recon, _ = discriminator(real_imgs)
fake_recon, _ = discriminator(gen_imgs.detach())
d_loss_real = pixelwise_loss(real_recon, real_imgs)
d_loss_fake = pixelwise_loss(fake_recon, gen_imgs.detach())
d_loss = d_loss_real - torch.clamp(d_loss_fake, max=margin)
# 更新判别器参数
d_loss.backward()
optimizer_D.step()
# 训练生成器
optimizer_G.zero_grad()
fake_recon, img_embeddings = discriminator(gen_imgs)
g_loss_fake = pixelwise_loss(fake_recon, gen_imgs)
g_loss_pull = pullaway_loss(img_embeddings)
g_loss = g_loss_fake + lambda_pt * g_loss_pull
g_loss.backward()
optimizer_G.step()
EBGAN的完整pytorch实现可参考PyTorch-GAN。