[论文解读] Debiasing Graph Neural Networks via Learning Disentangled Causal Substructure
提出 DisC,一种分离式 GNN 框架,通过可学习的边掩码将图划分为因果子图和偏差子图,在每个子图上训练独立的 GNN,采用偏差感知和因果损失,并生成反事实无偏样本以在严重偏差下提升泛化。
Most Graph Neural Networks (GNNs) predict the labels of unseen graphs by learning the correlation between the input graphs and labels. However, by presenting a graph classification investigation on the training graphs with severe bias, surprisingly, we discover that GNNs always tend to explore the spurious correlations to make decision, even if the causal correlation always exists. This implies that existing GNNs trained on such biased datasets will suffer from poor generalization capability. By analyzing this problem in a causal view, we find that disentangling and decorrelating the causal and bias latent variables from the biased graphs are both crucial for debiasing. Inspiring by this, we propose a general disentangled GNN framework to learn the causal substructure and bias substructure, respectively. Particularly, we design a parameterized edge mask generator to explicitly split the input graph into causal and bias subgraphs. Then two GNN modules supervised by causal/bias-aware loss functions respectively are trained to encode causal and bias subgraphs into their corresponding representations. With the disentangled representations, we synthesize the counterfactual unbiased training samples to further decorrelate causal and bias variables. Moreover, to better benchmark the severe bias problem, we construct three new graph datasets, which have controllable bias degrees and are easier to visualize and explain. Experimental results well demonstrate that our approach achieves superior generalization performance over existing baselines. Furthermore, owing to the learned edge mask, the proposed model has appealing interpretability and transferability. Code and data are available at: https://github.com/googlebaba/DisC.
研究动机与目标
- 从因果角度动机与分析严重偏差对 GNN 在图分类中的影响。
- 开发一个框架,在图中解耦因果子结构与偏差子结构。
- 学习一个全局边掩码,将图划分为因果子图和偏差子图。
- 在每个子图上训练专用的 GNN 模块,结合偏差感知损失和因果损失。
- 生成反事实无偏表征以去相关化因果与偏差信号。
提出的方法
- 引入一个参数化的边掩码生成器,为边分配概率以形成因果子图和偏差子图。
- 在掩码子图上训练两个 GNN:一个因果 GNN 和一个偏差 GNN,每个具有相应的损失函数(因果损失和广义交叉熵偏差损失)。
- 使用偏差感知损失(广义交叉熵)以增强对偏差子结构的学习;在因果分支中使用加权交叉熵以强调因果性。
- 计算一个无偏分数以重新加权图,并用选择性加权的损失来训练因果分支。
- 通过对偏差表征在图之间置换并交换相应标签来生成反事实无偏样本,然后用包括这些样本的联合损失进行训练。
实验结果
研究问题
- RQ1严重偏差如何影响 GNN 在图分类中的泛化?
- RQ2是否能在全局层面从偏置图中解耦因果与偏差子结构?
- RQ3是否能生成反事实无偏样本以去相关因果与偏差因素并提升泛化?
- RQ4全局边掩码在跨图群体中识别因果与偏置子图的有效性如何?
- RQ5DisC 框架是否在多种基础 GNN 架构上改善泛化?
主要发现
- DisC 在三个带偏图数据集、不同偏置程度下显著提升相对于基线 GNN 的泛化能力。
- DisC 持续优于现有去偏方法如 DIR 和 StableGNN,尤其在偏置更重时。
- 边掩码产生了可解释的子图,其中因果子图与真实判别结构对齐,偏置子图与虚假模式对齐。
- 反事实嵌入生成使因果和偏置因素去相关,从而让因果信号驱动预测。
- 解耦表示在潜在空间中按因果与偏置因素聚类,支持可解释性与迁移性。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。