[论文解读] Efficient and Modular Implicit Differentiation
本论文推出自动隐式微分,一个 Python/JAX 框架,通过指定最优性条件 F 来对优化问题解进行微分,在现有求解器之上实现模块化、与求解器无关的微分,且提供雅可比误差保证及多种应用。
Automatic differentiation (autodiff) has revolutionized machine learning. It allows to express complex computations by composing elementary ones in creative ways and removes the burden of computing their derivatives by hand. More recently, differentiation of optimization problem solutions has attracted widespread attention with applications such as optimization layers, and in bi-level problems such as hyper-parameter optimization and meta-learning. However, so far, implicit differentiation remained difficult to use for practitioners, as it often required case-by-case tedious mathematical derivations and implementations. In this paper, we propose automatic implicit differentiation, an efficient and modular approach for implicit differentiation of optimization problems. In our approach, the user defines directly in Python a function $F$ capturing the optimality conditions of the problem to be differentiated. Once this is done, we leverage autodiff of $F$ and the implicit function theorem to automatically differentiate the optimization problem. Our approach thus combines the benefits of implicit differentiation and autodiff. It is efficient as it can be added on top of any state-of-the-art solver and modular as the optimality condition specification is decoupled from the implicit differentiation mechanism. We show that seemingly simple principles allow to recover many existing implicit differentiation methods and create new ones easily. We demonstrate the ease of formulating and solving bi-level optimization problems using our framework. We also showcase an application to the sensitivity analysis of molecular dynamics.
研究动机与目标
- 降低使用隐式微分的门槛,让用户直接在 Python 中指定最优性条件。
- 将隐式微分与自动微分结合,在不重新实现求解器的情况下对优化解进行微分。
- 提供一个与最先进求解器兼容、适用于各种最优性条件的模块化框架。
- 为近似解提供理论上的雅可比精度保证。
- 通过多样化的应用展示分层优化与灵敏度分析的实际可行性。
提出的方法
- 用户定义映射 F,捕捉算法所求解问题的最优性条件。
- 使用 Python 装饰器 (@custom_root) 在 F 的基础上把隐式微分附加到求解器之上。
- 应用隐函数定理通过线性系统 -∂1F(x*(theta),theta) ∂x*(theta) = ∂2F(x*(theta),theta) 将 x*(theta) 与 theta 关联起来。
- 利用矩阵无关线性求解器(CG、GMRES、BiCGSTAB)高效计算雅可比-向量乘积和向量-雅可比乘积(JVP/VJP)。
- 支持前处理/后处理映射,以便将微分与其他神经网络或损失运算组合。
- 提供各种最优性条件映射的实际实现(驻点、KKT、近端梯度不动点、投影梯度不动点等)。
实验结果
研究问题
- RQ1自动隐式微分是否能够通过用户定义的 F 覆盖广泛的最优性条件目录?
- RQ2当内部优化近似求解时,雅可比精度保障是什么?
- RQ3在不同求解器和固定点表示下,该框架在效率和灵活性上与展开展开(unrolling)方法相比如何?
- RQ4该方法是否能够轻松处理双层问题,如超参数优化、数据集蒸馏和字典学习?
- RQ5在本框架中实现和对常用优化方案(近端、投影、镜像下降)进行微分的实际指南是什么?
主要发现
- 该框架通过对残差映射 F 使用自动微分并应用隐函数定理来对优化问题解进行微分。
- 来自近似解的雅可比误差被界定并随内部解的残差而缩放(定理1)。
- 该方法能够在一个单一、模块化的系统中恢复现有的隐式微分方法并实现新的方法。
- 实验表明在多类别 SVM 的超参数优化、数据集蒸馏、任务驱动的字典学习和分子动力学灵敏度分析等方面实现了高效微分。
- 使用固定点形式(镜像下降、近端梯度、投影梯度)结合隐式微分可实现实际、可扩展的双层优化工作流。
- 与展开法相比,该方法在若干双层任务中表现出更有利的运行时间,展示了速度与简化性方面的实际收益。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。