[论文解读] Training Multi-Layer Over-Parametrized Neural Network in Subquadratic Time
本文提出了一种新颖的框架,通过利用结构化权重矩阵和高效的数据结构,在每次迭代中以次二次方时间训练深度过参数化的神经网络。通过初始化阶段的预处理和自适应梯度计算,该方法将每次迭代的计算成本降低至 O(m²⁻Ω(1)),显著低于标准的 O(m²),从而实现了大规模语言模型的更快微调。
We consider the problem of training a multi-layer over-parametrized neural network to minimize the empirical risk induced by a loss function. In the typical setting of over-parametrization, the network width $m$ is much larger than the data dimension $d$ and the number of training samples $n$ ($m=\mathrm{poly}(n,d)$), which induces a prohibitive large weight matrix $W\in \mathbb{R}^{m imes m}$ per layer. Naively, one has to pay $O(m^2)$ time to read the weight matrix and evaluate the neural network function in both forward and backward computation. In this work, we show how to reduce the training cost per iteration. Specifically, we propose a framework that uses $m^2$ cost only in the initialization phase and achieves \emph{a truly subquadratic cost per iteration} in terms of $m$, i.e., $m^{2-Ω(1)}$ per iteration. Our result has implications beyond standard over-parametrization theory, as it can be viewed as designing an efficient data structure on top of a pre-trained large model to further speed up the fine-tuning process, a core procedure to deploy large language models (LLM).
研究动机与目标
- 解决由于大 m×m 权重矩阵导致的深度过参数化神经网络训练中每次迭代的高昂 O(m²) 成本问题。
- 开发一种方法,将每次迭代的训练成本真正降低至 m 的次二次方,即 O(m²⁻Ω(1))。
- 通过利用过参数化和结构化计算,实现大规模语言模型(LLMs)的高效微调。
- 克服先前方法的局限性,这些方法要么产生 O(nm²) 的成本,要么对输入维度 d 呈指数依赖。
- 在预训练模型之上设计一种数据结构,以加速微调过程,同时保持收敛性保证。
提出的方法
- 使用移位 ReLU 激活函数以在神经元激活中引入稀疏性,从而降低每层的有效计算量。
- 利用截断正态分布随机变量来建模激活神经元输出及其范数的分布。
- 应用截断卡方分布和次高斯分布的集中不等式,以限制各层之间范数波动。
- 设计一个预处理阶段,仅在初始化时产生 O(m²) 的成本,从而实现每次训练迭代的次二次方成本。
- 通过对所有数据点和所有层使用并集界,确保网络中所有层的范数稳定性具有高概率。
- 采用一种基于预计算统计特性的数据结构,对激活神经元进行索引,从而实现快速的前向和反向传播。
实验结果
研究问题
- RQ1我们能否实现对具有 m×m 权重矩阵的深度过参数化神经网络训练,每次迭代达到次二次方时间?
- RQ2我们如何利用过参数化和激活稀疏性,将计算量降低至标准 O(m²) 以下?
- RQ3我们能否设计一种预处理方案,实现在不依赖输入维度 d 的指数关系下,每次迭代成本为 O(m²⁻Ω(1))?
- RQ4理论上实现的次二次方成本在大规模语言模型微调的实际应用中能实现到何种程度?
- RQ5在随机初始化下,是否可能在实现次二次方训练成本的同时,保持各层之间的范数稳定性?
主要发现
- 所提出的框架实现了对深度过参数化网络训练的 O(m²⁻Ω(1)) 每次迭代成本,显著低于标准的 O(m²)。
- 该方法仅在初始化阶段产生 O(m²) 的成本,后续所有训练迭代均以次二次方时间运行。
- 以高概率,所有层和所有数据点的隐藏表示的 ℓ₂ 范数均保持在 [1−ε, 1+ε] 区间内,确保了训练的稳定性。
- 该框架适用于深度 L ≥ 2 且宽度 m = poly(n, d) 的网络,满足典型的过参数化条件。
- 该结果通过高效利用大规模语言模型的过参数化结构,实现了更快速的微调。
- 该分析依赖于截断正态分布和卡方分布的集中不等式,为范数稳定性提供了理论保证。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。