[论文解读] Multi-Task Learning as a Bargaining Game
本论文将多任务学习中的梯度聚合建模为协商博弈,推出 Nash-MTL 一种基于纳什协商解的梯度更新规则,给出收敛性保证,并在多个基准数据集上实现最先进的结果。
In Multi-task learning (MTL), a joint model is trained to simultaneously make predictions for several tasks. Joint training reduces computation costs and improves data efficiency; however, since the gradients of these different tasks may conflict, training a joint model for MTL often yields lower performance than its corresponding single-task counterparts. A common method for alleviating this issue is to combine per-task gradients into a joint update direction using a particular heuristic. In this paper, we propose viewing the gradients combination step as a bargaining game, where tasks negotiate to reach an agreement on a joint direction of parameter update. Under certain assumptions, the bargaining problem has a unique solution, known as the Nash Bargaining Solution, which we propose to use as a principled approach to multi-task learning. We describe a new MTL optimization procedure, Nash-MTL, and derive theoretical guarantees for its convergence. Empirically, we show that Nash-MTL achieves state-of-the-art results on multiple MTL benchmarks in various domains.
研究动机与目标
- 促使在任务梯度冲突或尺度不一致时改进多任务学习中的优化。
- 引入基于纳什协商解的原理性、公理性梯度聚合方法。
- 给出所提 Nash-MTL 算法在凸与非凸设定下的收敛性保证。
- 展示在视觉、化学和强化学习等多领域的多任务学习基准上达到最先进的性能。
提出的方法
- 将梯度聚合建模为一个包含 K 个任务梯度 g_i 的协商问题及其一致性集合 B_ε。
- 证明纳什协商解给出一个更新方向 Delta_theta,落在任务梯度的张成空间内:Delta_theta = sum_i alpha_i g_i。
- 推导关键方程 G^T G alpha = 1/alpha,其中 G 是任务梯度矩阵,alpha > 0 为协商权重。
- 提出一种高效近似求解 G^T G alpha = 1/alpha 的方法,使用一系列凸代理和类似 CCP 的迭代方案。
- 在温和假设下证明收敛到帕雷托驻点,并在任务为凸时给出基于凸性的增强。
- 通过建议较少频繁更新 alpha 以降低计算成本同时保持性能,来解决实际加速问题。
实验结果
研究问题
- RQ1博弈理论框架是否能为多任务学习提供一个原理性、尺度不变的梯度组合?
- RQ2Nash-MTL 在凸与非凸设定下是否收敛到帕雷托驻点?
- RQ3在多样化的多任务学习基准上,与现有梯度聚合方法相比,Nash-MTL 的表现如何?
- RQ4在处理大量任务时,应用 Nash-MTL 会产生哪些实际的计算成本权衡?
主要发现
| 方法 | MR ↓ | Delta_m % ↓ |
|---|---|---|
| LS | 6.8 | 177.6±3.4 |
| SI | 4.0 | 77.8±9.2 |
| RLW | 8.2 | 203.8±3.4 |
| DWA | 6.4 | 175.3±6.3 |
| MGDA | 5.9 | 120.5±2.0 |
| PCGrad | 5.0 | 125.7±10.3 |
| CAGrad | 5.7 | 112.8±4.0 |
| IMTL-G | 4.7 | 77.2±9.3 |
| Nash-MTL | 2.5 | 62.0±1.4 |
- 在 QM9 上,Nash-MTL 以 MR = 2.5 和 Delta_m = 62.0±1.4 实现最佳性能,超越所有基线。
- Nash-MTL 在 NYUv2 和 Cityscapes 基准上获得最佳的平均秩 (MR),表明跨任务平衡性强。
- Nash-MTL 在 NYUv2 获得最佳 Delta_m, Cityscapes 位居第二,显示在各任务中的鲁棒尺度不变改进。
- 在多任务强化学习(MT10)中,Nash-MTL 在所评估方法中获得最高的平均成功率(跨种子)。
- 所提出的方法对每任务损失尺度具有尺度不变性,减轻单一大梯度的支配效应。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。