[论文解读] Adaptive Risk Minimization: A Meta-Learning Approach for Tackling Group Distribution Shift
本文提出自适应风险最小化(ARM),一种元学习框架,可使模型在测试时适应不同数据组之间的分布偏移。通过在训练期间针对多样化、模拟的组分布偏移优化性能,ARM 在具有组分布偏移的图像分类基准上,相比以往的鲁棒性、不变性及自适应方法,测试准确率提高了1–4%。
A fundamental assumption of most machine learning algorithms is that the training and test data are drawn from the same underlying distribution. However, this assumption is violated in almost all practical applications: machine learning systems are regularly tested under distribution shift, due to changing temporal correlations, atypical end users, or other factors. In this work, we consider the setting where the training data are structured into groups and there may be multiple test time shifts, corresponding to new groups or group distributions. Most prior methods aim to learn a single robust model or invariant feature space to tackle this group shift. In contrast, we aim to learn models that adapt at test time to shift using unlabeled test points. Our primary contribution is to introduce the framework of adaptive risk minimization (ARM), in which models are optimized for post adaptation performance on training batches sampled from different groups, which simulate group shifts that may occur at test time. We use meta-learning to solve the ARM problem, and compared to prior methods for robustness, invariance, and adaptation, ARM methods provide consistent gains of 1-4% test accuracy on image classification problems exhibiting group shift.
研究动机与目标
- 为解决现实世界机器学习应用中测试数据与训练数据在组特定分布上存在差异的组分布偏移挑战。
- 开发一种方法,使模型能够利用未标记的测试样本在测试时动态适应,而非依赖固定的鲁棒或不变表示。
- 通过元学习在训练期间模拟此类偏移,以提升在多种未见组分布偏移下的泛化能力。
- 通过显式优化在多样化组分布上的适应后性能,超越现有方法在鲁棒性、不变性和自适应方面的表现。
提出的方法
- 提出自适应风险最小化(ARM),一种训练目标,通过在不同组中采样训练批次以模拟潜在的测试时分布偏移,优化模型在适应后的性能。
- 利用元学习训练模型,使其在推理时仅使用未标记的测试数据即可快速适应新的组分布。
- 在多个支持集上进行训练,以代表不同的组偏移,每个支持集由特定组分布的少量标记样本组成。
- 在测试时,使用来自新组的少量未标记测试样本对模型进行微调,通过小样本适应步骤最小化适应后的风险。
- 将元目标表述为在训练期间对所有模拟组偏移的适应后期望风险进行最小化。
- 采用两阶段优化过程:内层循环在支持集上适应模型,外层循环更新模型参数以最小化所有组上的适应后风险。
实验结果
研究问题
- RQ1通过在测试时使用未标记数据进行适应,元学习模型是否能在组分布偏移下实现更好的泛化?
- RQ2与学习不变表示或单一鲁棒模型的方法相比,所提出的ARM框架在分布偏移下的测试准确率表现如何?
- RQ3在训练期间优化适应后性能是否能在多样化组偏移场景下带来一致的性能提升?
- RQ4模型在测试时的适应能力在多大程度上减少了在未见组分布下的性能下降?
主要发现
- ARM在具有组分布偏移的图像分类任务中,相比专注于鲁棒性、不变性或自适应的先前方法,测试准确率持续提高1–4%。
- 该方法在展示组偏移的多个基准数据集上均表现出色,证明其对多样化且未见组分布的鲁棒性。
- 元学习使模型仅使用来自新组的少量未标记样本即可在测试时实现有效适应,且推理阶段无需标签数据。
- ARM带来的性能增益在显著的组分布偏移下最为明显,表明其在真实世界部署场景中的有效性。
- 该框架通过显式优化适应性能,优于不变风险最小化和标准域自适应方法。
- 消融研究证实,ARM成功的关键在于训练期间模拟组偏移并优化适应后性能。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。