[논문 리뷰] Training event-based neural networks with exact gradients via Differentiable ODE Solving in JAX
Eventax는 JAX에서 differentiable ODE 솔버와 이벤트 기반 스파이크 처리를 결합하여 임의의 뉴런 모델에 대한 순방향 시뮬레이션의 정확한 그래디언트를 생성합니다.
Existing frameworks for gradient-based training of spiking neural networks face a trade-off: discrete-time methods using surrogate gradients support arbitrary neuron models but introduce gradient bias and constrain spike-time resolution, while continuous-time methods that compute exact gradients require analytical expressions for spike times and state evolution, restricting them to simple neuron types such as Leaky Integrate and Fire (LIF). We introduce the Eventax framework, which resolves this trade-off by combining differentiable numerical ODE solvers with event-based spike handling. Built in JAX, our frame-work uses Diffrax ODE-solvers to compute gradients that are exact with respect to the forward simulation for any neuron model defined by ODEs . It also provides a simple API where users can specify just the neuron dynamics, spike conditions, and reset rules. Eventax prioritises modelling flexibility, supporting a wide range of neuron models, loss functions, and network architectures, which can be easily extended. We demonstrate Eventax on multiple benchmarks, including Yin-Yang and MNIST, using diverse neuron models such as Leaky Integrate-and-fire (LIF), Quadratic Integrate-and-fire (QIF), Exponential integrate-and-fire (EIF), Izhikevich and Event-based Gated Recurrent Unit (EGRU) with both time-to-first-spike and state-based loss functions, demonstrating its utility for prototyping and testing event-based architectures trained with exact gradients. We also demonstrate the application of this framework for more complex neuron types by implementing a multi-compartment neuron that uses a model of dendritic spikes in human layer 2/3 cortical Pyramidal neurons for computation. Code available at https://github.com/efficient-scalable-machine-learning/eventax.
연구 동기 및 목표
- 가변적인 뉴런 모델을 가진 이벤트 기반 신경망의 그래디언트 기반 학습 필요성을 제시한다.
- 임의의 ODE로 정의된 뉴런 동역학에 대해 정확한 순방향 그래디언트를 생성하는 프레임워크를 제시한다.
- 뉴런 동역학, 스파이크 조건 및 리셋 규칙을 정의할 수 있는 간단한 API를 제공한다.
- 유연성과 정확성을 검증하기 위해 다양한 뉴런 모델과 벤치마크에서 접근법을 시연한다.
제안 방법
- Diffrax의 differentiable ODE 솔버를 사용하여 스파이크를 위한 이벤트 처리와 함께 연속 시간 적분을 수행한다.
- 암시적 함수 정리를 이용해 솔버 스텝과 정확한 이벤트 시간 전체를 통해 그래디언트를 전파한다.
- 초기 상태, 동역학, 스파이크 조건, 입력 처리 및 리셋 규칙을 지정하는 사용자 정의 NeuronModel 인터페이스를 허용한다.
- 이종 네트워크를 위한 MultiNeuronModel과 AMOS 같은 래퍼 또는 refractory 확장을 지원한다.
- LIF, QIF, EIF, Izhikevich, EGRU 등의 뉴런 모델과 다구획 수상돌기-스파이크 모델을 포함한 시연.
실험 결과
연구 질문
- RQ1Can exact gradients be computed for forward simulations of arbitrary neuron models using differentiable ODE solving with event handling?
- RQ2Does integrating numerical ODE solvers with event-based spike timing yield accurate gradients without closed-form solutions?
- RQ3How does Eventax perform across classical benchmarks (Yin–Yang, MNIST) and recurrent event-based networks with various neuron models?
- RQ4What is the impact of solver choices and event handling on training stability and throughput?
주요 결과
- Eventax enables exact gradients with respect to the forward simulation for any neuron model defined by ODEs without requiring closed-form spike times.
- The framework supports a wide range of neuron models (LIF, QIF, EIF, Izhikevich, EGRU) and multi-compartment neurons while enabling time-to-first-spike and state-based losses.
- Experiments on Yin–Yang and MNIST show competitive accuracies across temporal and state-based objectives.
- A delayed-memory XOR task demonstrates training of recurrent event-based architectures (EGRU) with exact gradients.
- Performance analysis indicates throughput scales with batch size and event count, reflecting Diffrax-based solver characteristics.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.