用于实例标签预测和分布外泛化的多实例因果表示学习.

0. TL; DR

这篇论文提出了一种名为 CausalMIL 的全新框架。其基本假设是任何一个实例的产生,都源于两类潜在因素:因果因素 (causal factors, $z_c$),即决定其标签的核心特征(如数字”1”的笔画);和非因果/混淆因素 (non-causal factors, $z_e$),即与标签无关的风格、环境等信息(如数字”1”的书写角度、颜色)。

CausalMIL可识别变分自编码器 (identifiable VAE, iVAE) 的思想融入MIL。它利用“包”作为一种独特的条件信息,在理论上保证了能够将因果因素$z_c$和非因果因素$z_e$分离开来(即可识别性)。

通过只利用纯净的因果表征$z_c$进行预测,CausalMIL在多个数据集上的实例标签预测任务中,显著超越了包括Attention-MIL在内的多种SOTA基线。由于模型学会了忽略环境带来的混淆因素$z_e$,它在面对与训练集分布不一致的测试数据时,表现出惊人的鲁棒性。

1. 背景介绍

多实例学习 (Multiple Instance Learning, MIL) 是弱监督学习的一个典型范式:有一堆“包”(Bags),每个包里装着若干“实例”(Instances);只知道包的标签,却不知道里面每个实例的具体标签。

大多数现有的MIL算法,其核心思路都可以归结为“标签消歧”(disambiguation):要么试图从正包中找出那个“最像正”的实例作为代表;要么通过注意力机制,为每个实例分配一个权重,来猜测它对包标签的贡献度。这些方法本质上仍是在处理包级标签“不精确”带来的麻烦,而没有挖掘“包”结构本身可能蕴含的更深层次信息。

本文作者指出,一个包内的所有实例,往往是在相似的“环境”或“上下文”中产生的。“环境”信息虽然可能与实例的核心标签无关,但它作为一种共有的、非因果的混淆因素,系统性地影响了包内所有实例的外观。

如果能将这种包继承的非因果因素 ($z_e$) 与决定实例标签的内在因果因素 ($z_c$) 分离开来,只用纯净的 $z_c$ 去预测实例标签,排除了 $z_e$ 带来的干扰;当模型面对一个来自新“环境”的测试包时,由于它已经学会了忽略环境因素 $z_e$,因此能做出更鲁棒的预测。

2. CausalMIL模型

CausalMIL建模为一个因果图模型。令$B_i$为第$i$个包所代表的辅助/环境信息;$z_{ij}^c$为第$i$个包中第$j$个实例的因果表征。它直接决定了实例标签$y_{ij}$;$z_{ij}^e$为第$i$个包中第$j$个实例的非因果表征。它受到包环境$B_i$的影响;$x_{ij}$为观测到的实例。它由$z_{ij}^c$和$z_{ij}^e$共同生成。

包标签$Y_i$由其内部所有实例标签$y_{ij}$根据标准MIL假设决定。$p(x_{ij}|z_{ij})$的生成机制在不同包之间是不变的。$p(y_{ij}|z_{ij}^c)$的因果机制在不同包之间也是不变的。

CausalMIL通过iVAE从观测数据$x$中把$z_c$和$z_e$分离开来。iVAE证明,如果VAE的先验分布 $p(z)$ 不是无条件的,而是以一个额外的辅助变量 $u$ 为条件的,即 $p(z|u)$,那么在一定条件下,潜在变量$z$是可以被识别(identify)出来的(识别指最多相差一个置换或简单的线性变换)。

作者在理论上证明,只要满足一些条件(如生成函数可逆、存在足够多样化的包等),并且潜在变量的条件先验 $p(z|B, y)$ 属于一个广义的指数族分布,那么潜在变量$z$就是可识别的。

\[p(z|B, y) = \frac{Q(z)}{C(B, y)} \exp[T(z)^T \lambda(B, y)]\]

其中,$Q(z)$是基础测量,$C(B, y)$是归一化常数,$T(z)$是充分统计量,$\lambda(B,y)$是一个由包信息和标签决定的参数函数。

由于一个包是实例的集合,具有置换不变性,作者采用了一个置换不变的深度集合网络 (Deep Sets) 来对包内所有实例信息进行聚合,从而得到一个表征包环境的向量$B_i$。

\[B_i = \text{net}(\{x_{i1}, \dots, x_{in_i}\}) = \rho[\text{pool}(\{\phi(x_{i1}), \dots, \phi(x_{in_i})\})]\]

其中 $\phi$ 和 $\rho$ 是神经网络,pool是求和操作。

至此有了可识别的$z$。但$z$中仍然混合着$z_c$和$z_e$。为了利用弱监督的包标签$Y_i$来引导模型只关注$z_c$,CausalMIL设计了一个ELBO(证据下界)目标函数:

\[\begin{aligned} \mathcal{L}_{\text{CausalMIL}} &= \log p_f(x^*|z^*) + \alpha \log p_\omega(Y|z^*) - \text{KL}[q_\phi(z^*|x^*, B) || p_{T,\lambda}(z^*|B)] \\ & + \log p_v(B|z) - \text{KL}[q_\psi(z|B) || p(B)] \end{aligned}\]

式中的 $z^{*}$ 和 $x^{*}$ 并不是包里的所有实例,而是通过一个max操作选出的那个“最可能为正”的实例。即 $z^{*} = \arg\max_B p_\omega(Y|z)$。这意味着,对于每个包,重构损失 $\log p_f(x^{*}|z^{*})$ 和分类损失 $\log p_\omega(Y|z^{*})$ 都只作用于这个被选出来的“关键实例”上。

由于分类器 $p_\omega(Y|z)$ 是一个简单的线性分类器,并且在每个mini-batch中它需要处理来自不同包(不同环境)的实例,它将被迫只依赖于那些在不同包之间保持不变且与标签相关的因素,也就是$z_c$,而忽略掉随包变化的$z_e$。同时,因为重构损失也只关注这个关键的正实例,编码器 $q_\phi$ 也会被激励去主要编码$z_c$,因为它包含了重构正实例所需的核心内容信息。通过这种设计,CausalMIL在端到端的训练中,自然地引导模型将注意力集中在因果表征$z_c$上。

目标函数中的最后两项构成了一个独立的辅助VAE,学习一个关于“包信息/环境B”的良好概率模型,确保包信息$B_i$是一个高质量、有意义、结构化的信号。

3. 实验分析

作者在合成的MNIST、FashionMNISTKuzushijiMNIST多实例数据集,以及真实的结肠癌病理图像数据集上,评估了CausalMIL的实例标签预测能力。

相比于Attn-MIL等判别式方法,CausalMIL的性能有质的飞跃。这表明简单地对实例加权,远不如从根源上分离因果和非因果因素来得有效。即使与同为生成模型的MIVAE相比,CausalMIL也表现出巨大优势。MIVAE虽然也使用VAE,但缺乏可识别性理论保证,因此学到的表征质量远不如CausalMIL

通过重构图像,可以直观地看到CausalMIL确实学到了因果特征。例如,在识别数字”4”时,模型只重构出了”4”的关键笔画结构,而忽略了其他数字以及”4”本身的书写风格、角度等非因果信息。

为了测试模型在面对与训练集分布不一致的数据时的鲁棒性,使用了经典的ColoredMNIST任务。任务是判断数字是否小于5,但数字的颜色与标签在训练集中存在伪相关(如,小于5的数字大概率是绿色)。而在测试集中,这种相关性被逆转。一个好的模型应该学会忽略颜色,只看数字本身。

CausalMIL在只使用弱监督包标签的情况下,测试集准确率达到了89.2%,远超所有全监督的对比方法。这个结果表明CausalMIL成功地将颜色(包环境带来的混淆因素 $z_e$)与数字的形状(因果因素 $z_c$)分离开来,并只依赖后者进行预测。这也揭示了MIL框架在OOD问题上的一个天然优势:“包”本身就可以被看作是不同的“环境”,为模型学习不变性提供了丰富的数据。