[論文レビュー] Training event-based neural networks with exact gradients via Differentiable ODE Solving in JAX
Eventaxは、JAXに基づく微分可能ODE solverとイベントベースのスパイク処理を組み合わせ、任意のニューロンモデルに対する正確な前方勾配を生成します。
Existing frameworks for gradient-based training of spiking neural networks face a trade-off: discrete-time methods using surrogate gradients support arbitrary neuron models but introduce gradient bias and constrain spike-time resolution, while continuous-time methods that compute exact gradients require analytical expressions for spike times and state evolution, restricting them to simple neuron types such as Leaky Integrate and Fire (LIF). We introduce the Eventax framework, which resolves this trade-off by combining differentiable numerical ODE solvers with event-based spike handling. Built in JAX, our frame-work uses Diffrax ODE-solvers to compute gradients that are exact with respect to the forward simulation for any neuron model defined by ODEs . It also provides a simple API where users can specify just the neuron dynamics, spike conditions, and reset rules. Eventax prioritises modelling flexibility, supporting a wide range of neuron models, loss functions, and network architectures, which can be easily extended. We demonstrate Eventax on multiple benchmarks, including Yin-Yang and MNIST, using diverse neuron models such as Leaky Integrate-and-fire (LIF), Quadratic Integrate-and-fire (QIF), Exponential integrate-and-fire (EIF), Izhikevich and Event-based Gated Recurrent Unit (EGRU) with both time-to-first-spike and state-based loss functions, demonstrating its utility for prototyping and testing event-based architectures trained with exact gradients. We also demonstrate the application of this framework for more complex neuron types by implementing a multi-compartment neuron that uses a model of dendritic spikes in human layer 2/3 cortical Pyramidal neurons for computation. Code available at https://github.com/efficient-scalable-machine-learning/eventax.
研究の動機と目的
- イベントベースのニューラルネットワークの勾配ベースの訓練の必要性を、柔軟なニューロンモデルで動機づける。
- 任意のODE定義のニューロン動態に対して正確な前方勾配を生み出すフレームワークを提示する。
- ニューロン動態、スパイク条件、リセット規則を定義するための簡易APIを提供する。
- 複数のニューロンモデルとベンチマークでアプローチの柔軟性と精度を検証する。
提案手法
- Diffraxの微分可能ODEソルバを用いて、スパイクのイベント処理を伴う連続時間積分を実行する。
- Implicit function theoremを用いて solverのステップと厳密なイベント時刻を通じて勾配を伝播する。
- 初期状態、ダイナミクス、スパイク条件、入力処理、リセット規則を指定するユーザー定義のNeuronModelインターフェースを許可する。
- 異種ネットワーク用のMultiNeuronModelやAMOSや不応性拡張などのラッパーをサポートする。
- LIF、QIF、EIF、Izhikevich、EGRUなどのニューロンモデルと、多腔室樹状突起スパイクモデルをデモンストレーションする。
実験結果
リサーチクエスチョン
- RQ1微分可能ODEソルバとイベント処理を用いた前向きシミュレーションで、任意のニューロンモデルの正確な勾配を計算できるか。
- RQ2イベントベースのスパイクタイミングと数値ODEソルバを統合することで、閉形式解なしで正確な勾配を得られるか。
- RQ3EventaxはYin–Yang、MNISTなどの古典的ベンチマークと、さまざまなニューロンモデルを用いた再帰的イベントベースネットワークでどのように性能を発揮するか。
- RQ4ソルバの選択とイベント処理が訓練の安定性とスループットに与える影響は何か。
主な発見
- Eventaxは、イベント時刻の閉形式解を必要とせず、ODEで定義された任意のニューロンモデルに対する前方シミュレーションの正確な勾配を可能にする。
- このフレームワークは、LIF、QIF、EIF、Izhikevich、EGRU、および多腔室ニューロンモデルを広くサポートし、最初のスパイクまでの時間と状態ベースの損失を可能にする。
- Yin–YangとMNISTの実験は、時間的および状態ベースの目的で競争力のある精度を示す。
- 遅延メモリ XORタスクは、正確な勾配を用いて再帰的イベントベースアーキテクチャ(EGRU)の訓練を示す。
- パフォーマンス分析は、バッチサイズとイベント数に応じてスループットが拡張されることを示し、Diffraxベースのソルバ特性を反映している。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。