高度不平衡分割任务中的边界损失.

在医学图像分割任务中通常存在严重的类别不平衡问题,目标前景区域的大小常常比背景区域小几个数量级,比如下图中前景区域比背景区域小500倍以上。

分割通常采用的交叉熵损失函数是一种基于分布的损失函数,在处理类别高度不平衡的问题上存在着众所周知的缺点。它假设所有样本类别的重要性相同,这通常会导致训练的不稳定,并导致决策边界偏向于数量多的类别。

分割中另一种常见的损失函数dice loss是基于区域的损失函数,在不平衡的医学图像分割问题中通常效果好。但遇到非常小的区域时可能会遇到困难,错误分类的像素可能会导致loss的剧烈降低,从而导致优化的不稳定。此外,dice loss对应查准率和召回率的调和平均,当true positive不变时,该损失对待false postivefalse negative的重要性相同,因此dice loss主要适用于这两种类型的误差数量差不多的情况。

本文提出了一种基于边界的损失函数Boundary loss,它在轮廓空间而不是区域空间上采用距离度量的形式。边界损失计算的不是区域上积分,而是区域之间边界上积分,因此可以缓解高度不平衡分割问题中区域损失的相关问题。

如何根据CNNregional softmax输出来表示对应的boundary points是个很大的挑战,本文受到用离散基于图的优化方法来计算曲线演化梯度流的启发,采用积分方法来计算边界的变化,避免了轮廓点上的局部微分计算,最终的boundary loss是网络输出区域softmax概率的线性函数和,因此可以和现有的区域损失结合使用。

1. Boundary loss的形式

$I:Ω⊂R^{2,3}→R$ 表示空间域$Ω$中的一张图片,\(g:Ω→\{0,1\}\)是该图片的ground truth分割二值图,如果像素$p$属于目标区域 $G⊂Ω$ (前景区域),$g(p)=1$;否则为$0$,即$p∈Ω∖G$(背景区域)。$s_θ:Ω→[0,1]$表示分割网络的softmax概率输出,$S_θ⊂Ω$表示模型输出的对应前景区域,即\(S_θ=\{p∈Ω\) $|$ \(s_θ(p)⩾δ\}\),其中$δ$是提前设定的阈值。

我们的目的是构建一个边界损失函数$Dist(∂G,∂S_θ)$,它采用$Ω$中区域边界空间中距离度量的形式,其中$∂G$是ground truth区域$G$的边界的一种表示(比如边界上所有点的集和),$∂S_θ$是网络输出定义的分割区域的边界。考虑下面的形状空间上非对称L2 distance的表示,它评估的是两个临近边界$∂S$和$∂G$之间的距离变化:

\[\operatorname{Dist}(\partial G, \partial S)=\int_{\partial G}\left\|y_{\partial S}(p)-p\right\|^2 d p\]

其中$p∈Ω$是边界$∂G$上的一点,$y_{∂S}(p)$是边界$∂S$上对应的点,即$y_{∂S}(p)$是边界$∂G$上$p$处的法线与$∂S$的交点,如图(a)所示。上式中的微分边界变化可以用积分方法来近似,这就避免了涉及轮廓上点的微分计算,并用区域积分来表示边界变化,如下:

\[\operatorname{Dist}(\partial G, \partial S) \approx 2 \int_{\Delta S} D_G(q) d q\]

其中$△S$表示两个轮廓之间的区域,$D_G:Ω→R^+$是一个相对于边界$∂G$的距离变换图,即$D_G(q)$表示任意点$q∈Ω$与轮廓$∂G$上最近点$z_{∂G}(q)$之间的距离:$D_G(q)=||q−z_{∂G}(q)||$,如图(b)所示。

为了证明这种近似,沿连接$∂G$上的一点$p$与$y_{∂S}(p)$之间的法线对距离图$2D_G(q)$进行积分可得:

\[\int_p^{y_{\partial S}(p)} 2 D_G(q) d q=\int_0^{\left\|y_{\partial S}(p)-p\right\|} 2 D_G d D_G=\left\|y_{\partial S}(p)-p\right\|^2\]

边界损失函数$Dist(∂G,∂S)$可以进一步表示为:

\[\begin{aligned} \frac{1}{2} \operatorname{Dist}(\partial G, \partial S)&=\int_S \phi_G(q) d q-\int_G \phi_G(q) d q \\ & =\int_{\Omega} \phi_G(q) s(q) d q-\int_{\Omega} \phi_G(q) g(q) d q \end{aligned}\]

其中\(s:Ω→\{0,1\}\)是区域$S$的二元指示函数:如果$q∈S$表示属于目标$s(q)=1$,否则为$0$。$ϕ_G:Ω→R$是边界$∂G$的水平集表示:如果$q∈G$则$ϕ_G(q)=−D_G(q)$否则$ϕ_G(q)=D_G(q)$。

注意到上式第二项不包含模型参数,因此可以丢弃。对于$S=S_θ$,即用网络的softmax输出$s_θ(q)$替换式中的$s(q)$,可以得到如下所示的边界损失,其中水平集函数$ϕ_G$是直接根据区域$G$提前计算得到的。

\[\mathcal{L}_B(\theta) = \int_{\Omega} \phi_G(q) s_{\theta}(q) d q\]

在边界损失中,每个点$q$的softmax输出通过距离函数进行加权。而在基于区域的损失函数中,这种到边界距离的信息被忽略了,区域内每个点不管到边界距离大小都都按同样的权重进行处理。

对于边界损失,当距离函数中所有的负值都保留(模型对gt区域中所有像素的softmax预测都为1)而所有的正值都舍去(即模型对背景的softmax预测都为0)时,边界损失到达全局最小,即模型的softmax预测正好输出ground truth时边界损失最小,这也验证了边界损失的有效性。

在后续的实验中可以看到,通常要把边界损失和区域损失结合起来使用才能取得好的效果。

\[\mathcal{L}_R(\theta) + \alpha \mathcal{L}_B(\theta)\]

2. Boundary loss的实现

import torch
import numpy as np
from einops import rearrange, einsum
from scipy.ndimage import distance_transform_edt


class BDLoss(nn.Module):
    def __init__(self):
        """
        compute boudary loss
        only compute the loss of foreground
        """
        super(BDLoss, self).__init__()

    @torch.no_grad()
    def one_hot2dist(self, seg):
        res = np.zeros_like(seg)
        for c in range(seg.shape[1]):
            posmask = seg[:,c,...]
            if posmask.any():
                negmask = 1.-posmask
                neg_map = distance_transform_edt(negmask)
                pos_map = distance_transform_edt(posmask)
                res[:,c,...] = neg_map * negmask - (pos_map - 1) * posmask
        return res

    def forward(self, result, gt):
        """
        result: (batch_size, class, h, w)
        gt: (batch_size, h, w)
        """
        result = torch.softmax(result, dim=1)
        gt = rearrange(gt, 'b h w -> b 1 h w')

        y_onehot = torch.zeros_like(result)
        y_onehot = y_onehot.scatter_(1, gt.data, 1)

        bound = torch.from_numpy(self.one_hot2dist(y_onehot.cpu().numpy())).float()
        pc = result[:, 1:, ...]
        dc = bound[:, 1:, ...]
        multipled = pc * dc
        return multipled.mean()
    
seg_loss = BDLoss()
result = torch.randn((16, 8, 64, 64))
gt = torch.randint(0, 8, (16, 64, 64))
loss = seg_loss(result, gt)
print(loss)