[논문 리뷰] Scaling Transformer to 1M tokens and beyond with RMT
이 논문은 Recurrent Memory Transformer (RMT)를 제시한다. 이는 trainable memory tokens를 부착하고 커리큘럼 학습을 사용하여 Transformer 모델이 최대 2 million tokens를 처리할 수 있도록 하는 기억-강화, 세그먼트-수준 순환 접근법이다.
A major limitation for the broader scope of problems solvable by transformers is the quadratic scaling of computational complexity with input size. In this study, we investigate the recurrent memory augmentation of pre-trained transformer models to extend input context length while linearly scaling compute. Our approach demonstrates the capability to store information in memory for sequences of up to an unprecedented two million tokens while maintaining high retrieval accuracy. Experiments with language modeling tasks show perplexity improvement as the number of processed input segments increases. These results underscore the effectiveness of our method, which has significant potential to enhance long-term dependency handling in natural language understanding and generation tasks, as well as enable large-scale context processing for memory-intensive applications.
연구 동기 및 목표
- 메모리 증강된 세그먼트-수준 재귀(RMT)를 인코더 전용 및 디코더 전용 트랜스포머용 플러그인 래퍼로 시연한다.
- RMT가 추론 중 선형 계산과 일정한 메모리로 극도로 긴 시퀀스(최대 2M tokens)를 처리할 수 있음을 보여준다.
- 메모리 획득/유지 작업을 백만 토큰 컨텍스트에 확장 가능하도록 개발하고 벤치마킹하여 메모리 연산의 일반화 성능을 평가한다.
- 도메인 전반의 실용적 이점을 평가하기 위해 긴 범위의 언어 모델링 및 형식적 추론 작업에서 RMT의 영향을 조사한다.
제안 방법
- 아키텍처를 변경하지 않고 사전 학습된 트랜스포머에 토큰 기반 메모리 모듈을 부착한다.
- 고정 크기 세그먼트로 나누고 세그먼트 내에서만 전체 어텐션을 수행하여 선형 스케일링을 가능하게 한다.
- 메모리 토큰의 순환을 세그먼트 간에 학습하여 메모리 출력이 후속 세그먼트에 영향을 미치도록 한다.
- 학습 커리큘럼을 사용하여 작업 길이를 단일 세그먼트에서 다중 세그먼트 컨텍스트로 점진적으로 확장한다.
- 합성 기억 작업을 통해 메모리 연산을 평가하고 실험을 긴 범위의 언어 모델링 및 정리-증명 스타일 생성으로 확장한다.
실험 결과
연구 질문
- RQ1RMT가 선형 연산 비용으로 사전 학습된 트랜스포머의 유효 컨텍스트 길이를 multi-million-token 규모로 확장할 수 있는가?
- RQ2메모리-강화 트랜스포머가 극도로 긴 시퀀스에서 사실을 암기하고 검색하며 추론하는 능력은 어느 정도인가?
- RQ3점진적으로 긴 세그먼트 작업으로 학습될 때 메모리-강화 모델이 더 긴 시퀀스 길이에 일반화하는가?
- RQ4긴 텍스트 언어 모델링 및 형식적 증명 생성을 대상으로 한 RMT의 perplexity와 예측 품질에 대한 영향은 무엇인가?
주요 결과
- RMT는 고정된 세그먼트 크기에 대해 입력 길이에 따라 선형으로 스케일링되며 다중 세그먼트 입력에서 비순환 모델에 비해 FLOPs를 감소시키고(일부 경우 최대 295× 적은 FLOPs).
- 메모리를 사용하면 사전 학습된 BERT 백본이 최대 2,000,000 토큰(512 토큰의 4,096개 세그먼트) 간 정보를 저장하고 검색할 수 있다.
- 커리큘럼 학습은 안정성과 일반화를 향상시켜 짧은 과제로 학습된 모델이 훨씬 더 긴 과제를 해결하도록 한다.
- 긴 범위 언어 모델링에서 메모리-equipped RMT는 베이스라인 대비 perplexity를 개선하고 세그먼트 경계에서 기억을 이어가며 예측의 안정성을 높인다.
- RMT는 어텐션 패턴 기반의 메모리 연산을 보여주며 매우 긴 시퀀스에서 메모리 검색을 일반화할 수 있어, 적합한 과제에서 2M 토큰을 넘는 확장에 내재적 기술적 한계가 없음을 시사한다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.