[Paper Review] Conflict-Averse Gradient Descent for Multi-task Learning
CAGrad minimizes the average loss while regularizing updates to maximize the worst local improvement among tasks, with convergence guarantees to a minimum of the average loss. It generalizes GD and MGDA and improves multi-task learning performance across supervised, semi-supervised, and reinforcement learning tasks.
The goal of multi-task learning is to enable more efficient learning than single task learning by sharing model structures for a diverse set of tasks. A standard multi-task learning objective is to minimize the average loss across all tasks. While straightforward, using this objective often results in much worse final performance for each task than learning them independently. A major challenge in optimizing a multi-task model is the conflicting gradients, where gradients of different task objectives are not well aligned so that following the average gradient direction can be detrimental to specific tasks' performance. Previous work has proposed several heuristics to manipulate the task gradients for mitigating this problem. But most of them lack convergence guarantee and/or could converge to any Pareto-stationary point. In this paper, we introduce Conflict-Averse Gradient descent (CAGrad) which minimizes the average loss function, while leveraging the worst local improvement of individual tasks to regularize the algorithm trajectory. CAGrad balances the objectives automatically and still provably converges to a minimum over the average loss. It includes the regular gradient descent (GD) and the multiple gradient descent algorithm (MGDA) in the multi-objective optimization (MOO) literature as special cases. On a series of challenging multi-task supervised learning and reinforcement learning tasks, CAGrad achieves improved performance over prior state-of-the-art multi-objective gradient manipulation methods.
Motivation & Objective
- Motivate and address the optimization challenge of conflicting gradients in multi-task learning (MTL).
- Propose a gradient-based update rule that minimizes the worst-case local improvement across tasks while optimizing the average loss.
- Provide theoretical convergence guarantees to a minimum of the average loss under mild conditions.
- Demonstrate empirical improvements over prior gradient manipulation methods across supervised, semi-supervised, and reinforcement learning settings.
Proposed method
- Define update direction d to maximize the worst local improvement across tasks while staying within a ball around the average gradient g0.
- Formulate the primal problem: max_d min_i <gi, d> subject to ||d - g0|| <= c ||g0||.
- Derive the dual problem, reducing to solving for weights w on the task gradients in the simplex, with gw = sum_i w_i gi.
- Compute update as: theta_{t} = theta_{t-1} - alpha (g0 + (sqrt(phi)/||gw||) gw), where phi = c^2 ||g0||^2.
- Show that c=0 recovers standard gradient descent and c→∞ recovers MGDA; prove convergence for 0 <= c < 1.
- Provide a practical speedup by sampling a subset S of tasks for gradient computations and adjusting the dual objective accordingly.
Experimental results
Research questions
- RQ1Can CAGrad converge to a (local) minimum of the average loss L0 while mitigating gradient conflicts across tasks?
- RQ2How does CAGrad relate to and generalize existing methods like GD and MGDA, and under what parameter regimes do these special cases arise?
- RQ3Does CAGrad improve empirical performance over state-of-the-art gradient-manipulation methods across supervised, semi-supervised, and reinforcement learning multi-task problems?
Key findings
- CAGrad provably converges to a minimum point of the average loss L0 when 0 <= c < 1.
- With c approaching 0, CAGrad reduces to standard gradient descent; with large c, it approximates MGDA.
- Empirically, CAGrad improves over prior gradient-manipulation methods on multi-task supervised learning, semi-supervised learning, and reinforcement learning benchmarks.
- In supervised MTL experiments, CAGrad often achieves the best average task performance on NYU-v2 and CityScapes when using the MTAN backbone.
- In reinforcement learning, CAGrad outperforms several baselines on MetaWorld MT10/MT50 with notable gains in success rates.
- A fast variant (CAGrad-Fast) maintains competitive performance while offering 2x–5x speedups over PCGrad in MT10/MT50 settings.
Better researchstarts right now
From paper design to paper writing, dramatically reduce your research time.
No credit card · Free plan available
This review was created by AI and reviewed by human editors.