[论文解读] Boosting Deep Learning Risk Prediction with Generative Adversarial Networks for Electronic Health Records
本文提出 ehrGAN,一种专为电子健康记录(EHRs)设计的生成对抗网络,用于生成真实且带标签的患者数据,以实现半监督风险预测。通过将有限的真实EHR数据与合成样本相结合,该方法显著提升了深度学习模型在心力衰竭和糖尿病预测中的性能,相较于最先进基线模型,在HF50上AUROC提升最高达0.0291,在Dia50上提升最高达0.0201。
The rapid growth of Electronic Health Records (EHRs), as well as the accompanied opportunities in Data-Driven Healthcare (DDH), has been attracting widespread interests and attentions. Recent progress in the design and applications of deep learning methods has shown promising results and is forcing massive changes in healthcare academia and industry, but most of these methods rely on massive labeled data. In this work, we propose a general deep learning framework which is able to boost risk prediction performance with limited EHR data. Our model takes a modified generative adversarial network namely ehrGAN, which can provide plausible labeled EHR data by mimicking real patient records, to augment the training dataset in a semi-supervised learning manner. We use this generative model together with a convolutional neural network (CNN) based prediction model to improve the onset prediction performance. Experiments on two real healthcare datasets demonstrate that our proposed framework produces realistic data samples and achieves significant improvements on classification tasks with the generated data over several stat-of-the-art baselines.
研究动机与目标
- 解决深度学习在医疗领域中因标注EHR数据有限而影响模型性能的挑战。
- 开发一种生成模型,以生成真实且临床合理的EHR样本,用于扩充训练数据。
- 利用合成数据进行半监督学习,提升心力衰竭和糖尿病等疾病的风险预测性能。
- 证明在低数据环境下,结合基于GAN的数据生成与深度神经网络在临床预测中的有效性。
提出的方法
- 提出 ehrGAN,一种经过修改的GAN架构,通过对抗训练生成具有正确标签的真实EHR序列。
- 采用条件生成器,从基于患者标签的条件学习潜在空间采样,以确保生成数据中的标签一致性。
- 将生成器整合进半监督学习框架中,利用真实数据和生成数据联合训练基于CNN的风险预测器。
- 通过超参数ρ优化训练目标,以控制重建损失与对抗损失之间的权衡,确保生成样本的多样性与合理性。
- 通过超参数μ控制数据使用比例,平衡训练过程中真实标注数据与生成数据的比例。
- 将该框架应用于两个真实世界EHR数据集(HF50和Dia50)进行风险预测任务,使用AUROC和准确率评估性能。
实验结果
研究问题
- RQ1基于GAN的模型能否生成真实、带标签的EHR序列,使其与真实患者记录难以区分?
- RQ2与监督基线相比,使用合成EHR数据进行半监督学习是否能提升风险预测性能?
- RQ3为最大化预测性能,真实数据与生成数据之间的最优比例(由μ控制)是什么?
- RQ4控制生成器损失的超参数ρ如何影响生成样本的质量与实用性,以及下游预测的准确性?
- RQ5所提出的框架能否在标注数据有限的不同临床预测任务中实现泛化?
主要发现
- ehrGAN模型成功生成了在临床合理性与时间模式方面与真实患者记录难以区分的真实EHR样本。
- 所提出的SSL-GAN框架在HF50数据集上相比最佳基线(CNN-BASIC)实现了0.0291的AUROC提升,达到0.9075 AUROC(ρ=0.1,μ=0.6)。
- 在Dia50数据集上,该方法相比基线实现了0.0201的AUROC提升,最优设置下达到0.9354 AUROC。
- ρ的最优设置为0.1,因为ρ=0或ρ=1等值会导致样本质量差或标签不一致,从而降低性能。
- 当ρ=0.1时,μ的最优值为0.6,表明过度使用生成数据会损害性能,提示需谨慎控制真实与生成数据的比例。
- 该框架在两个数据集上均持续优于标准CNN以及两种强基线半监督学习方法(SSL-SMIR和SSL-LGC),展现出稳健的泛化能力。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。