使用对数双曲余弦损失改进变分自编码器.

在变分自编码器VAE中,解码样本和原输入之间的重构损失函数默认选择L2损失,本文作者建议将其替换为对数双曲余弦(log cosh)损失,实验结果表明其能够显著改善VAE的重构质量。

1. 对数双曲余弦损失 Log Hyperbolic Cosine Loss

对数双曲余弦(log cosh)函数的表达式如下:

\[f(t;a) = \frac{1}{a} \log( \cosh(at)) = \frac{1}{a} \log(\frac{e^{at}+e^{-at}}{2})\]

对于较大的$| t |$,该函数接近L1函数,对于较小的$| t |$,该函数接近L2函数,从而结合了L2函数的平滑特点以及L1函数的鲁棒性和图像清晰度优势。

此外,该函数的导数是简单的tanh函数,容易训练且实现简单:

\[\nabla_t f(t;a)= \nabla_t \frac{1}{a} \log(\frac{e^{at}+e^{-at}}{2}) = \frac{1}{a}\frac{2}{e^{at}+e^{-at}} \nabla_t \frac{e^{at}+e^{-at}}{2} \\= \frac{1}{a}\frac{2}{e^{at}+e^{-at}} \frac{ae^{at}-ae^{-at}}{2} = \frac{e^{at}-e^{-at}}{e^{at}+e^{-at}} = \tanh (at)\]

2. LogCosh VAE的Pytorch

LogCosh VAE的完整pytorch实现可参考PyTorch-VAE,与标准VAE的主要区别在于构造重构损失时使用log cosh替代均方误差:

t = recons - input
recons_loss = self.alpha * t + \
              torch.log(1. + torch.exp(- 2 * self.alpha * t)) - \
              torch.log(torch.tensor(2.0))
recons_loss = (1. / self.alpha) * recons_loss.mean()

kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

loss = recons_loss + self.beta * kld_loss