使用变分自编码器预测实例和包标签的非独立同分布多实例学习.
0. TL; DR
这篇论文提出了一种名为多实例变分自编码器(Multi-Instance Variational Auto-Encoder, MIVAE)的全新生成式多实例学习(MIL)框架,明确地将实例的生成过程解构为两个部分:
- 一个共享的、包级别的潜在因子 ($z_B$),用于捕捉所有实例共有的上下文依赖和结构信息。
- 一系列独立的、实例级别的潜在因子 ($z_I$),用于捕捉每个实例自身的独特变异。
通过设计了一套巧妙的编码器-解码器架构,能够同时推断包级别和实例级别的潜在因子,并端到端地预测包标签和实例标签。实验证明,通过显式地建模实例间的依赖关系,MIVAE在多个经典MIL基准和真实的医疗影像数据集上超越了包括Attention-MIL在内的多种SOTA算法。
1. 背景介绍
多实例学习(MIL)的初衷是为了处理那些无法用单一特征向量简单描述的复杂对象。例如,一个分子(包)由多种可能的三维构象(实例)组成;一张图片(包)由多个区域或图块(实例)构成。这种“包-实例”的表示法,天然地保留了对象内部的丰富结构和上下文信息。同一个包内的实例,理应不是相互独立的。
- 一个分子的所有构象,共享相同的化学键结构。
- 一篇文章的所有段落,共享相同的主题和作者的写作风格。
- 一张病理切片的所有细胞图块,共享来自同一个病人的遗传背景和染色工艺。
绝大多数现有的MIL算法假设包内所有实例是独立同分布的(i.i.d.)。这种“i.i.d.假设”不仅与现实世界相悖,更忽视了MIL表示法所提供的结构信息。它导致模型可能无法充分理解实例间的微妙关系,从而影响了包级别和实例级别的预测精度。
本文的出发点是能否设计一个模型,既能捕捉每个实例的独特性,又能显式地建模所有实例共享的上下文依赖,从而实现更精准的联合预测。
2. MIVAE 模型
MIVAE构建了一个生成模型,假设观测到的实例是由两种不同层级的潜在因子共同生成的。
- $z_B$: 包级别潜在因子 (Bag-level Latent Factor)。这是共享的,包内所有实例都依赖于它。它负责编码这个包共有的上下文、结构或风格信息。
- $z_{ij}^I$: 实例级别潜在因子 (Instance-level Latent Factor)。这是独立的,每个实例$x_{ij}$都有一个自己专属的$z_{ij}^I$。它负责编码这个实例独特的、区别于其他实例的变异信息。
- $x_{ij}$: 观测到的实例。它由共享的$z_B$和独立的$z_{ij}^I$共同解码生成。
- $Y_i$: 观测到的包标签。MIVAE假设包标签主要与共享的包级别因子$z_B$相关。
作者假设一旦以共享的包级别因子$z_B$为条件,那么包内的所有实例就变得相互独立了。为了学习这个生成模型,MIVAE采用了变分自编码器(VAE)框架。这需要设计一个生成网络(解码器)和一个推断网络(编码器)。
- 生成网络$p_θ(x_j | z_B, z_j^I)$从潜在因子重构实例。输入是共享的$z_B$和独立的$z_j^I$,输出是重构的实例$x_j$。这个解码器对于包内所有实例是共享的。
- 推断网络需要解决两个挑战:如何从一堆实例中推断出共享的$z_B$,以及如何推断出每个实例独立的$z_j^I$。
- 推断独立的$z_j^I$ (Instance-level Encoder): 对于每个实例$x_j$,都有一个独立的编码器$q_{\phi_I}(z_j^I | x_j)$来推断其专属的$z_j^I$。
- 推断共享的$z_B$ (Bag-level Encoder): 首先使用编码器$q_{φ_B}(ẑ{B_j} | x_j)$为每个实例$x_j$推断一个“临时的”包级别因子$ẑ{B_j}$。然后将所有这些临时的$ẑ_{B_j}$的分布参数(均值和方差)进行平均,得到最终的、共享的包级别因子$z_B$的分布。
MIVAE的最终目标函数由两部分构成:
\[\mathcal{L}_{\text{MIVAE}} = \mathcal{L}_{\text{ELBO}}(B, y) + \alpha \cdot \mathbb{E}[\log q_\omega(y | z_B, z_I)]\]证据下界 (ELBO) $\mathcal{L}_{\text{ELBO}}(B, y)$ 是VAE的标准损失,包含两部分:
- 重构损失: 鼓励模型能从潜在因子准确重构出原始实例。
- KL散度: 两个KL散度项,分别约束包级别因子$z_B$和实例级别因子$z_j^I$的后验分布去逼近预设的先验分布(如高斯分布),起到了正则化的作用。
为了实现标签预测,MIVAE额外引入了一个辅助分类器 $q_ω$。分类器中的一部分$f_{ω_I}$接收实例级别因子$z_j^I$作为输入,预测实例标签。整个分类器$q_ω$结合了$f_{ω_I}$的输出(通过max-pooling聚合)和包级别因子$z_B$的直接预测$f_{ω_B}(z_B)$,来得到最终的包标签。
通过这种设计,MIVAE实现了在一个统一的框架内端到端地学习非i.i.d.的实例表示,并同时进行包和实例的标签预测。
3. 实验分析
作者在多个数据集上对MIVAE进行了全面的评估。
3.1 包标签预测:经典MIL基准
评估数据集包括 Musk1, Musk2, Fox, Tiger, Elephant。MIVAE的性能具有很强的竞争力,在Musk1和Tiger上取得了最高的平均准确率。尽管这些数据集规模较小,特征也是预先计算好的,可能无法完全发挥深度生成模型的优势,但MIVAE依然表现出色。这证明了其模型设计的有效性和普适性。
3.2 实例标签预测:20 NewsGroups
数据集使用20 Newsgroups,一个文本分类任务,标签不平衡性极高(正实例比例仅约3%)。MIVAE在20个数据子集中,有15个取得了最佳性能(AUC-PR)。
这个结果表明通过将共享的上下文信息(如文章主题、作者风格)剥离到$z_B$中,模型可以更专注于利用$z_I$来识别那些真正决定标签的关键实例(如包含特定关键词的句子),从而在高度不平衡和模糊的监督信号下,实现了更精准的实例定位。这直接印证了显式建模非i.i.d.关系的巨大好处。
3.3 端到端应用:结肠癌病理图像分析
对于结肠癌H&E染色全切片图像数据集,MIVAE在包标签预测(准确率)和实例标签预测(AUC-PR)两项任务上,均全面超越了包括mi-Net、MI-Net、AttentionMIL在内的所有SOTA深度学习方法。
从预测热力图可以直观地看到,相比于mi-Net(召回率低,只找到很少的正实例)和AttentionMIL(预测置信度不高,区域模糊),MIVAE生成的热力图与病理学家标注的真实癌变区域高度吻合。这表明MIVAE不仅能判断整张切片是否癌变,还能精准地“圈出”绝大多数癌细胞所在的位置。这背后的原因正是模型学会了分离共享的病理背景($z_B$)和细胞自身的癌变特征($z_I$)。