[论文解读] Learning Sparse Nonparametric DAGs
该论文提出了一种通用的、可微分的优化框架,通过使用偏导数将代数无环性约束扩展至非参数结构方程模型(SEMs),从而实现从数据中学习稀疏的非参数有向无环图(DAGs)。该方法支持使用标准优化求解器进行端到端训练,在非线性和非参数模型上实现了最先进性能,且无需专用算法或针对特定模型的实现。
We develop a framework for learning sparse nonparametric directed acyclic graphs (DAGs) from data. Our approach is based on a recent algebraic characterization of DAGs that led to a fully continuous program for score-based learning of DAG models parametrized by a linear structural equation model (SEM). We extend this algebraic characterization to nonparametric SEM by leveraging nonparametric sparsity based on partial derivatives, resulting in a continuous optimization problem that can be applied to a variety of nonparametric and semiparametric models including GLMs, additive noise models, and index models as special cases. Unlike existing approaches that require specific modeling choices, loss functions, or algorithms, we present a completely general framework that can be applied to general nonlinear models (e.g. without additive noise), general differentiable loss functions, and generic black-box optimization routines. The code is available at https://github.com/xunzheng/notears.
研究动机与目标
- 开发一种通用的、与模型无关的基于评分的DAG学习框架,避免为每种模型类型设计专用算法。
- 将此前仅限于线性SEM的无环性连续优化公式扩展至一般非参数和半参数模型。
- 通过将DAG学习表述为光滑、可微分的规划问题,使标准优化例程(如L-BFGS-B)能够应用于非参数DAG学习。
- 在多种模型上(包括加法模型、指数模型、神经网络和正交基展开)展示该框架的有效性。
- 证明现成的求解器可在无需模型特定或算法特定调优的情况下实现具有竞争力的性能。
提出的方法
- 通过有向无环图的结构函数的雅可比矩阵的矩阵指数的迹,将无环性约束从线性SEM推广至非参数SEM。
- 利用结构函数的偏导数定义一个连续、可微的惩罚项,以在非参数模型中强制实现无环性。
- 使用灵活的函数族(如多层感知机(MLPs)和Sobolev型正交基展开)对结构函数进行参数化。
- 将DAG学习问题重新表述为带有可微无环性惩罚的约束优化问题,可通过标准非线性求解器求解。
- 集成邻域选择和边剪枝作为预处理/后处理步骤,以提升稀疏性和性能。
- 在PyTorch中实现该框架,支持端到端反向传播,并与深度学习工具包兼容。
实验结果
研究问题
- RQ1DAG学习中的无环性约束能否超越线性模型,推广到任意非参数结构方程模型?
- RQ2能否使用单一、统一的优化框架在参数模型、半参数模型和非参数模型中学习DAG,而无需针对每类模型设计专用算法?
- RQ3无环性的可微分、连续公式是否能使现成求解器在非线性和非参数数据上实现具有竞争力的性能?
- RQ4与现有最先进方法相比,该框架在真实世界生物数据上的表现如何?
- RQ5模型容量(如隐藏单元数量)对非参数DAG学习中的性能和泛化能力有何影响?
主要发现
- 在包含13条边的真实Sachs数据集上,该框架的SHD为16,优于NOTEARS(SHD 22)和GNN(SHD 19)。
- 在d=20、n=200的加法GP和GP设置下,结合边剪枝和邻域选择(NOTEARS-MLP++)的方法达到SHD 1.2,与CAM相当或更优。
- 在d=20、n=1000的加法GP模型中,隐藏单元数从0增至20可改善SHD,但进一步增至100时性能下降,表明在样本有限时存在过拟合。
- 该方法成功恢复了Sachs数据集中11条共识边中的7条,包括其他方法未能发现的3条:mek→erk、PIP3→PLCg和PKC→mek。
- 该框架支持使用基于梯度的优化实现所有边的高效全局同步更新,而不同于一次仅更新一条边的局部搜索方法。
- 该方法对模型选择具有鲁棒性:适用于MLPs、Sobolev基展开以及多种损失函数,且可与预处理(PNS)和后处理(边剪枝)结合使用。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。