[論文レビュー] Learning to Branch for Multi-Task Learning
本論文は LearnToBranch を紹介します。エンドツーエンドで訓練可能な手法で、マルチタスク学習においてネットワーク内でどこを共有すべきか、どこで分岐すべきかを自動的に学習します。微分可能な木構造トポロジーを、gumbel-softmax サンプリングに導かれて構築します。合成データ、CelebA、Taskonomy でのタスクグルーピングと性能の改善を示します。
Training multiple tasks jointly in one deep network yields reduced latency during inference and better performance over the single-task counterpart by sharing certain layers of a network. However, over-sharing a network could erroneously enforce over-generalization, causing negative knowledge transfer across tasks. Prior works rely on human intuition or pre-computed task relatedness scores for ad hoc branching structures. They provide sub-optimal end results and often require huge efforts for the trial-and-error process. In this work, we present an automated multi-task learning algorithm that learns where to share or branch within a network, designing an effective network topology that is directly optimized for multiple objectives across tasks. Specifically, we propose a novel tree-structured design space that casts a tree branching operation as a gumbel-softmax sampling procedure. This enables differentiable network splitting that is end-to-end trainable. We validate the proposed method on controlled synthetic data, CelebA, and Taskonomy.
研究の動機と目的
- 手作りのタスク関連性の仮定を用いず、複数タスクの最適なネットワーク共有と分岐構造を自動的に探索する。
- 微分可能な分岐を介してマルチタスク損失を最小化する木構造のトポロジを組み合わせて構築する。
- アーキテクチャと重みを同時に最適化するエンドツーエンドの訓練フレームワークを提供する。
- 合成データ、CelebA、Taskonomy データセットでの有効性を示す。
提案手法
- ネットワークを DAG として、各子ノードが学習可能なカテゴリ分布 p_theta によって親接続をサンプリングする分岐ブロックを持つ表現。
- 学習中に離散的な分岐決定を微分可能にするために gumbel-softmax を用い、ハードな木へ収束するよう温度を徐々にアニーリングする。
- 分岐操作 x_j^{l+1} = E_{d_j ~ p_theta_j}[d_j · Y^l] を定義し、トポロジーと重みのエンドツーエンドの最適化を実現する。
- 設計空間からのネットワーク構成のサンプリングと、バックプロパゲーションによるアーキテクチャ確率とネットワーク重みの更新を交互に行うことで訓練する。
- 訓練後、ノイズなしで theta の argmax によって最終アーキテクチャを選択し、最終性能のためにスクラッチから再訓練する。
- 分岐ブロックを積み重ねて、葉ノードとタスク数を一致させつつ、より深い木構造のマルチタスクネットワークを構築する。
実験結果
リサーチクエスチョン
- RQ1微分可能で木構造の分岐機構は、複数タスクのためにどの層を共有または分岐させるべきかを自動的に決定できるか?
- RQ2アーキテクチャと重みのエンドツーエンド最適化は、手作りまたは静的なトポロジよりも多タスク性能を向上させるか?
- RQ3事前のタスク関連性情報がなくても、タスクグルーピングはバックプロパゲーション信号から自然に現れるか?
- RQ4学習したトポロジーは、合成データ、CelebA、Taskonomy データセットでどれほど効果的か?
主な発見
| Method | Acc (%) | Params (M) |
|---|---|---|
| Moon | 90.94 | 119.73 |
| Indep Group | 91.06 | - |
| MCNN-AUX | 91.29 | - |
| VGG-16 Baseline | 91.44 | 134.41 |
| Branch-VGG | 90.79 | 2.09 |
| LearnToBranch-VGG | 91.55 | 1.94 |
| GNAS-Deep-Wide | 91.36 | 6.41 |
| LearnToBranch-Deep-Wide | 91.62 | 6.33 |
| LNet+ANet | 87 | - |
| Walk and Learn | 88 | - |
| Moon | 90.94 | 119.73 |
| Indep Group | 91.06 | - |
| MCNN-AUX | 91.29 | - |
| VGG-16 Baseline | 91.44 | 134.41 |
| Branch-VGG | 90.79 | 2.09 |
| LearnToBranch-VGG | 91.55 | 1.94 |
| GNAS-Deep-Wide | 91.36 | 6.41 |
| LearnToBranch-Deep-Wide | 91.62 | 6.33 |
- 本手法は、人間の事前知識なしに、関連するタスクをクラスタリングし、タスクが分岐する場合に分岐するタスクグルーピング構造を学習します。
- LearnToBranch は CelebA で、いくつかのベースラインと比較して、同等または優れた精度を、より少ないパラメータで達成します。
- Taskonomy では、LearnToBranch が AdaShare および他のベースラインを、5つのタスク(セグメンテーション、法線、深度、キーポイント、エッジ)でより少ないパラメータ数で上回ります。
- 学習されたアーキテクチャは、実行を通じて一貫した共有パターンを示し、安定した自動タスクグルーピングを示唆します。
- 訓練にはトポロジー探索フェーズ(時間: hours)と、その後の最終アーキテクチャをスクラッチから再訓練する段階が含まれ、エンドツーエンド最適化で高い性能を達成します。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。