[论文解读] Counterfactual Generative Networks
本文提出 Counterfactual Generative Networks (CGNs),将图像生成分解为独立的形状、纹理和背景机制,从而实现反事实图像生成和不变分类器,提升对域外数据的鲁棒性,在 MNIST 变体和 ImageNet 上有演示。
Neural networks are prone to learning shortcuts -- they often model simple correlations, ignoring more complex ones that potentially generalize better. Prior works on image classification show that instead of learning a connection to object shape, deep classifiers tend to exploit spurious correlations with low-level texture or the background for solving the classification task. In this work, we take a step towards more robust and interpretable classifiers that explicitly expose the task's causal structure. Building on current advances in deep generative modeling, we propose to decompose the image generation process into independent causal mechanisms that we train without direct supervision. By exploiting appropriate inductive biases, these mechanisms disentangle object shape, object texture, and background; hence, they allow for generating counterfactual images. We demonstrate the ability of our model to generate such images on MNIST and ImageNet. Further, we show that the counterfactual images can improve out-of-distribution robustness with a marginal drop in performance on the original classification task, despite being synthetic. Lastly, our generative model can be trained efficiently on a single GPU, exploiting common pre-trained models as inductive biases.
研究动机与目标
- 鼓励健壮、因果信息驱动的图像分类,避免对虚假相关的依赖。
- 将图像生成分解为独立机制,以控制形状、纹理和背景。
- 生成带有未见因素组合的反事实图像,以训练不变分类器。
- 在 MNIST 变体和 ImageNet 上展示该方法,显示出对域外鲁棒性的提升,同时任务性能的损失很小。
- 展示从生成模型获得的涌现特性,如对象掩膜和无监督的修补。
提出的方法
- 将图像合成建模为一个结构化因果模型,具备独立的形状、纹理和背景(IMs)机制。
- 使用固定的组合步骤(Alpha 混合)从掩膜、纹理和背景形成图像。
- 将 IMs 用若干损失训练,包括 L_shape(掩膜保真度)、L_text(纹理)、L_bg(通过显著性进行背景修复)和 L_rec(对来自条件GAN的伪真值的重建)。
- 使用预训练骨干网络初始化图像生成器(如 BigGAN),以适用于像 ImageNet 这样的大规模数据集;通过专用归纳偏置进行微调。
- 通过在保持噪声固定的情况下对机制中的标签进行随机化来生成反事实,得到 X_CF,以训练不变分类器。
- 在反事实数据上训练一个不变分类器 r,以预测与一个特定机制相关的标签,使其对其他机制保持不变。
实验结果
研究问题
- RQ1能否学习控制形状、纹理和背景的独立机制,以生成高质量的反事实图像?
- RQ2反事实图像是否改善训练,使分类器对虚假相关保持不变,从而对域外数据更鲁棒?
- RQ3为了区分这些因素并避免崩溃,需要哪些归纳偏置(预训练、机制特定损失)?
- RQ4CGN 从 MNIST 变体扩展到 ImageNet,在产生有意义的反事实和鲁棒分类器方面表现如何?
主要发现
- CGNs can generate high-quality counterfactual images with controllable shape, texture, and background across MNIST variants and ImageNet.
- Disentangling object shape, texture, and background enables training invariant classifiers that are more robust to spurious correlations.
- Counterfactual training improves out-of-domain robustness with only a marginal drop in in-domain classification performance on MNIST variants and at scale on ImageNet.
- Inductive biases such as pre-training on large generators and mechanism-specific losses are essential to prevent collapse and achieve disentanglement.
- The model produces useful emergent properties, including high-quality object masks and unsupervised inpainting through the imposed supervision and architecture.
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。