[Paper Review] ReLU-KAN: New Kolmogorov-Arnold Networks that Only Need Matrix Addition, Dot Multiplication, and ReLU
ReLU-KAN replaces KAN’s B-spline basis with a ReLU-based basis to enable full matrix operations, achieving substantial GPU speedups, improved fitting stability, and preserved catastrophic forgetting resistance.
Limited by the complexity of basis function (B-spline) calculations, Kolmogorov-Arnold Networks (KAN) suffer from restricted parallel computing capability on GPUs. This paper proposes a novel ReLU-KAN implementation that inherits the core idea of KAN. By adopting ReLU (Rectified Linear Unit) and point-wise multiplication, we simplify the design of KAN's basis function and optimize the computation process for efficient CUDA computing. The proposed ReLU-KAN architecture can be readily implemented on existing deep learning frameworks (e.g., PyTorch) for both inference and training. Experimental results demonstrate that ReLU-KAN achieves a 20x speedup compared to traditional KAN with 4-layer networks. Furthermore, ReLU-KAN exhibits a more stable training process with superior fitting ability while preserving the "catastrophic forgetting avoidance" property of KAN. You can get the code in https://github.com/quiqi/relu_kan
Motivation & Objective
- Motivate faster, GPU-friendly implementations of Kolmogorov-Arnold Networks (KANs) by simplifying basis functions.
- Develop a ReLU-based basis that enables matrix-based computations and easy integration with frameworks like PyTorch.
- Show that ReLU-KAN speeds up training and improves fitting accuracy while preserving KAN properties such as avoiding catastrophic forgetting.
Proposed method
- Introduce a simplified basis function R_i(x) = [ReLU(e_i − x) × ReLU(x − s_i)]^2 × 16/(e_i − s_i)^4 as the replacement for B-splines in KAN.
- Express the entire basis computation as matrix operations to enhance GPU parallelism.
- Pre-generate non-trainable parameters to accelerate computation, analogous to positional encoding.
- Represent the weighted sum of basis functions as a convolution operation to fit within standard DL frameworks.
- Provide a concise PyTorch implementation of the ReLU-KAN layer with less than 30 lines of code.
- Derive a layer-wise computation pipeline showing how R_i basis evaluations form a matrix F used in outputs.
Experimental results
Research questions
- RQ1Does replacing KAN’s B-spline basis with a ReLU-based basis improve training speed on GPUs?
- RQ2Can ReLU-KAN maintain or improve the fitting accuracy and stability of KAN across univariate and multivariate functions?
- RQ3Does ReLU-KAN preserve KAN’s resistance to catastrophic forgetting while enabling scalable network architectures?
Key findings
- ReLU-KAN is 5 to 20 times faster than KAN in training across single- to three-layer models.
- For larger networks, GPU speedups of ReLU-KAN become more pronounced, with up to about 20x speedup observed.
- ReLU-KAN achieves higher fitting accuracy than KAN, by about two orders of magnitude in reported comparisons.
- Training of ReLU-KAN shows more stable convergence than KAN across evaluated functions, especially for higher-frequency targets.
- ReLU-KAN retains KAN’s catastrophic forgetting avoidance property in experiments.
Better researchstarts right now
From paper design to paper writing, dramatically reduce your research time.
No credit card · Free plan available
This review was created by AI and reviewed by human editors.