[论文解读] Learning Surrogate Losses
本文提出了一种新颖的即插即用优化框架,通过神经网络学习平滑、可微分的代理损失函数,以最小化非可微且不可分解的机器学习目标(如AUC、F1、Jaccard指数和MCR)。通过联合训练代理损失网络与预测模型,采用双层优化方法,该方法在九个不同数据集上均优于手工设计的代理损失函数,性能表现更优。
The minimization of loss functions is the heart and soul of Machine Learning. In this paper, we propose an off-the-shelf optimization approach that can minimize virtually any non-differentiable and non-decomposable loss function (e.g. Miss-classification Rate, AUC, F1, Jaccard Index, Mathew Correlation Coefficient, etc.) seamlessly. Our strategy learns smooth relaxation versions of the true losses by approximating them through a surrogate neural network. The proposed loss networks are set-wise models which are invariant to the order of mini-batch instances. Ultimately, the surrogate losses are learned jointly with the prediction model via bilevel optimization. Empirical results on multiple datasets with diverse real-life loss functions compared with state-of-the-art baselines demonstrate the efficiency of learning surrogate losses.
研究动机与目标
- 为解决AUC、F1和Jaccard指数等非可微且不可分解损失函数的优化挑战,这些损失函数无法通过梯度下降直接最小化。
- 通过端到端学习任务特定的代理损失函数,消除对手工设计代理松弛函数的依赖。
- 将代理学习过程形式化为双层优化问题,实现预测模型与代理损失网络的联合训练。
- 证明基于数据集的代理学习可实现优于通用或预训练代理函数的泛化性能。
- 提供一种通用的、即插即用的优化框架,适用于任何非可微损失函数,且无需真实损失函数的梯度信息。
提出的方法
- 该方法将代理损失定义为可学习的神经网络,用于在小批量数据上近似真实非可微损失函数。
- 代理网络具有集合不变性,即对小批量中样本顺序不敏感,从而能够正确处理不可分解损失。
- 代理损失通过双层优化进行训练:外层循环最小化训练集上的真实损失,内层循环优化代理网络以匹配真实损失。
- 采用交替优化算法,联合训练预测模型与代理损失网络,同时通过代理网络反向传播梯度。
- 该方法将真实损失视为黑箱函数,无需显式计算其关于模型参数的梯度。
- 该方法学习每个数据集特定的代理损失,而非依赖通用代理函数,从而提升准确率与适应性。
实验结果
研究问题
- RQ1能否训练一个神经网络,使其学习到平滑、可微分的代理损失函数,以准确逼近非可微的真实损失函数?
- RQ2学习数据集特定的代理损失是否优于通用或手工设计的代理松弛函数?
- RQ3双层优化能否在无需真实损失梯度信息的情况下,实现预测模型与代理损失网络的联合训练?
- RQ4所提出的方法在具有复杂不可分解损失的真实世界数据集上是否具备可扩展性与高效性?
- RQ5在多种损失函数下,代理学习与最先进基线方法相比,最终模型性能如何?
主要发现
- 在九个数据集上,代理学习(SL-R)在所有四种损失函数(MCR、AUC、F1、JAC)上的测试损失均低于所有最先进基线方法。
- 平均而言,SL-R在MCR任务中赢得5.5个数据集,在AUC任务中赢得8.0个,在JAC任务中赢得5.5个,在F1任务中赢得6.0个,表明其性能持续优越。
- 在IJC数据集上,SL-R的AUC达到0.0030,显著优于次佳基线方法(GO)的0.0258。
- 在SUSY数据集上,SL-R将F1损失降低至0.2289,而代价敏感基线方法为0.2420,表现更优。
- SL-R在所有数据集上均取得了AUC和JAC的最先进结果,且始终优于Lovasz Soft-Max与成对排序基线方法。
- 在最大数据集(SUSY)上,训练时间约为1天4小时,运行于单张GPU上,证明了其在实际应用中的可行性,尽管增加了计算复杂度。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。