使用多尺度结构相似性度量MS-SSIM学习图像生成.

1. 结构相似性度量 SSIM

结构相似性度量(Structural Similarity Metric, SSIM)匹配两幅图像($x$和$y$)中的亮度($I$)、对比度($C$)和结构($S$)信息:

\[I(x,y) = \frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1} \\ C(x,y) = \frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2} \\ S(x,y) = \frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}\]

其中$µ_x$、$µ_y$、$σ_x$和$σ_y$表示以$x$或$y$为中心的局部图像窗口中像素强度的平均值或标准差。在文中选择$x$或$y$任一侧$5$个像素的正方形邻域,得到$11×11$的窗口。$σ_{xy}$表示以$x$和$y$为中心的窗口中相应像素之间的相关系数。常数$C_1$、$C_2$和$C_3$是为数值稳定性添加的较小值。SSIM将三个函数组合起来:

\[\text{SSIM}(x,y) = I(x,y)^{\alpha} C(x,y)^{\beta} S(x,y)^{\gamma}\]

2. 多尺度结构相似性度量 MS-SSIM

SSIM假设图像采样密度和观察距离是固定的,仅适用于特定范围的图像比例。多尺度结构相似性度量(MS-SSIM)是一种同时在多个尺度上运行的SSIM变体。对输入图像$x$和$y$使用低通滤波器以2的因子进行迭代下采样,尺度$j$表示以$2^{j-1}$的因子进行下采样。对比度$C(x,y)$和结构$S(x,y)$分量应用于所有尺度,而亮度分量$I(x,y)$仅应用于最粗略的尺度$M$。MI-SSIM定义如下:

\[\text{MI-SSIM}(x,y) = I_M(x,y)^{\alpha_M} \prod_{j=1}^{M} C_j(x,y)^{\beta_j} S_j(x,y)^{\gamma_j}\]

在实验中下采样程度设置为$M=5$,所有权重设置为$\alpha_M=\beta_j=\gamma_j=1$。

class MSSIM(nn.Module):
    def __init__(self,
                 in_channels: int = 3,
                 window_size: int = 11,
                 size_average:bool = True) -> None:
        """
        Computes the differentiable MS-SSIM loss
        :param in_channels: (Int)
        :param window_size: (Int)
        :param size_average: (Bool)
        """
        super(MSSIM, self).__init__()
        self.in_channels = in_channels
        self.window_size = window_size
        self.size_average = size_average

    def gaussian_window(self, window_size:int, sigma: float) -> Tensor:
        kernel = torch.tensor([exp((x - window_size // 2)**2/(2 * sigma ** 2))
                               for x in range(window_size)])
        return kernel/kernel.sum()

    def create_window(self, window_size, in_channels):
        _1D_window = self.gaussian_window(window_size, 1.5).unsqueeze(1)
        _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
        window = _2D_window.expand(in_channels, 1, window_size, window_size).contiguous()
        return window

    def ssim(self,
             img1: Tensor,
             img2: Tensor,
             window_size: int,
             in_channel: int,
             size_average: bool) -> Tensor:

        device = img1.device
        window = self.create_window(window_size, in_channel).to(device)
        mu1 = F.conv2d(img1, window, padding= window_size//2, groups=in_channel)
        mu2 = F.conv2d(img2, window, padding= window_size//2, groups=in_channel)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2

        sigma1_sq = F.conv2d(img1 * img1, window, padding = window_size//2, groups=in_channel) - mu1_sq
        sigma2_sq = F.conv2d(img2 * img2, window, padding = window_size//2, groups=in_channel) - mu2_sq
        sigma12   = F.conv2d(img1 * img2, window, padding = window_size//2, groups=in_channel) - mu1_mu2

        img_range = 1.0 #img1.max() - img1.min() # Dynamic range
        C1 = (0.01 * img_range) ** 2
        C2 = (0.03 * img_range) ** 2

        v1 = 2.0 * sigma12 + C2
        v2 = sigma1_sq + sigma2_sq + C2
        cs = torch.mean(v1 / v2)  # contrast sensitivity

        ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)

        if size_average:
            ret = ssim_map.mean()
        else:
            ret = ssim_map.mean(1).mean(1).mean(1)
        return ret, cs

    def forward(self, img1: Tensor, img2: Tensor) -> Tensor:
        device = img1.device
        weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
        levels = weights.size()[0]
        mssim = []
        mcs = []

        for _ in range(levels):
            sim, cs = self.ssim(img1, img2,
                                self.window_size,
                                self.in_channels,
                                self.size_average)
            mssim.append(sim)
            mcs.append(cs)

            img1 = F.avg_pool2d(img1, (2, 2))
            img2 = F.avg_pool2d(img2, (2, 2))

        mssim = torch.stack(mssim)
        mcs = torch.stack(mcs)

        # # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
        # if normalize:
        #     mssim = (mssim + 1) / 2
        #     mcs = (mcs + 1) / 2

        pow1 = mcs ** weights
        pow2 = mssim ** weights

        output = torch.prod(pow1[:-1] * pow2[-1])
        return 1 - output

3. EL-VAE

作者将感知损失MS-SSIM引入VAE,构造了Expected-Loss VAE (EL-VAE)EL-VAE的完整pytorch实现可参考PyTorch-VAE,与标准VAE的主要区别在于构造重构损失时使用MS-SSIM替代均方误差:

self.mssim_loss = MSSIM(self.in_channels,
                        window_size,
                        size_average)
recons_loss = self.mssim_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
loss = recons_loss + kld_weight * kld_loss