[论文解读] SymTorch: A Framework for Symbolic Distillation of Deep Neural Networks
SymTorch 将神经网络组件的符号蒸馏自动化为闭式表达式,从而实现可解释的代理与混合神经符号模型,并在 GNN、PINN 与 LLM 的案例研究中展示应用。
Symbolic distillation replaces neural networks, or components thereof, with interpretable, closed-form mathematical expressions. This approach has shown promise in discovering physical laws and mathematical relationships directly from trained deep learning models, yet adoption remains limited due to the engineering barrier of integrating symbolic regression into deep learning workflows. We introduce SymTorch, a library that automates this distillation by wrapping neural network components, collecting their input-output behavior, and approximating them with human-readable equations via PySR. SymTorch handles the engineering challenges that have hindered adoption: GPU-CPU data transfer, input-output caching, model serialization, and seamless switching between neural and symbolic forward passes. We demonstrate SymTorch across diverse architectures including GNNs, PINNs and transformer models. Finally, we present a proof-of-concept for accelerating LLM inference by replacing MLP layers with symbolic surrogates, achieving an 8.3\% throughput improvement with moderate performance degradation.
研究动机与目标
- 提供一个开源框架,用于自动化神经网络组件的符号蒸馏。
- 通过处理 GPU-CPU 数据传输、缓存、序列化和前向传播切换,降低工程门槛。
- 在不同架构(GNN、PINN、Transformer)中的适用性,并展示潜在的 LLM 推理加速。
- 说明符号代理如何恢复已知物理定律并揭示 LLM 的算术偏差。
提出的方法
- 将 NN 组件包装为 SymbolicModel 块,在前向传播过程中收集输入-输出数据。
- 使用 PySR 对收集的 I/O 进行符号回归,以获得每个输出维度的闭式表达式。
- 用符号代理的 Pareto 前线替换选定的神经网络块,构建混合模型。
- 缓存激活并在前向传播中实现神经与符号计算的无缝切换。
- 通过拟合符号代理到兴趣点周围的邻域,提供 SLIME 风格的局部解释。

实验结果
研究问题
- RQ1符号回归是否能够在多样化架构中真实地近似神经网络组件的输入-输出映射?
- RQ2在变换器/LLM 中用符号代理替换神经块时,实际的收益与取舍(准确性与速度)是什么?
- RQ3符号蒸馏在多大程度上能够恢复已知物理定律或揭示 LLM 中学习到的算术偏差?
- RQ4SLIME 风格的符号解释在黑箱模型的局部可解释性方面表现如何?
主要发现
| Label | Perplexity Baseline | Δ Perplexity PCA+MLP | Δ Perplexity PCA+SymTorch | Δ Perplexity Control |
|---|---|---|---|---|
| Baseline | 10.62 | +3.11 | +3.14 | +6.97 |
- SymTorch 实现了跨 GNN、PINN 与 Transformer 块的符号蒸馏。
- 在 Transformer 的 28 层 MLP 中用 3 个符号代理替换,吞吐量提高 8.3%,但困惑度小幅上升(基线 3.14 对比 10.62 的对比)。
- 在 PCA 降维子空间中的符号代理捕捉了 MLP 行为,困惑度衰减与仅 PCA 相当。
- 通过符号蒸馏,PINN 基于 PDE 的解可以回归为闭式表达式。
- 基于 GNN 的符号蒸馏恢复了真实的相互作用规律,与以往的先验偏置发现相一致。

更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。