[论文解读] Training event-based neural networks with exact gradients via Differentiable ODE Solving in JAX
Eventax 将可微分 ODE 求解器与事件基尖峰处理结合,在 JAX 中为任意神经元模型的前向仿真生成精确梯度。
Existing frameworks for gradient-based training of spiking neural networks face a trade-off: discrete-time methods using surrogate gradients support arbitrary neuron models but introduce gradient bias and constrain spike-time resolution, while continuous-time methods that compute exact gradients require analytical expressions for spike times and state evolution, restricting them to simple neuron types such as Leaky Integrate and Fire (LIF). We introduce the Eventax framework, which resolves this trade-off by combining differentiable numerical ODE solvers with event-based spike handling. Built in JAX, our frame-work uses Diffrax ODE-solvers to compute gradients that are exact with respect to the forward simulation for any neuron model defined by ODEs . It also provides a simple API where users can specify just the neuron dynamics, spike conditions, and reset rules. Eventax prioritises modelling flexibility, supporting a wide range of neuron models, loss functions, and network architectures, which can be easily extended. We demonstrate Eventax on multiple benchmarks, including Yin-Yang and MNIST, using diverse neuron models such as Leaky Integrate-and-fire (LIF), Quadratic Integrate-and-fire (QIF), Exponential integrate-and-fire (EIF), Izhikevich and Event-based Gated Recurrent Unit (EGRU) with both time-to-first-spike and state-based loss functions, demonstrating its utility for prototyping and testing event-based architectures trained with exact gradients. We also demonstrate the application of this framework for more complex neuron types by implementing a multi-compartment neuron that uses a model of dendritic spikes in human layer 2/3 cortical Pyramidal neurons for computation. Code available at https://github.com/efficient-scalable-machine-learning/eventax.
研究动机与目标
- Motivate the need for gradient-based training of event-based neural networks with flexible neuron models.
- Present a framework that yields exact forward-gradient with respect to arbitrary ODE-defined neuron dynamics.
- Provide a simple API to define neuron dynamics, spike conditions, and reset rules.
- Demonstrate the approach on multiple neuron models and benchmarks to validate flexibility and accuracy.
提出的方法
- Use Diffrax differentiable ODE solvers to perform continuous-time integration with event handling for spikes.
- Propagate gradients through solver steps and through exact event times using the implicit function theorem.
- Allow user-defined NeuronModel interfaces specifying initial state, dynamics, spike condition, input handling, and reset rules.
- Support a MultiNeuronModel for heterogeneous networks and wrappers like AMOS or refractory extensions.
- Demonstrate with neuron models such as LIF, QIF, EIF, Izhikevich, and EGRU, plus a multi-compartment dendritic-spike model.
实验结果
研究问题
- RQ1Can exact gradients be computed for forward simulations of arbitrary neuron models using differentiable ODE solving with event handling?
- RQ2Does integrating numerical ODE solvers with event-based spike timing yield accurate gradients without closed-form solutions?
- RQ3How does Eventax perform across classical benchmarks (Yin–Yang, MNIST) and recurrent event-based networks with various neuron models?
- RQ4What is the impact of solver choices and event handling on training stability and throughput?
主要发现
- Eventax enables exact gradients with respect to the forward simulation for any neuron model defined by ODEs without requiring closed-form spike times.
- The framework supports a wide range of neuron models (LIF, QIF, EIF, Izhikevich, EGRU) and multi-compartment neurons while enabling time-to-first-spike and state-based losses.
- Experiments on Yin–Yang and MNIST show competitive accuracies across temporal and state-based objectives.
- A delayed-memory XOR task demonstrates training of recurrent event-based architectures (EGRU) with exact gradients.
- Performance analysis indicates throughput scales with batch size and event count, reflecting Diffrax-based solver characteristics.
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。