[論文レビュー] Conflict-Averse Gradient Descent for Multi-task Learning
CAGrad は平均損失を最小化しつつ、更新を抑制してタスク間の最悪の局所改善を最大化し、平均損失の最小値への収束保証を得る。GDとMGDAを一般化し、教師あり・半教師あり・強化学習タスクを跨いだマルチタスク学習の性能を向上させる。
The goal of multi-task learning is to enable more efficient learning than single task learning by sharing model structures for a diverse set of tasks. A standard multi-task learning objective is to minimize the average loss across all tasks. While straightforward, using this objective often results in much worse final performance for each task than learning them independently. A major challenge in optimizing a multi-task model is the conflicting gradients, where gradients of different task objectives are not well aligned so that following the average gradient direction can be detrimental to specific tasks' performance. Previous work has proposed several heuristics to manipulate the task gradients for mitigating this problem. But most of them lack convergence guarantee and/or could converge to any Pareto-stationary point. In this paper, we introduce Conflict-Averse Gradient descent (CAGrad) which minimizes the average loss function, while leveraging the worst local improvement of individual tasks to regularize the algorithm trajectory. CAGrad balances the objectives automatically and still provably converges to a minimum over the average loss. It includes the regular gradient descent (GD) and the multiple gradient descent algorithm (MGDA) in the multi-objective optimization (MOO) literature as special cases. On a series of challenging multi-task supervised learning and reinforcement learning tasks, CAGrad achieves improved performance over prior state-of-the-art multi-objective gradient manipulation methods.
研究の動機と目的
- マルチタスク学習(MTL)における勾配の衝突という最適化課題を動機づけ、対処する。
- 平均損失を最適化しつつ、タスク全体の最悪局所改善を最小化する勾配ベースの更新規則を提案する。
- 穏やかな条件の下で平均損失の最小値への収束保証を理論的に提供する。
- 監督あり/半監督/強化学習設定を横断して、従来の勾配操作法より実証的な改善を示す。
提案手法
- 平均勾配 g0 の周りの球内に収まりつつ、タスク間の最悪の局所改善を最大化する更新方向 d を定義する。
- Primal problem の定式化: max_d min_i <gi, d> subject to ||d - g0|| <= c ||g0||.
- Dual 問題を導出し、単体上のタスク勾配の重み w の解法へ還元する。 gw = sum_i w_i gi.
- Compute update as: theta_{t} = theta_{t-1} - alpha (g0 + (sqrt(phi)/||gw||) gw), where phi = c^2 ||g0||^2.
- Show that c=0 recovers standard gradient descent and c→∞ recovers MGDA; prove convergence for 0 <= c < 1.
- Provide a practical speedup by sampling a subset S of tasks for gradient computations and adjusting the dual objective accordingly.
実験結果
リサーチクエスチョン
- RQ1Can CAGrad converge to a (local) minimum of the average loss L0 while mitigating gradient conflicts across tasks?
- RQ2How does CAGrad relate to and generalize existing methods like GD and MGDA, and under what parameter regimes do these special cases arise?
- RQ3Does CAGrad improve empirical performance over state-of-the-art gradient-manipulation methods across supervised, semi-supervised, and reinforcement learning multi-task problems?
主な発見
- CAGrad provably converges to a minimum point of the average loss L0 when 0 <= c < 1.
- With c approaching 0, CAGrad reduces to standard gradient descent; with large c, it approximates MGDA.
- Empirically, CAGrad improves over prior gradient-manipulation methods on multi-task supervised learning, semi-supervised learning, and reinforcement learning benchmarks.
- In supervised MTL experiments, CAGrad often achieves the best average task performance on NYU-v2 and CityScapes when using the MTAN backbone.
- In reinforcement learning, CAGrad outperforms several baselines on MetaWorld MT10/MT50 with notable gains in success rates.
- A fast variant (CAGrad-Fast) maintains competitive performance while offering 2x–5x speedups over PCGrad in MT10/MT50 settings.
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。