[论文解读] Training (Overparametrized) Neural Networks in Near-Linear Time
本文提出了一种近线性时间算法,用于训练过参数化的 ReLU 神经网络,通过将高斯-牛顿法重新表述为 ℓ2-回归问题,并利用快速-约翰逊-林登施特劳斯(Fast-JL)降维技术预处理格拉姆矩阵,从而实现加速。该方法每轮迭代的时间复杂度为 eO(mnd + n³),总运行时间为 eO((mnd + n³) log(1/ǫ)),可将训练损失降低至 ǫ,相较于以往的二阶方法实现了二次加速,并证明了先进随机化线性代数技术在深度学习优化中的可行性。
The slow convergence rate and pathological curvature issues of first-order gradient methods for training deep neural networks, initiated an ongoing effort for developing faster $\mathit{second}$-$\mathit{order}$ optimization algorithms beyond SGD, without compromising the generalization error. Despite their remarkable convergence rate ($\mathit{independent}$ of the training batch size $n$), second-order algorithms incur a daunting slowdown in the $\mathit{cost}$ $\mathit{per}$ $\mathit{iteration}$ (inverting the Hessian matrix of the loss function), which renders them impractical. Very recently, this computational overhead was mitigated by the works of [ZMG19,CGH+19}, yielding an $O(mn^2)$-time second-order algorithm for training two-layer overparametrized neural networks of polynomial width $m$. We show how to speed up the algorithm of [CGH+19], achieving an $ ilde{O}(mn)$-time backpropagation algorithm for training (mildly overparametrized) ReLU networks, which is near-linear in the dimension ($mn$) of the full gradient (Jacobian) matrix. The centerpiece of our algorithm is to reformulate the Gauss-Newton iteration as an $\ell_2$-regression problem, and then use a Fast-JL type dimension reduction to $\mathit{precondition}$ the underlying Gram matrix in time independent of $M$, allowing to find a sufficiently good approximate solution via $\mathit{first}$-$\mathit{order}$ conjugate gradient. Our result provides a proof-of-concept that advanced machinery from randomized linear algebra -- which led to recent breakthroughs in $\mathit{convex}$ $\mathit{optimization}$ (ERM, LPs, Regression) -- can be carried over to the realm of deep learning as well.
研究动机与目标
- 为解决深度学习中二阶优化的高计算成本问题,特别是以往针对过参数化网络的高斯-牛顿方法每轮迭代成本为 O(mn²) 的问题。
- 将二阶优化扩展至 ReLU 网络,因其相较于平滑激活函数网络更具复杂性和现实意义。
- 在完整梯度维度(mn)上实现近线性时间训练,克服传统二阶方法中 Hessian 矩阵求逆的瓶颈。
- 证明先前用于凸优化的先进随机化线性代数技术可有效迁移至非凸深度学习训练场景。
提出的方法
- 将高斯-牛顿更新重新表述为对雅可比矩阵的 ℓ2-回归问题,以通过共轭梯度法高效求解。
- 应用类 Fast-Johnson-Lindenstrauss(Fast-JL)的降维技术对格拉姆矩阵 J_t J_t^T 进行预处理,降低其规模,同时保持解的质量。
- 使用一阶共轭梯度法求解预处理后的回归问题,其时间复杂度与原始矩阵大小 M 无关。
- 借助神经正切核(NTK)理论,证明过参数化网络的局部线性化成立,从而可简化为核回归问题。
- 通过浓度不等式控制雅可比矩阵近似误差和回归子问题解的误差,确保收敛性。
- 将回归求解器集成至反向传播框架中,每轮迭代时间复杂度为 eO(mnd + n³),主要由雅可比矩阵计算和回归求解主导。
实验结果
研究问题
- RQ1能否通过将每轮迭代成本降低至 O(mn²) 以下,使针对过参数化 ReLU 网络的二阶优化变得实用?
- RQ2能否利用 Fast-JL 降维技术对高斯-牛顿 Hessian 近似进行预处理,同时保持收敛性保证?
- RQ3能否在两层 ReLU 网络中实现以完整梯度维度(mn)为基准的近线性时间训练?
- RQ4能否将随机化线性代数的工具——在凸优化中已证明成功——适配至非凸深度学习训练场景?
- RQ5所提出的算法在实现显著快于 SGD 的收敛速度的同时,是否仍能保持良好的泛化性能?
主要发现
- 该算法相较于 [CGH+19] 的 O(mn²) 方法实现了二次加速,将每轮迭代成本降低至 eO(mnd + n³)。
- 该方法是首个在适度过参数化条件下(m = Ω(max{λ⁻⁴n⁴, λ⁻²n²d log(n/δ)})) 实现近线性时间复杂度(以梯度维度 mn 为基准)的 ReLU 网络二阶算法。
- 该算法保证每轮迭代满足 ∥ft+1 − y∥² ≤ ½∥ft − y∥²,从而实现对目标损失的线性收敛。
- 将训练损失降低至 ǫ 的总运行时间为 eO((mnd + n³) log(1/ǫ)),若使用快速矩阵乘法,可进一步优化为 eO((mnd + n^ω) log(1/ǫ))。
- 该算法维持有界的权重更新,确保网络权重不会显著偏离初始化,支持泛化能力。
- 该方法成功应用于凸优化,将牛顿法的运行时间优化至 eO((nd log(κ) + d³) log(1/ǫ)),适用于 γ-强凸、β-光滑且 Hessian 满足 L-Lipschitz 条件的函数。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。