[论文解读] Inductive Representation Learning on Large Graphs
GraphSAGE 提出一个归纳框架,通过学习邻域聚合函数来生成未见节点的节点嵌入,在三个归纳节点分类任务上优于基线。
Low-dimensional embeddings of nodes in large graphs have proved extremely useful in a variety of prediction tasks, from content recommendation to identifying protein functions. However, most existing approaches require that all nodes in the graph are present during training of the embeddings; these previous approaches are inherently transductive and do not naturally generalize to unseen nodes. Here we present GraphSAGE, a general, inductive framework that leverages node feature information (e.g., text attributes) to efficiently generate node embeddings for previously unseen data. Instead of training individual embeddings for each node, we learn a function that generates embeddings by sampling and aggregating features from a node's local neighborhood. Our algorithm outperforms strong baselines on three inductive node-classification benchmarks: we classify the category of unseen nodes in evolving information graphs based on citation and Reddit post data, and we show that our algorithm generalizes to completely unseen graphs using a multi-graph dataset of protein-protein interactions.
研究动机与目标
- 使人们意识到在不断演化的图中需要能泛化到未见节点的归纳式节点嵌入。
- 提出一个通用的 GraphSAGE 框架,学习聚合邻域特征以生成节点嵌入。
- 评估多种聚合器架构,并在多样数据集上展示比基线更好的预测性能。
- 展示该方法能够跨图泛化,并对学习局部图结构提供理论洞见。
提出的方法
- 提出 GraphSAGE,它通过学习可训练的聚合函数来组合节点局部邻域的特征,并通过多跳(K)生成嵌入。
- 使用前向传播过程(算法1),其中每一层聚合邻居表示,与节点自身表示拼接,并应用带有学习权重 W^k 的非线性变换。
- 采用无监督损失(方程1),鼓励相邻节点具有相似表示,距离较远的节点具有不同表示,任务特定目标可选用监督变体。
- 探索不同的聚合器架构(均值、LSTM、池化)以捕捉邻域信息,同时确保对邻居顺序的对称性。
实验结果
研究问题
- RQ1GraphSAGE 是否能够在训练时未见的节点上生成有意义的嵌入(归纳设定)?
- RQ2不同邻域聚合器如何影响归纳嵌入质量和可扩展性?
- RQ3学习到的聚合器在多大程度上能够捕捉局部图结构并实现跨图泛化?
- RQ4在真实世界的归纳任务中,GraphSAGE 与转导基线及其他嵌入方法相比如何?
主要发现
| 名称 | 引用 Unsup F1 | 引用 Sup F1 | Reddit Unsup F1 | Reddit Sup F1 | PPI Unsup F1 | PPI Sup F1 |
|---|---|---|---|---|---|---|
| Random | 0.206 | 0.206 | 0.043 | 0.042 | 0.396 | 0.396 |
| Raw features | 0.575 | 0.575 | 0.585 | 0.585 | 0.422 | 0.422 |
| DeepWalk | 0.565 | 0.565 | 0.324 | 0.324 | — | — |
| DeepWalk + features | 0.701 | 0.701 | 0.691 | 0.691 | — | — |
| GraphSAGE-GCN | 0.742 | 0.772 | 0.908 | 0.930 | 0.465 | 0.500 |
| GraphSAGE-mean | 0.778 | 0.820 | 0.897 | 0.950 | 0.486 | 0.598 |
| GraphSAGE-LSTM | 0.788 | 0.832 | 0.907 | 0.954 | 0.482 | 0.612 |
| GraphSAGE-pool | 0.798 | 0.839 | 0.892 | 0.948 | 0.502 | 0.600 |
- GraphSAGE 在引用、Reddit 和 PPI 数据集上均优于基线(随机、原始特征、DeepWalk 以及 DeepWalk+特征)。
- 在聚合器中,LSTM 与池化变体整体表现最佳,均值在结果上具有竞争力;基于 GCN 的聚合在某些任务中的表现较弱。
- 无监督的 GraphSAGE 也能达到接近完全监督变体的强性能,表明在没有任务特定标签时也具有良好实用性。
- K=2 的 GraphSAGE 变体与中等规模的邻域采样相比,带来显著的准确性提升(平均约比 K=1 高出约10-15%),并且具有较优的运行时间。
- 该方法在 PPI 场景中展示了跨图泛化能力,对多张图进行训练可以改善在未见图上的性能。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。