[论文解读] Accuracy on the Line: On the Strong Correlation Between Out-of-Distribution and In-Distribution Generalization
本论文实证显示在大量模型、数据集和分布转变中,训练分布内准确率与训练外分布准确率之间存在强烈、很大程度上线性的相关性,并提供一个简单的基于高斯的理论来解释这一现象。
For machine learning systems to be reliable, we must understand their performance in unseen, out-of-distribution environments. In this paper, we empirically show that out-of-distribution performance is strongly correlated with in-distribution performance for a wide range of models and distribution shifts. Specifically, we demonstrate strong correlations between in-distribution and out-of-distribution performance on variants of CIFAR-10 & ImageNet, a synthetic pose estimation task derived from YCB objects, satellite imagery classification in FMoW-WILDS, and wildlife classification in iWildCam-WILDS. The strong correlations hold across model architectures, hyperparameters, training set size, and training duration, and are more precise than what is expected from existing domain adaptation theory. To complete the picture, we also investigate cases where the correlation is weaker, for instance some synthetic distribution shifts from CIFAR-10-C and the tissue classification dataset Camelyon17-WILDS. Finally, we provide a candidate theory based on a Gaussian data model that shows how changes in the data covariance arising from distribution shift can affect the observed correlations.
研究动机与目标
- 调查在多样数据集和模型中,是否可以从训练分布内的表现来预测分布外的泛化能力。
- 描述在不同分布转变下,何时出现精确的线性 ID-OOD 趋势,何时失败。
- 评估 ID-OOD 相关性对超参数、训练时长、数据规模以及预训练的鲁棒性。
- 提出一个简单的理论高斯模型来解释观测到的相关性并指导未来理论。
提出的方法
- 在分布内数据 D 上训练大量模型(经典与神经)并改变架构、超参数、随机种子与训练设置。
- 在每个模型上对ID(D)和分布外(D′)测试集进行评估,并使用 probit 变换的散点图可视化以揭示线性趋势。
- 在 probit 域中通过 R^2 量化多对 ID/OOD 的线性度(例如 CIFAR-10 等)。
- 检查预训练和零样本推理对 ID-OOD 关系的影响。
- 制定一个简单的高斯数据模型,以导出变换后准确率之间的近线性关系,并识别影响偏差的因素。
实验结果
研究问题
- RQ1训练分布内的准确率是否线性预测分布外的准确率,跨越多样数据集和分布转变?
- RQ2ID-OOD 线性关系对模型架构、超参数、训练时长和训练数据规模有多鲁棒?
- RQ3预训练在 ID-OOD 关系中扮演何种角色,零样本推理如何影响?
- RQ4在何种条件下线性趋势会失败或减弱,原因是什么?
主要发现
| ID 数据集 | OOD 数据集 | 线性拟合在 probit 域的 R^2 | 评估的模型数量 |
|---|---|---|---|
| CIFAR-10 | CIFAR-10.1 | 0.995 | 1,060 |
| CIFAR-10 | CIFAR-10.2 | 0.997 | 1,060 |
| CINIC-10 | ? | 0.991 | 949 |
| STL-10 | ? | 0.995 | 456 |
| CIFAR-10-C Fog | ? | 0.990 | 790 |
| CIFAR-10-C Brightness | ? | 0.940 | 519 |
| ImageNet | ImageNet-V2 | 0.996 | 219 |
| YCB-Objects | YCB-Objects OOD | 0.975 | 39 |
| iWildCam-WILDS ID | iWildCam-WILDS OOD | 0.881 (0.536) | 66 (63) |
| FMoW-WILDS ID | FMoW-WILDS OOD | 0.984 | 162 |
- 对于许多数据集/模型对,ID 与 OOD 准确率之间存在精确的线性趋势,在多个转变下 probit 域的 R^2 值约为 0.98–0.997。
- 该线性关系在多种模型家族(经典与神经)、架构、超参数、训练时长和训练集规模上成立。
- 预训练可在任务和设置条件下保持或改变 ID-OOD 趋势(例如 CIFAR-10.2 和 FMoW-WILDS 与趋势一致;iWildCam-WILDS 在预训练模型下显示偏离)。
- 使用预训练模型的零-shot 预测往往偏离基本线性趋势,趋向接近 x = y 的线,这表明是预训练而非仅仅 ID 训练引入的偏差。
- 某些转变显示较弱或无精确线性趋势(如 Camelyon17-WILDS 组织分类和一些 CIFAR-10-C 毁坏如高斯噪声),凸显该现象的局限性。
- 一个简单的高斯数据模型解释了近线性的 probit 关系,斜率为 α/γ,维度增加时偏差收缩,支持基于协方差的 ID-OOD 相关性直觉。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。