使用多尺度结构相似性度量MS-SSIM学习图像生成.
1. 结构相似性度量 SSIM
结构相似性度量(Structural Similarity Metric, SSIM)匹配两幅图像($x$和$y$)中的亮度($I$)、对比度($C$)和结构($S$)信息:
\[I(x,y) = \frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1} \\ C(x,y) = \frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2} \\ S(x,y) = \frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}\]其中$µ_x$、$µ_y$、$σ_x$和$σ_y$表示以$x$或$y$为中心的局部图像窗口中像素强度的平均值或标准差。在文中选择$x$或$y$任一侧$5$个像素的正方形邻域,得到$11×11$的窗口。$σ_{xy}$表示以$x$和$y$为中心的窗口中相应像素之间的相关系数。常数$C_1$、$C_2$和$C_3$是为数值稳定性添加的较小值。SSIM将三个函数组合起来:
\[\text{SSIM}(x,y) = I(x,y)^{\alpha} C(x,y)^{\beta} S(x,y)^{\gamma}\]2. 多尺度结构相似性度量 MS-SSIM
SSIM假设图像采样密度和观察距离是固定的,仅适用于特定范围的图像比例。多尺度结构相似性度量(MS-SSIM)是一种同时在多个尺度上运行的SSIM变体。对输入图像$x$和$y$使用低通滤波器以2的因子进行迭代下采样,尺度$j$表示以$2^{j-1}$的因子进行下采样。对比度$C(x,y)$和结构$S(x,y)$分量应用于所有尺度,而亮度分量$I(x,y)$仅应用于最粗略的尺度$M$。MI-SSIM定义如下:
\[\text{MI-SSIM}(x,y) = I_M(x,y)^{\alpha_M} \prod_{j=1}^{M} C_j(x,y)^{\beta_j} S_j(x,y)^{\gamma_j}\]在实验中下采样程度设置为$M=5$,所有权重设置为$\alpha_M=\beta_j=\gamma_j=1$。
class MSSIM(nn.Module):
def __init__(self,
in_channels: int = 3,
window_size: int = 11,
size_average:bool = True) -> None:
"""
Computes the differentiable MS-SSIM loss
:param in_channels: (Int)
:param window_size: (Int)
:param size_average: (Bool)
"""
super(MSSIM, self).__init__()
self.in_channels = in_channels
self.window_size = window_size
self.size_average = size_average
def gaussian_window(self, window_size:int, sigma: float) -> Tensor:
kernel = torch.tensor([exp((x - window_size // 2)**2/(2 * sigma ** 2))
for x in range(window_size)])
return kernel/kernel.sum()
def create_window(self, window_size, in_channels):
_1D_window = self.gaussian_window(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(in_channels, 1, window_size, window_size).contiguous()
return window
def ssim(self,
img1: Tensor,
img2: Tensor,
window_size: int,
in_channel: int,
size_average: bool) -> Tensor:
device = img1.device
window = self.create_window(window_size, in_channel).to(device)
mu1 = F.conv2d(img1, window, padding= window_size//2, groups=in_channel)
mu2 = F.conv2d(img2, window, padding= window_size//2, groups=in_channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, window, padding = window_size//2, groups=in_channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding = window_size//2, groups=in_channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding = window_size//2, groups=in_channel) - mu1_mu2
img_range = 1.0 #img1.max() - img1.min() # Dynamic range
C1 = (0.01 * img_range) ** 2
C2 = (0.03 * img_range) ** 2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
if size_average:
ret = ssim_map.mean()
else:
ret = ssim_map.mean(1).mean(1).mean(1)
return ret, cs
def forward(self, img1: Tensor, img2: Tensor) -> Tensor:
device = img1.device
weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
levels = weights.size()[0]
mssim = []
mcs = []
for _ in range(levels):
sim, cs = self.ssim(img1, img2,
self.window_size,
self.in_channels,
self.size_average)
mssim.append(sim)
mcs.append(cs)
img1 = F.avg_pool2d(img1, (2, 2))
img2 = F.avg_pool2d(img2, (2, 2))
mssim = torch.stack(mssim)
mcs = torch.stack(mcs)
# # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
# if normalize:
# mssim = (mssim + 1) / 2
# mcs = (mcs + 1) / 2
pow1 = mcs ** weights
pow2 = mssim ** weights
output = torch.prod(pow1[:-1] * pow2[-1])
return 1 - output
3. EL-VAE
作者将感知损失MS-SSIM引入VAE,构造了Expected-Loss VAE (EL-VAE)。EL-VAE的完整pytorch实现可参考PyTorch-VAE,与标准VAE的主要区别在于构造重构损失时使用MS-SSIM替代均方误差:
self.mssim_loss = MSSIM(self.in_channels,
window_size,
size_average)
recons_loss = self.mssim_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
loss = recons_loss + kld_weight * kld_loss