[논문 리뷰] Ring Attention with Blockwise Transformers for Near-Infinite Context
Ring Attention은 장치 수에 비례하여 컨텍스트 길이를 확장하는 Transformer 모델의 훈련 및 추론을 링 토폴로지에서 장치 간 통신과 블록 단위 주의/FFN 계산을 겹쳐 근사 없이 거의 무한한 컨텍스트를 달성합니다.
Transformers have emerged as the architecture of choice for many state-of-the-art AI models, showcasing exceptional performance across a wide range of AI applications. However, the memory demands imposed by Transformers limit their ability to handle long sequences, thereby posing challenges in utilizing videos, actions, and other long-form sequences and modalities in complex environments. We present a novel approach, Ring Attention with Blockwise Transformers (Ring Attention), which leverages blockwise computation of self-attention and feedforward to distribute long sequences across multiple devices while fully overlapping the communication of key-value blocks with the computation of blockwise attention. Our approach enables training and inference of sequences that are up to device count times longer than those achievable by prior memory-efficient Transformers, without resorting to approximations or incurring additional communication and computation overheads. Extensive experiments on language modeling and reinforcement learning tasks demonstrate the effectiveness of our approach in allowing millions of tokens context size and improving performance.
연구 동기 및 목표
- 장기 컨텍스트 Transformer의 메모리 병목 현상을 동기 부여하고 해결합니다.
- 장치를 가로질러 긴 시퀀스를 분배하기 위한 링 기반의 블록 단위 계산 방식을 도입합니다.
- 키-값 블록 간의 통신과 계산 간의 중첩이 오버헤드를 제거한다는 것을 보여줍니다.
- 다수의 토큰과 디바이스 수의 확장성(언어 모델링 및 강화학습 과제에서)을 데모합니다.
제안 방법
- 블록 단위 주의 및 피드포워드 계산으로 시퀀스 길이를 여러 디바이스에 분산합니다.
- 링 토폴로지는 호스트를 조정합니다; 각 호스트는 쿼리 블록을 처리하는 동안 키-값 블록은 다음/이전 호스트로 회전합니다.
- 키-값 블록의 통신을 블록 단위 계산과 중첩시켜 통신 지연을 숨깁니다.
- 블록 단위 병렬 트랜스포머를 사용하여 메모리 비용을 블록 크기에 선형으로, 시퀀스 길이에 독립적으로 유지합니다.
- 알고리즘 1은 FSDP 및 링 통신이 있는 링 기반 트랜스포머 학습의 메모리 감소 단계를 개요합니다.
- 구현은 근사 없이 메모리 효율적인 주의 원소와 실제 블록 단위 연산을 활용합니다.
실험 결과
연구 질문
- RQ1Ring Attention이 디바이스 수에 비례하여 Transformer 컨텍스트 길이를 선형적으로 확장하면서 성능을 유지할 수 있는가?
- RQ2블록 단위 주의를 디바이스 링에 분산시킬 때의 메모리 및 계산 트레이드오프는 무엇인가?
- RQ3Ring Attention이 서로 다른 모델 크기 및 하드웨어(GPUs/TPUs)에서 모델 FLOPs 활용도와 처리량에 어떻게 영향을 미치는가?
- RQ4Ring Attention이 강화학습 및 긴 컨텍스트 언어 모델링과 같이 긴 컨텍스트를 활용하는 다운스트림 작업을 개선하는가?
주요 결과
- Ring Attention은 기존의 메모리 효율적 방법들보다 디바이스 개수의 배수만큼 더 긴 시퀀스 학습을 가능하게 합니다.
- 백만 단어를 넘는 컨텍스트 크기도 근사나 추가 오버헤드 없이 달성할 수 있습니다.
- MFU(모델 FLOPs 활용도)는 매우 긴 컨텍스트 길이에서도 높게 유지되며, 일부 기준선과 다릅니다.
- ExoRL의 RL 실험에서 Ring Attention은 더 긴 전개/컨텍스트를 사용할 때 여러 작업에서 baselines 대비 평균 수익을 개선합니다.
- 512K-토큰 컨텍스트에서 Ring Attention으로 LLaMA-13B를 미세조정하면 긴 컨텍스트 행 조회 작업에서 높은 정확도를 유지하면서 일부 짧은 컨텍스트 기저선을 능가합니다.
- 하드웨어(A100 GPU 및 TPU) 전반에 걸쳐 Ring Attention은 바닐라/메모리 효율적 트랜스포머에 비해 최소한의 오버헤드로 컨텍스트 길이를 크게 확장하는 경향을 보입니다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.