[论文解读] DAG-GNN: DAG Structure Learning with Graph Neural Networks
DAG-GNN 使用图神经网络和变分自编码器的深度生成模型来学习有向无环图结构,能够处理连续/离散和向量值变量,使用实际可行的连通性约束和增强拉格朗日训练。
Learning a faithful directed acyclic graph (DAG) from samples of a joint distribution is a challenging combinatorial problem, owing to the intractable search space superexponential in the number of graph nodes. A recent breakthrough formulates the problem as a continuous optimization with a structural constraint that ensures acyclicity (Zheng et al., 2018). The authors apply the approach to the linear structural equation model (SEM) and the least-squares loss function that are statistically well justified but nevertheless limited. Motivated by the widespread success of deep learning that is capable of capturing complex nonlinear mappings, in this work we propose a deep generative model and apply a variant of the structural constraint to learn the DAG. At the heart of the generative model is a variational autoencoder parameterized by a novel graph neural network architecture, which we coin DAG-GNN. In addition to the richer capacity, an advantage of the proposed model is that it naturally handles discrete variables as well as vector-valued ones. We demonstrate that on synthetic data sets, the proposed method learns more accurate graphs for nonlinearly generated samples; and on benchmark data sets with discrete variables, the learned graphs are reasonably close to the global optima. The code is available at \url{https://github.com/fishmoon1234/DAG-GNN}.
研究动机与目标
- 从样本中学习忠实的 DAG 结构的动机,超越线性 SEM 假设。
- 开发一个能够捕捉非线性和多样数据类型(连续、离散、向量值)的深度生成框架。
- 在变分自编码器中利用基于图神经网络的编码器/解码器来建模数据分布,条件为加权邻接矩阵 A。
- 通过实际的连续约束确保无环性,并通过增强拉格朗日优化进行训练。
- 在合成、基准和应用数据集上显示相比现有线性 SEM 基方法的改进结构恢复。
提出的方法
- 对 DAG 进行加权邻接矩阵 A 的参数化,并将 X 表示为 X = f2((I−AT)−1 f1(Z)),其中 Z 为潜在输入。
- 使用基于图神经网络的 VAE,其中编码器计算 q(Z|X),解码器在给定 A 的条件下定义 p(X|Z)。
- 采用变分目标 ELBO,具有闭式 KL(q(Z|X)‖p(Z)) 和蒙特卡罗重构项。
- 通过使用沿行 softmax 的分类解码器输出来处理离散变量;相应地调整似然。
- 引入一个实用的无环性约束 h(A) = tr[(I + α A∘A)^m] − m = 0,并通过增强拉格朗日方法优化。
- 训练在最小化增强拉格朗日项和更新拉格朗日乘子及惩罚 C 以满足无环性约束之间交替进行。
实验结果
研究问题
- RQ1一个基于图神经网络的变分自编码器是否能够在非线性且混合类型数据(连续和离散)中准确恢复 DAG 结构?
- RQ2提出的 DAG-GNN 框架是否在合成非线性数据和离散变量基准上优于线性 SEM 基方法(例如 DAG-NOTEARS)?
- RQ3在深度生成 DAG 模型中实施实际的无环性约束如何影响学习到的图的质量?
- RQ4向量值变量是否可以在这个 DAG-GNN 框架中有效建模,相较于标量变量方法是否提供更好的结构恢复?
主要发现
| 数据集 | m | Groundtruth | GOPNILP | DAG-GNN |
|---|---|---|---|---|
| Child | 20 | -1.27e+4 | -1.27e+4 | -1.38e+4 |
| Alarm | 37 | -1.27e+4 | -1.12e+4 | -1.28e+4 |
| Pigs | 441 | -3.48e+5 | -3.50e+5 | -3.69e+5 |
- 在合成非线性数据上,DAG-GNN 提升了 SHD,并在非线性增加时显著降低 FDR,相较于 DAG-NOTEARS。
- 对于向量值数据,DAG-GNN 的表现优于 DAG-NOTEARS,能够以可比甚至更好的准确度恢复所有真实边。
- 在离散基准数据集(Child、Alarm、Pigs)上,DAG-GNN 的 BIC 分数接近真实值并与 GOPNILP 具有竞争力,尽管自编码器更简单。
- 在蛋白质信号网络(Sachs 等)中,DAG-GNN 在比较方法中实现了最低的 SHD(18-22 范围),并学习出一个可行的无环图,具有多条真实和有前瞻性的间接/反向边。
- 该框架扩展到知识库关系图,从 KB 方案数据中提取直观的因果边。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。