Skip to main content
QUICK REVIEW

[论文解读] Layer-Dependent Importance Sampling for Training Deep and Large Graph Convolutional Networks

Difan Zou, Ziniu Hu|arXiv (Cornell University)|Nov 17, 2019
Advanced Graph Neural Networks被引用 87
一句话总结

LADIES 引入层依赖的重要性采样,以在较低的内存/时间成本下训练深度与大规模 GCN,并在比以往采样方法更好的泛化性能。

ABSTRACT

Graph convolutional networks (GCNs) have recently received wide attentions, due to their successful applications in different graph tasks and different domains. Training GCNs for a large graph, however, is still a challenge. Original full-batch GCN training requires calculating the representation of all the nodes in the graph per GCN layer, which brings in high computation and memory costs. To alleviate this issue, several sampling-based methods have been proposed to train GCNs on a subset of nodes. Among them, the node-wise neighbor-sampling method recursively samples a fixed number of neighbor nodes, and thus its computation cost suffers from exponential growing neighbor size; while the layer-wise importance-sampling method discards the neighbor-dependent constraints, and thus the nodes sampled across layer suffer from sparse connection problem. To deal with the above two problems, we propose a new effective sampling algorithm called LAyer-Dependent ImportancE Sampling (LADIES). Based on the sampled nodes in the upper layer, LADIES selects their neighborhood nodes, constructs a bipartite subgraph and computes the importance probability accordingly. Then, it samples a fixed number of nodes by the calculated probability, and recursively conducts such procedure per layer to construct the whole computation graph. We prove theoretically and experimentally, that our proposed sampling algorithm outperforms the previous sampling methods in terms of both time and memory costs. Furthermore, LADIES is shown to have better generalization accuracy than original full-batch GCN, due to its stochastic nature.

研究动机与目标

  • 在大图上训练深 GCN,克服全批次成本与节点级采样的冗余问题。
  • 开发一个层依赖的采样方案,以维持连接性并降低方差。
  • 证明相对于现有方法在理论上的效率与方差改进。
  • 在基准数据集上展示在运行时间、内存和准确度方面的经验提升。

提出的方法

  • 提出 LADIES,对于每一层,基于上层采样节点及其邻居构建一个二部子图。
  • 使用 p_i^{(l-1)} = ||Q^{(l)} P_{*,i}||_2^2 / ||Q^{(l)} P||_F^2 计算逐层重要性概率以指导采样。
  • 基于计算得到的概率在每一层采样固定数量的节点,并构建一个密集的、归一化的采样邻接矩阵 tilde{P}^{(l-1)} 以传播嵌入。
  • 使用自上而下的层依赖采样以确保连通性并避免感受野指数级增长。
  • 通过逐行归一化 tilde{P}^{(l)} 以稳定训练。
  • 给出内存/时间复杂度与方差的理论分析,并在多个数据集上进行经验验证。

实验结果

研究问题

  • RQ1层依赖采样如何改善深度 GCN 的计算图连通性与效率?
  • RQ2与节点级和层级先前方法相比,LADIES 是否提供更低的内存/时间复杂度和更小的方差?
  • RQ3LADIES 是否在标准图基准数据集上改善或保持预测准确性与泛化能力?
  • RQ4在非常大规模图上,哪些采样规模足以实现强性能?

主要发现

DatasetSample MethodF1-Score(%)Total Time(s)Mem(MB)Batch Time(ms)Batch Num
Cora (2708)Full-Batch76.5±1.41.19±0.8230.7215.75±0.5280.8±51.7
Cora (2708)GraphSage (5)75.2±1.56.77±4.94471.3978.42±0.8765.2±52.1
Cora (2708)FastGCN (64)25.1±8.40.55±0.653.139.22±0.2063.2±71.2
Cora (2708)FastGCN (512)78.0±2.14.70±1.357.3310.08±0.29487±147
Cora (2708)LADIES (64)77.6±1.44.19±1.163.139.68±0.48436±118.4
Cora (2708)LADIES (512)78.3±1.60.72±0.397.359.77±0.2875.6±37.0
Citeseer (3327)Full-Batch62.3±3.10.61±0.7068.1315.77±0.5840.6±22.8
Citeseer (3327)GraphSage (5)59.4±0.94.51±3.68595.7153.14±1.9057.2±42.1
Citeseer (3327)FastGCN (64)19.2±2.70.53±0.485.898.88±0.4064.0±57.0
Citeseer (3327)FastGCN (512)44.6±10.84.34±1.7313.9710.41±0.51386±167
Citeseer (3327)FastGCN (1024)63.5±1.82.24±1.0123.2410.54±0.27223±98.6
Citeseer (3327)LADIES (64)65.0±1.42.17±0.655.899.60±0.39232±66.8
Citeseer (3327)LADIES (512)64.3±2.40.41±0.2213.9210.32±0.2337.6±11.9
Pubmed (19717)Full-Batch71.9±1.94.80±1.53137.9344.69±0.57102±33.4
Pubmed (19717)GraphSage (5)70.1±1.45.53±2.57453.5844.73±0.3074.8±31.7
Pubmed (19717)FastGCN (64)38.5±6.90.40±0.691.927.42±0.1658.8±94.8
Pubmed (19717)FastGCN (512)39.3±9.20.44±0.614.5310.06±0.4144.8±55.0
Pubmed (19717)FastGCN (8192)74.4±0.83.47±1.1649.4117.84±0.33195±56.9
Pubmed (19717)LADIES (64)76.8±0.82.57±0.721.929.43±0.47277±82.2
Pubmed (19717)LADIES (512)75.9±1.12.27±1.174.3910.43±0.36245±84.5
Reddit (232965)Full-Batch91.6±1.6474.3±84.42370.481564±3.41179±75.5
Reddit (232965)GraphSage (5)92.1±1.113.12±2.841234.63121.47±0.7281.5±42.3
Reddit (232965)FastGCN (64)27.8±12.62.06±1.293.757.85±0.7257.4±43.7
Reddit (232965)FastGCN (512)17.5±16.70.31±0.416.9110.01±0.3132.1±72.3
Reddit (232965)FastGCN (8192)89.5±1.25.63±2.1274.2816.57±0.58278±51.2
Reddit (232965)LADIES (64)83.5±0.95.62±1.583.759.42±0.48453±88.2
Reddit (232965)LADIES (512)92.8±1.66.87±1.177.2610.87±0.63393±74.4
  • LADIES 在内存和时间成本方面低于节点级采样方法,并保持同等或更好的准确度。
  • 与 FastGCN 相比,LADIES 由于更小的有效连接节点集而获得更小的方差,并且在大图上受益于较小的采样规模。
  • 在基准数据集(Cora、Citeseer、Pubmed、Reddit)上,LADIES 以更小的样本量(如 64)与更深的结构获得最佳测试准确度。
  • LADIES 展现出较强的泛化能力,即使使用随机采样,也常常优于全批次 GCN 的验证/测试性能。
  • LADIES 能在不出现指数级计算增长的情况下,扩展到极大规模图与深层 GCN。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。