Skip to main content
QUICK REVIEW

[Paper Review] ReLU-KAN: New Kolmogorov-Arnold Networks that Only Need Matrix Addition, Dot Multiplication, and ReLU

Qi Qiu, Tao Zhu|arXiv (Cornell University)|Jun 4, 2024
Parallel Computing and Optimization Techniques13 citations
TL;DR

ReLU-KAN replaces KAN’s B-spline basis with a ReLU-based basis to enable full matrix operations, achieving substantial GPU speedups, improved fitting stability, and preserved catastrophic forgetting resistance.

ABSTRACT

Limited by the complexity of basis function (B-spline) calculations, Kolmogorov-Arnold Networks (KAN) suffer from restricted parallel computing capability on GPUs. This paper proposes a novel ReLU-KAN implementation that inherits the core idea of KAN. By adopting ReLU (Rectified Linear Unit) and point-wise multiplication, we simplify the design of KAN's basis function and optimize the computation process for efficient CUDA computing. The proposed ReLU-KAN architecture can be readily implemented on existing deep learning frameworks (e.g., PyTorch) for both inference and training. Experimental results demonstrate that ReLU-KAN achieves a 20x speedup compared to traditional KAN with 4-layer networks. Furthermore, ReLU-KAN exhibits a more stable training process with superior fitting ability while preserving the "catastrophic forgetting avoidance" property of KAN. You can get the code in https://github.com/quiqi/relu_kan

Motivation & Objective

  • Motivate faster, GPU-friendly implementations of Kolmogorov-Arnold Networks (KANs) by simplifying basis functions.
  • Develop a ReLU-based basis that enables matrix-based computations and easy integration with frameworks like PyTorch.
  • Show that ReLU-KAN speeds up training and improves fitting accuracy while preserving KAN properties such as avoiding catastrophic forgetting.

Proposed method

  • Introduce a simplified basis function R_i(x) = [ReLU(e_i − x) × ReLU(x − s_i)]^2 × 16/(e_i − s_i)^4 as the replacement for B-splines in KAN.
  • Express the entire basis computation as matrix operations to enhance GPU parallelism.
  • Pre-generate non-trainable parameters to accelerate computation, analogous to positional encoding.
  • Represent the weighted sum of basis functions as a convolution operation to fit within standard DL frameworks.
  • Provide a concise PyTorch implementation of the ReLU-KAN layer with less than 30 lines of code.
  • Derive a layer-wise computation pipeline showing how R_i basis evaluations form a matrix F used in outputs.

Experimental results

Research questions

  • RQ1Does replacing KAN’s B-spline basis with a ReLU-based basis improve training speed on GPUs?
  • RQ2Can ReLU-KAN maintain or improve the fitting accuracy and stability of KAN across univariate and multivariate functions?
  • RQ3Does ReLU-KAN preserve KAN’s resistance to catastrophic forgetting while enabling scalable network architectures?

Key findings

  • ReLU-KAN is 5 to 20 times faster than KAN in training across single- to three-layer models.
  • For larger networks, GPU speedups of ReLU-KAN become more pronounced, with up to about 20x speedup observed.
  • ReLU-KAN achieves higher fitting accuracy than KAN, by about two orders of magnitude in reported comparisons.
  • Training of ReLU-KAN shows more stable convergence than KAN across evaluated functions, especially for higher-frequency targets.
  • ReLU-KAN retains KAN’s catastrophic forgetting avoidance property in experiments.

Better researchstarts right now

From paper design to paper writing, dramatically reduce your research time.

No credit card · Free plan available

This review was created by AI and reviewed by human editors.