Skip to main content
QUICK REVIEW

[论文解读] Meta-Learning for Stochastic Gradient MCMC

Wenbo Gong, Yingzhen Li|arXiv (Cornell University)|Jun 12, 2018
Domain Adaptation and Few-Shot Learning参考文献 38被引用 18
一句话总结

该论文提出了首个用于自动设计随机梯度 MCMC(SG-MCMC)采样器的元学习框架,通过神经网络扩展哈密顿动力学,以参数化状态相关的漂移和扩散矩阵。所学习的采样器在不同数据集和网络架构间具有良好的泛化能力,在贝叶斯神经网络推理中表现出更快的收敛速度和更高的采样效率,优于标准的 SG-MCMC 方法。

ABSTRACT

Stochastic gradient Markov chain Monte Carlo (SG-MCMC) has become increasingly popular for simulating posterior samples in large-scale Bayesian modeling. However, existing SG-MCMC schemes are not tailored to any specific probabilistic model, even a simple modification of the underlying dynamical system requires significant physical intuition. This paper presents the first meta-learning algorithm that allows automated design for the underlying continuous dynamics of an SG-MCMC sampler. The learned sampler generalizes Hamiltonian dynamics with state-dependent drift and diffusion, enabling fast traversal and efficient exploration of neural network energy landscapes. Experiments validate the proposed approach on both Bayesian fully connected neural network and Bayesian recurrent neural network tasks, showing that the learned sampler out-performs generic, hand-designed SG-MCMC algorithms, and generalizes to different datasets and larger architectures.

研究动机与目标

  • 自动化设计针对特定概率模型(尤其是贝叶斯神经网络)的 SG-MCMC 采样器。
  • 克服在设计保持正确平稳分布的 SG-MCMC 动力学时对人工物理直觉的依赖。
  • 实现所学采样器在不同数据集、网络架构和任务复杂度间的泛化能力。
  • 开发一种元学习框架,通过在简单任务上训练采样器,并将其迁移至复杂、高维的后验分布。
  • 通过学习最优动力学以穿越复杂能量景观,提升贝叶斯深度学习中的采样效率与收敛速度。

提出的方法

  • 提出一种通过神经网络参数化漂移矩阵和扩散矩阵来扩展哈密顿动力学的元学习 SG-MCMC 采样器。
  • 采用连续 SDE 公式,使用可学习函数表示漂移(旋度矩阵)和扩散(扩散矩阵),确保目标后验分布保持为平稳分布。
  • 采用元学习目标,使采样器在任务分布(如不同数据集或架构)上进行训练,以学习可泛化的动力学。
  • 采用展开轨迹的双层优化设置进行训练:内层循环模拟 SG-MCMC 动力学,外层循环通过验证性能的梯度下降更新元参数。
  • 通过两个神经网络的乘积参数化扩散矩阵:一个用于自适应摩擦,一个用于动量控制,实现基于能量和梯度方向的动态调整。
  • 将所学采样器应用于全连接和循环神经网络,在贝叶斯推理任务上通过测试负对数似然(NLL)和收敛速度评估性能。

实验结果

研究问题

  • RQ1元学习能否用于自动设计适用于贝叶斯神经网络推理的 SG-MCMC 采样器,使其具备最优动力学?
  • RQ2元学习采样器是否能在不重新训练的情况下泛化至不同数据集和网络架构?
  • RQ3所学习的动力学采样器是否在采样效率和精度上优于人工设计的 SG-MCMC 算法(如 SGHMC 和 SGLD)?
  • RQ4所学习的漂移和扩散矩阵如何适应能量景观的不同区域(如高能区与低能区)?
  • RQ5当应用于具有不同数据分布或序列结构(如 RNN 中)的任务时,元学习采样器是否能保持性能?

主要发现

  • 在 Piano-midi 数据集上,元学习采样器的收敛速度优于 SGHMC,且在早期训练阶段与 Santa 和 Adam 具有相近的速度。
  • 在 MuseData 数据集上,元学习采样器的最终测试 NLL 优于 SGHMC,表明其在长时程采样中具有更优性能。
  • 元学习采样器成功泛化至 Nottingham 和 JSB chorales 数据集,尽管性能略逊于其他方法,可能由于与训练数据的分布差异所致。
  • 在摩擦网络中移除偏置项 β(NNSGHMC-s)后,过拟合现象减少,JSB 数据集上的测试 NLL 降低至 8.40,表明鲁棒性得到提升。
  • 所学习的扩散矩阵动态调整摩擦:在低能区域摩擦更高以防止发散,在高能区域摩擦更低以维持动量,其调整基于梯度与动量符号的一致性。
  • 在贝叶斯全连接和循环神经网络中,元学习采样器在采样效率和泛化能力上均优于通用的 SG-MCMC 方法(如 SGLD 和 SGHMC),验证了其有效性。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。