[论文解读] A simple neural network module for relational reasoning
本文提出关系网络(RNs)作为一种可插拔的关系推理模块,使 CNN/LSTM 架构能够在 CLEVR、bAbI 和动态物理系统任务上实现最先进、超越人类的推理。
Relational reasoning is a central component of generally intelligent behavior, but has proven difficult for neural networks to learn. In this paper we describe how to use Relation Networks (RNs) as a simple plug-and-play module to solve problems that fundamentally hinge on relational reasoning. We tested RN-augmented networks on three tasks: visual question answering using a challenging dataset called CLEVR, on which we achieve state-of-the-art, super-human performance; text-based question answering using the bAbI suite of tasks; and complex reasoning about dynamic physical systems. Then, using a curated dataset called Sort-of-CLEVR we show that powerful convolutional networks do not have a general capacity to solve relational questions, but can gain this capacity when augmented with RNs. Our work shows how a deep learning architecture equipped with an RN module can implicitly discover and learn to reason about entities and their relations.
研究动机与目标
- 将关系推理确立为智能行为的核心动机,并指出标准神经网络在此类任务中的局限性。
- 提出一个简单的、可插拔的 RN 模块,用于计算对象对之间的关系。
- 展示 RN 在多种领域的有效性:视觉问答(CLEVR)、文本问答(bAbI)和动态物理系统。
提出的方法
- 将 RN 定义为 RN(O)=f_phi(sum_{i,j} g_theta(o_i, o_j)),其中 o_i 表示对象表征。
- 使用 g_theta 计算对象对之间的关系,使用 f_phi 聚合关系。
- 通过求和实现对输入作为对象集合的顺序不变性的处理。
- 在适用时将 g_theta 条件化为附加输入,如问题嵌入。
- 展示 RN 可以通过从 CNN/LSTM 特征学习上游对象表征,在非结构化输入上工作。
- 端到端训练,使用 Adam 和标准的 CNN/LSTM 组件。
实验结果
研究问题
- RQ1是否存在一个专用的关系模块能够提高神经网络在不同领域推断对象之间关系的能力?
- RQ2当附着于现有架构时,Relation Networks 是否提供数据高效、顺序不变的关系推理?
- RQ3RNs 是否能够解决视觉问答、文本问答和动态物理系统中的关系性问题?
主要发现
| 模型 | 总体 | 计数 | 存在 | 数字比较 | 查询属性 | 比较属性 |
|---|---|---|---|---|---|---|
| Human | 92.6 | 86.7 | 96.6 | 86.5 | 95.0 | 96.0 |
| Q-type baseline | 41.8 | 34.6 | 50.2 | 51.0 | 36.0 | 51.3 |
| LSTM | 46.8 | 41.7 | 61.1 | 69.8 | 36.8 | 51.8 |
| CNN + LSTM | 52.3 | 43.7 | 65.2 | 67.1 | 49.3 | 53.0 |
| CNN+LSTM+SA | 68.5 | 52.2 | 71.1 | 73.5 | 85.3 | 52.3 |
| CNN+LSTM+SA* | 76.6 | 64.4 | 82.7 | 77.4 | 82.6 | 75.4 |
| CNN+LSTM+RN | 95.5 | 90.1 | 97.8 | 93.6 | 97.9 | 97.1 |
- 在 CLEVR 从像素输入达到最先进的、超越人类的性能(总分 95.5%)的增强模型。
- 在 CLEVR 的状态描述上获得 96.4% 的准确率。
- 在 Sort-of-CLEVR 上,CNN+RN 能以>94% 的准确率解决关系性与非关系性问题,而 CNN+MLP 在关系性问题上表现不佳。
- 在 bAbI 上,该模型解决了 18/20 个任务,未发生灾难性失败。
- 在动态物理系统中,RN 以 93% 的准确率推断连接,并以 95% 的准确率计数连接系统,优于 MLP。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。