使用变分自编码器预测实例和包标签的非独立同分布多实例学习.

0. TL; DR

这篇论文提出了一种名为多实例变分自编码器(Multi-Instance Variational Auto-Encoder, MIVAE)的全新生成式多实例学习(MIL)框架,明确地将实例的生成过程解构为两个部分:

  1. 一个共享的、包级别的潜在因子 ($z_B$),用于捕捉所有实例共有的上下文依赖和结构信息。
  2. 一系列独立的、实例级别的潜在因子 ($z_I$),用于捕捉每个实例自身的独特变异。

通过设计了一套巧妙的编码器-解码器架构,能够同时推断包级别和实例级别的潜在因子,并端到端地预测包标签和实例标签。实验证明,通过显式地建模实例间的依赖关系,MIVAE在多个经典MIL基准和真实的医疗影像数据集上超越了包括Attention-MIL在内的多种SOTA算法。

1. 背景介绍

多实例学习(MIL)的初衷是为了处理那些无法用单一特征向量简单描述的复杂对象。例如,一个分子(包)由多种可能的三维构象(实例)组成;一张图片(包)由多个区域或图块(实例)构成。这种“包-实例”的表示法,天然地保留了对象内部的丰富结构和上下文信息。同一个包内的实例,理应不是相互独立的。

绝大多数现有的MIL算法假设包内所有实例是独立同分布的(i.i.d.)。这种“i.i.d.假设”不仅与现实世界相悖,更忽视了MIL表示法所提供的结构信息。它导致模型可能无法充分理解实例间的微妙关系,从而影响了包级别和实例级别的预测精度。

本文的出发点是能否设计一个模型,既能捕捉每个实例的独特性,又能显式地建模所有实例共享的上下文依赖,从而实现更精准的联合预测。

2. MIVAE 模型

MIVAE构建了一个生成模型,假设观测到的实例是由两种不同层级的潜在因子共同生成的。

作者假设一旦以共享的包级别因子$z_B$为条件,那么包内的所有实例就变得相互独立了。为了学习这个生成模型,MIVAE采用了变分自编码器(VAE)框架。这需要设计一个生成网络(解码器)和一个推断网络(编码器)

  1. 生成网络$p_θ(x_j | z_B, z_j^I)$从潜在因子重构实例。输入是共享的$z_B$和独立的$z_j^I$,输出是重构的实例$x_j$。这个解码器对于包内所有实例是共享的。
  2. 推断网络需要解决两个挑战:如何从一堆实例中推断出共享的$z_B$,以及如何推断出每个实例独立的$z_j^I$。
    • 推断独立的$z_j^I$ (Instance-level Encoder): 对于每个实例$x_j$,都有一个独立的编码器$q_{\phi_I}(z_j^I | x_j)$来推断其专属的$z_j^I$。
\[q_{\phi_I}(z_j^I | x_j) \sim \mathcal{N}(\mu=f_{\psi_{\phi_I}}(x_j), \sigma^2=f_{\pi_{\phi_I}}(x_j))\] \[q_{\phi_B}(z_B | B) \sim \mathcal{N}\left(\frac{1}{n_i}\sum_j f_{\psi_{\phi_B}}(x_j), \frac{1}{n_i}\sum_j f_{\pi_{\phi_B}}(x_j)\right)\]

MIVAE的最终目标函数由两部分构成:

\[\mathcal{L}_{\text{MIVAE}} = \mathcal{L}_{\text{ELBO}}(B, y) + \alpha \cdot \mathbb{E}[\log q_\omega(y | z_B, z_I)]\]

证据下界 (ELBO) $\mathcal{L}_{\text{ELBO}}(B, y)$ 是VAE的标准损失,包含两部分:

为了实现标签预测,MIVAE额外引入了一个辅助分类器 $q_ω$。分类器中的一部分$f_{ω_I}$接收实例级别因子$z_j^I$作为输入,预测实例标签。整个分类器$q_ω$结合了$f_{ω_I}$的输出(通过max-pooling聚合)和包级别因子$z_B$的直接预测$f_{ω_B}(z_B)$,来得到最终的包标签。

通过这种设计,MIVAE实现了在一个统一的框架内端到端地学习非i.i.d.的实例表示,并同时进行包和实例的标签预测。

3. 实验分析

作者在多个数据集上对MIVAE进行了全面的评估。

3.1 包标签预测:经典MIL基准

评估数据集包括 Musk1, Musk2, Fox, Tiger, ElephantMIVAE的性能具有很强的竞争力,在Musk1Tiger上取得了最高的平均准确率。尽管这些数据集规模较小,特征也是预先计算好的,可能无法完全发挥深度生成模型的优势,但MIVAE依然表现出色。这证明了其模型设计的有效性和普适性。

3.2 实例标签预测:20 NewsGroups

数据集使用20 Newsgroups,一个文本分类任务,标签不平衡性极高(正实例比例仅约3%)。MIVAE在20个数据子集中,有15个取得了最佳性能(AUC-PR)。

这个结果表明通过将共享的上下文信息(如文章主题、作者风格)剥离到$z_B$中,模型可以更专注于利用$z_I$来识别那些真正决定标签的关键实例(如包含特定关键词的句子),从而在高度不平衡和模糊的监督信号下,实现了更精准的实例定位。这直接印证了显式建模非i.i.d.关系的巨大好处。

3.3 端到端应用:结肠癌病理图像分析

对于结肠癌H&E染色全切片图像数据集,MIVAE在包标签预测(准确率)和实例标签预测(AUC-PR)两项任务上,均全面超越了包括mi-Net、MI-Net、AttentionMIL在内的所有SOTA深度学习方法。

从预测热力图可以直观地看到,相比于mi-Net(召回率低,只找到很少的正实例)和AttentionMIL(预测置信度不高,区域模糊),MIVAE生成的热力图与病理学家标注的真实癌变区域高度吻合。这表明MIVAE不仅能判断整张切片是否癌变,还能精准地“圈出”绝大多数癌细胞所在的位置。这背后的原因正是模型学会了分离共享的病理背景($z_B$)和细胞自身的癌变特征($z_I$)。