使用对数双曲余弦损失改进变分自编码器.

在变分自编码器VAE中,解码样本和原输入之间的重构损失函数默认选择L2损失,本文作者建议将其替换为对数双曲余弦(log cosh)损失,实验结果表明其能够显著改善VAE的重构质量。

1. 对数双曲余弦损失 Log Hyperbolic Cosine Loss

对数双曲余弦(log cosh)函数的表达式如下:

f(t;a)=1alog(cosh(at))=1alog(eat+eat2)

对于较大的|t|,该函数接近L1函数,对于较小的|t|,该函数接近L2函数,从而结合了L2函数的平滑特点以及L1函数的鲁棒性和图像清晰度优势。

此外,该函数的导数是简单的tanh函数,容易训练且实现简单:

tf(t;a)=t1alog(eat+eat2)=1a2eat+eatteat+eat2=1a2eat+eataeataeat2=eateateat+eat=tanh(at)

2. LogCosh VAE的Pytorch

LogCosh VAE的完整pytorch实现可参考PyTorch-VAE,与标准VAE的主要区别在于构造重构损失时使用log cosh替代均方误差:

t = recons - input
recons_loss = self.alpha * t + \
              torch.log(1. + torch.exp(- 2 * self.alpha * t)) - \
              torch.log(torch.tensor(2.0))
recons_loss = (1. / self.alpha) * recons_loss.mean()

kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

loss = recons_loss + self.beta * kld_loss