[论文解读] Routing Networks: Adaptive Selection of Non-linear Functions for Multi-Task Learning
本文提出路由网络(routing networks),一种新颖的多任务学习架构,通过路由器动态组合功能模块(如神经网络层)以实现按输入自适应、任务特定的计算。该模型通过协作式多智能体强化学习进行训练,在准确率和收敛速度方面显著优于交叉连接(cross-stitch)和共享层基线模型,在CIFAR-100(20项任务)上训练速度最快可提升85%。
Multi-task learning (MTL) with neural networks leverages commonalities in tasks to improve performance, but often suffers from task interference which reduces the benefits of transfer. To address this issue we introduce the routing network paradigm, a novel neural network and training algorithm. A routing network is a kind of self-organizing neural network consisting of two components: a router and a set of one or more function blocks. A function block may be any neural network - for example a fully-connected or a convolutional layer. Given an input the router makes a routing decision, choosing a function block to apply and passing the output back to the router recursively, terminating when a fixed recursion depth is reached. In this way the routing network dynamically composes different function blocks for each input. We employ a collaborative multi-agent reinforcement learning (MARL) approach to jointly train the router and function blocks. We evaluate our model against cross-stitch networks and shared-layer baselines on multi-task settings of the MNIST, mini-imagenet, and CIFAR-100 datasets. Our experiments demonstrate a significant improvement in accuracy, with sharper convergence. In addition, routing networks have nearly constant per-task training cost while cross-stitch networks scale linearly with the number of tasks. On CIFAR-100 (20 tasks) we obtain cross-stitch performance levels with an 85% reduction in training time.
研究动机与目标
- 通过实现神经网络组件的动态、输入相关组合,缓解多任务学习(MTL)中的任务干扰问题。
- 克服固定架构设计在多任务学习中的局限性,如共享层或交叉连接网络,这些设计可能导致负迁移。
- 构建一个可泛化的框架,使功能模块(如全连接层或卷积层)能够根据输入和任务自适应选择。
- 实现恒定的每任务训练成本,避免交叉连接网络随任务数量线性增长的缺点。
- 探索使用多智能体强化学习训练路由器,使其学习任务特定的路由策略,同时在有益时共享功能模块。
提出的方法
- 设计一种路由网络,包含两个组件:路由器和一组功能模块,其中路由器递归地选择并应用功能模块,最多达到固定深度。
- 使用强化学习实现硬性路由决策,每个任务拥有独立的强化学习智能体,以学习路由策略。
- 采用协作式多智能体强化学习框架(加权策略学习器,Weighted Policy Learner)联合训练路由器和功能模块,实现任务特定的适应。
- 允许路由器基于输入特征、任务身份、递归深度和历史路由选择进行决策,以支持动态且分层的组合。
- 支持异构功能模块(如不同类型的层),只要其输入/输出维度兼容即可。
- 使用策略梯度方法端到端训练系统,尽管路由器的决策不可微,但可通过策略学习进行优化。
实验结果
研究问题
- RQ1动态自适应的路由机制是否能减少多任务学习中的任务干扰,同时提升泛化能力?
- RQ2在准确率和训练效率方面,路由网络相较于强基线模型(如交叉连接网络和共享层模型)表现如何?
- RQ3路由器在多大程度上能学习到反映底层数据结构和任务相似性的任务特定路由策略?
- RQ4与单智能体或可微路由方法相比,多智能体强化学习是否能实现更好的探索和更快的收敛?
- RQ5路由网络是否能在扩展至大量任务时保持恒定的每任务训练成本,而不会像某些架构那样线性增长?
主要发现
- 在MNIST、mini-ImageNet和CIFAR-100上,路由网络的准确率显著高于交叉连接网络和共享层基线模型。
- 在包含20项任务的CIFAR-100上,路由网络实现了与交叉连接网络相当的性能,但训练时间减少了85%。
- 模型表现出更锐利的收敛性和更快的训练动态,策略熵随时间下降,表明已收敛至稳定、任务特定的路由策略。
- 路由器学习到多样的路由模式:在MNIST-MTL中,网络先使用7个功能模块,减少至4个,再扩展至5个,表明学习到了非平凡的、结构化的模块组合。
- 分析显示,路由策略收敛至纯策略(对某一模块的概率为100%),早期到达的智能体(如图11中的粉红和绿色)主导特定模块。
- 路由图显示模型能有效复用功能模块于不同任务之间,表明实现了有效的正向迁移,且无负向干扰。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。