[논문 리뷰] Conflict-Averse Gradient Descent for Multi-task Learning
CAGrad는 평균 손실을 최소화하는 동시에 업데이트를 정규화하여 작업 간 최악의 지역 개선을 최대화하고, 평균 손실의 수렴 최저값으로 수렴함을 보장합니다. 이는 GD와 MGDA를 일반화하고 감독 학습, 준감독 학습, 강화 학습 작업에서 다중 작업 학습 성능을 개선합니다.
The goal of multi-task learning is to enable more efficient learning than single task learning by sharing model structures for a diverse set of tasks. A standard multi-task learning objective is to minimize the average loss across all tasks. While straightforward, using this objective often results in much worse final performance for each task than learning them independently. A major challenge in optimizing a multi-task model is the conflicting gradients, where gradients of different task objectives are not well aligned so that following the average gradient direction can be detrimental to specific tasks' performance. Previous work has proposed several heuristics to manipulate the task gradients for mitigating this problem. But most of them lack convergence guarantee and/or could converge to any Pareto-stationary point. In this paper, we introduce Conflict-Averse Gradient descent (CAGrad) which minimizes the average loss function, while leveraging the worst local improvement of individual tasks to regularize the algorithm trajectory. CAGrad balances the objectives automatically and still provably converges to a minimum over the average loss. It includes the regular gradient descent (GD) and the multiple gradient descent algorithm (MGDA) in the multi-objective optimization (MOO) literature as special cases. On a series of challenging multi-task supervised learning and reinforcement learning tasks, CAGrad achieves improved performance over prior state-of-the-art multi-objective gradient manipulation methods.
연구 동기 및 목표
- 다중 작업 학습(MTL)에서 상충하는 그래디언트의 최적화 과제를 동기 부여하고 해결합니다.
- 평균 손실을 최적화하면서 작업 간 최악의 지역 개선을 최소화하는 그래디언트 기반 업데이트 규칙을 제안합니다.
- 완만한 조건에서 평균 손실의 최소값에 대한 이론적 수렴 보장을 제공합니다.
- 감독 학습, 준감독 학습 및 강화 학습 설정 전반에서 기존 그래디언트 조작 방법들에 비해 경험적 개선을 보여줍니다.
제안 방법
- 업데이트 방향 d를 평균 그래디언트 g0를 중심으로 한 구 안에서 유지되면서도 작업 전반에 걸친 최악의 지역 개선을 최대화하도록 정의합니다.
- 원문 문제를 형식화합니다: max_d min_i <gi, d> subject to ||d - g0|| <= c ||g0||.
- 듀얼 문제를 도출하여 단순도(simplex)에서 작업 그래디언트에 대한 가중치 w를 구하도록 감소시킵니다, gw = sum_i w_i gi.
- 업데이트를 계산합니다: theta_{t} = theta_{t-1} - alpha (g0 + (sqrt(phi)/||gw||) gw), where phi = c^2 ||g0||^2.
- c=0이 표준 경사 하강법을 복구하고 c→∞가 MGDA를 근사하는 것을 보이고, 0 <= c < 1에서 수렴을 입증합니다.
- 경사 계산의 부분 집합 S를 샘플링하고 듀얼 목적을 적절히 조정하여 실용적인 속도up를 제공합니다.
실험 결과
연구 질문
- RQ1CAGrad가 작업 간 그래디언트 충돌을 완화하면서 평균 손실 L0의 (로컬) 최소값으로 수렴할 수 있습니까?
- RQ2CAGrad가 GD 및 MGDA와 어떤 관계를 갖고 일반화하며, 이러한 특수 사례가 어떤 매개변수 조합에서 나타납니까?
- RQ3CAGrad가 감독 학습, 준감독 학습 및 강화 학습 다중 작업 문제에서 최첨단 그래디언트 조작 방법들보다 실험적으로 성능을 개선합니까?
주요 결과
- CAGrad는 0 <= c < 1일 때 평균 손실 L0의 최소 지점으로 수렴하는 것이 증명됩니다.
- c가 0에 접근하면 CAGrad는 표준 경사 하강법으로 축소되고, 큰 c에서는 MGDA에 근사합니다.
- 실험적으로 CAGrad는 다중 작업 감독 학습, 준감독 학습, 강화 학습 벤치마크에서 기존의 그래디언트 조작 방법들보다 개선을 보입니다.
- 감독형 MTL 실험에서 MTAN 백본을 사용할 때 NYU-v2 및 CityScapes에서 평균 작업 수행도가 종종 최고를 기록합니다.
- 강화 학습에서 CAGrad는 MetaWorld MT10/MT50에서 여러 베이스라인보다 더 높은 성공률의 눈에 띄는 이점을 보입니다.
- 빠른 변형(CAGrad-Fast)은 MT10/MT50 설정에서 PCGrad 대비 2배~5배의 속도 향상을 제공하면서도 경쟁력 있는 성능을 유지합니다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.