[论文解读] Adaptive Risk Minimization: Learning to Adapt to Domain Shift
本文提出了自适应风险最小化(ARM),一种利用未标注测试数据在测试时适应领域分布变化的模型训练框架。通过在训练领域上元学习适应策略,ARM 在存在分布偏移的图像分类基准上相较之前的方法提升了 1–4% 的测试准确率。
A fundamental assumption of most machine learning algorithms is that the training and test data are drawn from the same underlying distribution. However, this assumption is violated in almost all practical applications: machine learning systems are regularly tested under distribution shift, due to changing temporal correlations, atypical end users, or other factors. In this work, we consider the problem setting of domain generalization, where the training data are structured into domains and there may be multiple test time shifts, corresponding to new domains or domain distributions. Most prior methods aim to learn a single robust model or invariant feature space that performs well on all domains. In contrast, we aim to learn models that adapt at test time to domain shift using unlabeled test points. Our primary contribution is to introduce the framework of adaptive risk minimization (ARM), in which models are directly optimized for effective adaptation to shift by learning to adapt on the training domains. Compared to prior methods for robustness, invariance, and adaptation, ARM methods provide performance gains of 1-4% test accuracy on a number of image classification problems exhibiting domain shift.
研究动机与目标
- 解决标准机器学习模型在训练与测试数据分布不匹配时失效的局限性。
- 克服不变表示学习的不足,后者假设跨领域输入输出关系保持一致,但在关系发生变化时会失效。
- 在手写识别或医学影像等应用中,仅使用新领域中的未标注样本实现有效的测试时适应。
- 开发一个统一框架,不仅优化模型在训练领域的性能,还优化其在推理时有效适应的能力。
- 证明元学习的适应策略在多种分布偏移基准上优于标准 ERM 和基于不变性的方法。
提出的方法
- 提出自适应风险最小化(ARM),一种训练目标,通过使用未标注测试数据优化模型对未见领域的有效适应能力。
- 利用元学习在一组训练领域上训练模型,使其能够通过在未标注测试批次上微调快速适应新领域。
- 通过上下文元学习实例化 ARM,其中适应由测试输入批次的统计量(如批归一化统计)引导。
- 通过双层优化实现 ARM 目标:内层循环在测试批次上微调模型,外层循环更新模型参数以最小化适应后的期望风险。
- 将框架扩展至流式设置,使模型能够对持续到达的未标注数据逐步适应,实际中表现出快速收敛。
- 将 ARM 与归一化层(如 BatchNorm)集成,以实现高效适应,形成实验中使用的 ARM-BN 变体。
实验结果
研究问题
- RQ1能否仅使用未标注数据,在测试时训练模型以有效适应领域分布变化,而无需测试标签?
- RQ2与基于不变性的表示学习和标准 ERM 相比,ARM 的元学习适应在多种分布偏移场景下的鲁棒性如何?
- RQ3利用未标注测试数据进行适应是否能在具有真实世界分布偏移的多个基准上带来一致的性能提升?
- RQ4ARM 是否能泛化到不同类型领域偏移,包括图像噪声、跨用户手写风格差异以及医学图像分布偏移?
- RQ5在流式设置中,ARM 的性能如何随适应步数或输入数据点数量而变化?
主要发现
- 在存在领域偏移的图像分类基准上,ARM 方法相较之前最先进方法实现了 1–4% 的绝对准确率提升。
- 在 Wilds 基准上,ARM-BN 显著提升了 RxRx1(准确率 87.2%)和 Camelyon17 的性能,优于 ERM 和基于不变性的方法。
- 在流式设置中,ARM 模型在 Tiny ImageNet-C 上使用少于 50 个未标注测试样本即达到优异性能,证明了其快速且有效的适应能力。
- ARM-BN 在 FMoW 上表现不佳,表明适应策略可能需要针对特定数据分布进行定制,凸显了多样化适应工具的必要性。
- 该框架在多个数据集的平均性能和最差情况性能上均表现出一致提升,表明其鲁棒性增强。
- 实证结果证实,即使输入输出关系在不同领域间发生变化,元学习的适应策略依然有效,而基于不变性的方法在此情形下会失效。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。