[论文解读] Test-Time Adaptation via Conjugate Pseudo-labels
这篇论文引入共轭伪标签(Conjugate PL)用于测试时自适应(TTA),从训练损失的凸共轭派生损失。它显示采用交叉熵训练的模型自然偏好软最大熵(TENT 风格)损失,而平方损失模型偏好负平方损失。实验证据表,Conjugate PL 在多数据集和多种损失下改善 TTA,包括 PolyLoss,并且可被解读为一种使用共轭伪标签的自训练方案。
Test-time adaptation (TTA) refers to adapting neural networks to distribution shifts, with access to only the unlabeled test samples from the new domain at test-time. Prior TTA methods optimize over unsupervised objectives such as the entropy of model predictions in TENT [Wang et al., 2021], but it is unclear what exactly makes a good TTA loss. In this paper, we start by presenting a surprising phenomenon: if we attempt to meta-learn the best possible TTA loss over a wide class of functions, then we recover a function that is remarkably similar to (a temperature-scaled version of) the softmax-entropy employed by TENT. This only holds, however, if the classifier we are adapting is trained via cross-entropy; if trained via squared loss, a different best TTA loss emerges. To explain this phenomenon, we analyze TTA through the lens of the training losses's convex conjugate. We show that under natural conditions, this (unsupervised) conjugate function can be viewed as a good local approximation to the original supervised loss and indeed, it recovers the best losses found by meta-learning. This leads to a generic recipe that can be used to find a good TTA loss for any given supervised training loss function of a general class. Empirically, our approach consistently dominates other baselines over a wide range of benchmarks. Our approach is particularly of interest when applied to classifiers trained with novel loss functions, e.g., the recently-proposed PolyLoss, where it differs substantially from (and outperforms) an entropy-based loss. Further, we show that our approach can also be interpreted as a kind of self-training using a very specific soft label, which we refer to as the conjugate pseudolabel. Overall, our method provides a broad framework for better understanding and improving test-time adaptation. Code is available at https://github.com/locuslab/tta_conjugate.
研究动机与目标
- 激发并理解在没有标注测试数据的分布迁移下,如何选择或推导有效的 TTA 损失。
- 将 TTA 损失的设计与监督损失的凸共轭联系起来,以解释在何时最优为熵样损失或其他替代损失。
- 提供一个实用、通用的配方,用于获得适用于各种训练损失的良好 TTA 损失(例如交叉熵、平方损失、PolyLoss)。
- 表明所提出的 Conjugate PL 方法对应于使用共轭伪标签的自训练方案,并在基准测试中取得实证提升。
提出的方法
- 通过监督损失的凸共轭来表述 TTA 损失,显示 L_conj(h(x)) = -f^*(∇f(h(x))).
- 将其具体化到常见损失:当 f(h)=log sum exp(h) 时,交叉熵得到的 L_conj 为 softmax-entropy,与 TENT 相匹配。
- 对于平方损失,若 f(h)=½||h||^2,则 L_conj 变为负的平方范数,解释了另一种元学习所得的损失。
- 将 Conjugate PL 解释为使用伪标签 y_hat^CPL = ∇f(h(x)) 的自训练,即共轭伪标签。
- 通过在标准形式或扩展共轭形式中表达,将其扩展到 PolyLoss,从而为非标准损失启用 CPL。
- 提供一个算法(Conjugate PL),在未标注的测试批次上使用 CPL 进行自训练来更新模型参数,并引入温度缩放作为实际的增强方式。
实验结果
研究问题
- RQ1在测试时分布迁移下,给定一个监督训练损失时,良好 TTA 损失的原理性形式是什么?
- RQ2为什么基于熵的损失在对数交叉熵训练的模型上表现良好,以及何时应偏好替代损失?
- RQ3凸共轭是否可以提供一个普适框架,为多种训练损失(例如 PolyLoss、平方损失)推导 TTA 损失?
- RQ4共轭伪标签与自训练的关系如何,以及对不同损失而言,哪些伪标签是最优的?
主要发现
- 对最佳 TTA 损失进行元学习往往会为使用交叉熵训练的分类器恢复一个带温标的 softmax-熵,而对于使用平方损失训练的分类器则恢复一个负平方损失。
- 凸共轭框架解释了为何对交叉熵会出现 softmax-熵,以及为何其他损失会产生不同的 TTA 损失,提供了新的损失的一般性配方。
- Conjugate PL 在 CIFAR-10/100-C、ImageNet-C 以及域适应任务中,持续优于基线 TTA 损失(例如最小化熵、鲁棒伪标签、MEMO)。
- 将 Conjugate PL 应用于 PolyLoss 和平方损失分类器可带来显著提升,说明当源训练损失与标准交叉熵不同时时方法的好处。
- 该方法可被解读为使用共轭伪标签的自训练方法,提供一个有原理、广泛适用的 TTA 框架。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。