[논문 리뷰] Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
이 논문은 트랜스포머에 선형 어텐션을 도입하여 시간/메모리 복잡도를 O(N)으로 줄이고, 자동회귀 추론을 수천 배 빠르게 수행하며, 인과 설정에서 RNN과의 동등성을 보여준다.
Transformers achieve remarkable performance in several tasks but due to their quadratic complexity, with respect to the input's length, they are prohibitively slow for very long sequences. To address this limitation, we express the self-attention as a linear dot-product of kernel feature maps and make use of the associativity property of matrix products to reduce the complexity from $\mathcal{O}\left(N^2 ight)$ to $\mathcal{O}\left(N ight)$, where $N$ is the sequence length. We show that this formulation permits an iterative implementation that dramatically accelerates autoregressive transformers and reveals their relationship to recurrent neural networks. Our linear transformers achieve similar performance to vanilla transformers and they are up to 4000x faster on autoregressive prediction of very long sequences.
연구 동기 및 목표
- Transformer의 자기-어텐션이 긴 시퀀스에서 가지는 2차 시간/메모리 비용을 줄인다.
- 결과적으로 O(N)로의 복잡도를 활용하기 위한 커널 기반의 선형 어텐션 형식을 제안한다.
- 선형 복잡도 및 일정한 메모리로 인과(masking) 적용을 가능하게 한다.
- Transformers를 RNN으로 재구성하여 자가회귀 추론 속도 향상을 입증한다.
- 선형 트랜스포머를 이미지, 음성 및 합성 태스크에서 Softmax 및 Reformer와 비교한다
제안 방법
- 소프트맥스 어텐션을 phi라는 특징 맵을 이용한 커널 기반 선형 어텐션으로 대체하여 층당 O(N) 계산을 가능하게 한다.
- 결합 속성(associativity)을 활용하여 주의 연산을 phi(Q) (phi(K)^T V)로 재작성하고 쿼리에 대해 재사용하기 위해 합을 한 번만 계산한다.
- 점근적 누적 합 S_i와 Z_i를 단계적으로 갱신하여 인과 마스킹을 선형 복잡도 및 일정 메모리로 강제한다.
- 정방향 및 역방향 전파에 대한 선형-시간 그래디언트를 도출하여 학습 중 메모리 사용량을 유지한다.
- 인과 마스킹이 적용된 트랜스포머는 두 개의 기억(S, z)과 재귀 업데이트 방정식을 갖는 RNN으로 간주될 수 있음을 보인다.
- PyTorch 구현 및 CUDA 가속 그래디언트를 제공하고 이미지, 오디오, 합성 태스크에서 Softmax 및 Reformer와 비교한다
실험 결과
연구 질문
- RQ1트랜스포머의 자기-어텐션을 재구성하여 성능 손실 없이 선형 시간과 메모리를 달성할 수 있는가?
- RQ2자가회귀 시퀀스 생성을 위한 인과 마스킹 하에서 선형 어텐션의 동작은 어떻게 되는가?
- RQ3인과 마스킹 하에서 트랜스포머를 RNN으로 해석할 수 있으며 이것이 더 빠른 추론을 가능하게 하는가?
- RQ4선형 트랜스포머가 전체 트랜스포머 및 최신 대안에 비해 비전 및 음성 태스크에서까지 경쟁력 있는 정확도를 유지하는가?
주요 결과
| 방법 | 비트/차원 | 초당 이미지 수 |
|---|---|---|
| Softmax | 0.621 | 0.45 |
| LSH-1 | 0.745 | 0.68 |
| LSH-4 | 0.676 | 0.27 |
| Linear (ours) | 0.644 | 142.8 |
| (CIFAR-10에 대한 표 계속) | ||
| Softmax | 3.47 | 0.004 |
| LSH-1 | 3.39 | 0.015 |
| LSH-4 | 3.51 | 0.005 |
| Linear (ours) | 3.40 | 17.85 |
| (WSJ에 대한 표 계속) | ||
| Bi-LSTM | - | 1047 |
| Softmax | - | 2711 |
| LSH-4 | - | 2250 |
| Linear (ours) | - | 824 |
- 선형 어텐션은 층당 시간 및 메모리 복잡도를 O(N^2)에서 O(N)으로 감소시킨다.
- 적절한 특징 맵(elu 기반)을 사용하면 선형 트랜스포머가 평가된 태스크에서 전체 트랜스포머의 성능에 근접하다.
- 내부 상태(S_i, Z_i)를 점진적으로 갱신할 수 있어 자가회귀 추론이 수천 배 빨라진다.
- MNIST 및 CIFAR-10 이미지 생성에서 가까운 Softmax perplexities를 보이면서도 처리량은 크게 향상(수백~수천 배 더 빠름)된다.
- WSJ 음성인식에서 Linear Transformer가 LSTM 및 Reformer 베이스라인보다 더 나은 PER 및 더 빠른 학습을 달성하는 반면, Softmax가 여전히 가장 정확하지만 느리다.
- 이 논문은 인과 마스킹이 적용된 트랜스포머를 두 개의 기억으로 구성된 RNN으로 볼 수 있는 형식적 연결 고리를 밝힌다
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.