[논문 리뷰] Learning to Branch for Multi-Task Learning
이 논문은 다중 작업 학습을 위해 네트워크 내에서 공유 또는 분기를 어디에 둘지 자동으로 학습하는 end-to-end 학습 가능한 방법인 LearnToBranch를 소개하며, gumbel-softmax 샘플링으로 안내되는 미분 가능한 트리 구조 토폴로지를 사용합니다. 합성 데이터, CelebA, Taskonomy에서 작업 군집화 및 성능 개선을 보여줍니다.
Training multiple tasks jointly in one deep network yields reduced latency during inference and better performance over the single-task counterpart by sharing certain layers of a network. However, over-sharing a network could erroneously enforce over-generalization, causing negative knowledge transfer across tasks. Prior works rely on human intuition or pre-computed task relatedness scores for ad hoc branching structures. They provide sub-optimal end results and often require huge efforts for the trial-and-error process. In this work, we present an automated multi-task learning algorithm that learns where to share or branch within a network, designing an effective network topology that is directly optimized for multiple objectives across tasks. Specifically, we propose a novel tree-structured design space that casts a tree branching operation as a gumbel-softmax sampling procedure. This enables differentiable network splitting that is end-to-end trainable. We validate the proposed method on controlled synthetic data, CelebA, and Taskonomy.
연구 동기 및 목표
- 여러 작업에 대해 손으로 설계된 작업 관련성 가정 없이 최적의 공유 및 분기 구조를 자동으로 검색한다.
- differentiable 분기를 통해 다중 작업 손실을 최소화하는 트리 구조 토폴로지를 구성한다.
- 아키텍처와 가중치를 함께 최적화하는 end-to-end 학습 프레임워크를 제공한다.
- 합성 데이터, CelebA, 및 Taskonomy 데이터셋에서 효과를 입증한다.
제안 방법
- 네트워크를 DAG로 표현하고 각 자식 노드가 learnable한 범주 분포 p_theta를 통해 부모 연결을 샘플링하는 분기 블록을 포함한다.
- 학습 중에 이산적 분기 결정을 미분 가능하게 만들고 hard 트리로 수렴하기 위해 gumbel-softmax를 사용하고 온도를 점진적으로 완화한다.
- 분기 연산 x_j^{l+1} = E_{d_j ~ p_theta_j}[d_j · Y^l] 를 정의하여 토폴로지와 가중치에 대해 엔드 투 엔드 최적화를 가능하게 한다.
- 디자인 공간에서 네트워크 구성을 번갈아 샘플링하고 아키텍처 확률과 네트워크 가중치를 역전파로 업데이트하여 학습한다.
- 학습 이후 노이즈 없는 theta에 대해 argmax를 사용하여 최종 아키텍처를 선택하고 최종 성능을 위해 처음부터 재학습한다.
- 리프 노드와의 작업 수를 맞추면서 더 깊은 트리 구조의 다중 작업 네트워크를 구축하기 위해 분기 블록을 스택한다.
실험 결과
연구 질문
- RQ1미분 가능하고 트리 구조의 분기 메커니즘이 여러 작업에 대해 어떤 층을 공유하거나 분리할지 자동으로 결정할 수 있는가?
- RQ2아키텍처와 가중치의 엔드-투-엔드 최적화가 수동으로 설계되거나 고정된 토폴로지보다 다중 작업 성능을 더 잘 이끌어내는가?
- RQ3사전 작업 관련 정보 없이도 역전파 신호에서 자연스럽게 작업 군집이 나타날 수 있는가?
- RQ4학습된 토폴로지가 합성 데이터, CelebA, Taskonomy 데이터셋에서 얼마나 효과적인가?
주요 결과
| 방법 | 정확도(%) | 매개변수(M) |
|---|---|---|
| Moon | 90.94 | 119.73 |
| Indep Group | 91.06 | - |
| MCNN-AUX | 91.29 | - |
| VGG-16 Baseline | 91.44 | 134.41 |
| Branch-VGG | 90.79 | 2.09 |
| LearnToBranch-VGG | 91.55 | 1.94 |
| GNAS-Deep-Wide | 91.36 | 6.41 |
| LearnToBranch-Deep-Wide | 91.62 | 6.33 |
| LNet+ANet | 87 | - |
| Walk and Learn | 88 | - |
| Moon | 90.94 | 119.73 |
| Indep Group | 91.06 | - |
| MCNN-AUX | 91.29 | - |
| VGG-16 Baseline | 91.44 | 134.41 |
| Branch-VGG | 90.79 | 2.09 |
| LearnToBranch-VGG | 91.55 | 1.94 |
| GNAS-Deep-Wide | 91.36 | 6.41 |
| LearnToBranch-Deep-Wide | 91.62 | 6.33 |
- 이 방법은 인간의 사전 지식 없이도 관련 작업을 클러스터링하고 작업이 갈라질 때 분기하는 작업 그룹 구조를 학습한다.
- LearnToBranch는 CelebA에서 여러 베이스라인보다 더 적은 매개변수로 경쟁력 있는 정확도 또는 우수한 정확도를 달성한다.
- Taskonomy에서 LearnToBranch는 다섯 가지 작업(분할, 표준, 깊이, 키포인트, 에지)에서 AdaShare 및 다른 베이스라인을 매개변수 수가 더 적은 상태로 앞지른다.
- 학습된 아키텍처는 실행 간 일관된 공유 패턴을 보여 주어 안정적인 자동 작업 군집 구성을 시사한다.
- 토폴로지 검색 단계(시간: 시간 단위) 후 최종 아키텍처를 처음부터 재학습하여 엔드-투-엔드 최적화로 강력한 성능을 달성한다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.