DFCVAE:使用特征感知损失约束深度特征一致性.

作者将VAE中逐像素的重构损失替换为深度特征的一致性损失,从而使得重构图像具有更自然的视觉外观和更高的感知质量。特征感知损失是通过预训练的深度卷积神经网络的隐藏层特征构造的。

1. 特征感知损失 Feature Perceptual Loss

逐像素的重构损失(如L2损失)用于测量重建图像和原始图像之间的差异。然而生成的图像往往非常模糊,这是因为逐像素损失无法捕获两幅图像之间的感知差异和空间相关性。比如两幅图像只有几个像素的不同,对人类的视觉感知差异很小,但可能具有非常高的逐像素损失。

特征感知损失是指从预训练的深层卷积神经网络中提取的两幅图像的隐藏层特征表示之间的差异,能够通过确保输入和输出图像的隐藏层特征的一致性来提高VAE生成图像的质量。

特征感知损失并不是直接在像素空间中比较输入图像和生成图像,而是将它们分别送入到预训练的深度卷积网络,然后测量隐藏层特征表示之间的差异:

实验使用VGGNet作为预训练模型,第$l$层的特征感知损失定义为输入图像与重构图像在第$l$层隐藏层特征的均方误差:

\[\mathcal{L}_{rec}^{l} = \frac{1}{2C^lH^lW^l}\sum_{c=1}^{C^l}\sum_{h=1}^{H^l}\sum_{w=1}^{W^l}(\Phi(x)_{c,h,w}^l-\Phi(\hat{x})_{c,h,w}^l)^2\]

则网络总损失为特征感知损失与KL散度之和:

\[\mathcal{L}_{total} = \alpha \mathcal{L}_{KL} + \beta \sum_{l=1}^{L} \mathcal{L}_{rec}^{l}\]

2. DFCVAE的pytorch实现

所提DFCVAE的网络结构如图所示:

DFCVAE的完整pytorch实现可参考PyTorch-VAE,下面进行分析。

预训练网络设置为VGGNet

from torchvision.models import vgg19_bn
self.feature_network = vgg19_bn(pretrained=True)
# Freeze the pretrained feature network
for param in self.feature_network.parameters():
    param.requires_grad = False
self.feature_network.eval()

分别提取输入图像和重构图像的预训练网络隐藏层特征:

def extract_features(self,
                     input: Tensor,
                     feature_layers: List = None) -> List[Tensor]:
    """
    Extracts the features from the pretrained model
    at the layers indicated by feature_layers.
    :param input: (Tensor) [B x C x H x W]
    :param feature_layers: List of string of IDs
    :return: List of the extracted features
    """
    if feature_layers is None:
        feature_layers = ['14', '24', '34', '43']
    features = []
    result = input
    for (key, module) in self.feature_network.features._modules.items():
        result = module(result)
        if(key in feature_layers):
            features.append(result)
        return 

recons_features = self.extract_features(recons)
input_features = self.extract_features(input)

构造损失函数:

recons_loss =F.mse_loss(recons, input)
feature_loss = 0.0
for (r, i) in zip(recons_features, input_features):
    feature_loss += F.mse_loss(r, i)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

loss = self.beta * (recons_loss + feature_loss) + self.alpha * kld_loss