[論文レビュー] LongNet: Scaling Transformers to 1,000,000,000 Tokens
LongNetは拡張注意機構を導入し、1十億トークンを超えるトランスフォーマーを線形に近い計算量でスケールさせ、分散トレーニングを可能にし、短いシーケンスでも性能を維持します。
Scaling sequence length has become a critical demand in the era of large language models. However, existing methods struggle with either computational complexity or model expressivity, rendering the maximum sequence length restricted. To address this issue, we introduce LongNet, a Transformer variant that can scale sequence length to more than 1 billion tokens, without sacrificing the performance on shorter sequences. Specifically, we propose dilated attention, which expands the attentive field exponentially as the distance grows. LongNet has significant advantages: 1) it has a linear computation complexity and a logarithm dependency between any two tokens in a sequence; 2) it can be served as a distributed trainer for extremely long sequences; 3) its dilated attention is a drop-in replacement for standard attention, which can be seamlessly integrated with the existing Transformer-based optimization. Experiments results demonstrate that LongNet yields strong performance on both long-sequence modeling and general language tasks. Our work opens up new possibilities for modeling very long sequences, e.g., treating a whole corpus or even the entire Internet as a sequence.
研究の動機と目的
- Motivate the need to scale sequence length in transformers beyond current limits.
- Propose a dilated attention mechanism to expand the attentive field exponentially while maintaining efficiency.
- Enable distributed training across multiple devices to handle extremely long sequences.
- Demonstrate that LongNet preserves performance on short sequences while significantly extending context length.
- Provide practical integration with existing Transformer optimization (kernel fusion, quantization, etc.).
提案手法
- Replace vanilla self-attention with dilated attention that sparsifies Q, K, V along the sequence with segment-based dilations.
- Use a mix of dilated attention patterns with increasing segment sizes and dilation rates to achieve near-linear FLOPs in N and log-based token dependency.
- Introduce multi-head dilated attention where each head uses shifted sparse patterns to capture diverse local/global information.
- Formulate the overall output as a weighted combination of multiple dilated attentions with weights derived from the softmax denominator.
- Describe a distributed training algorithm that partitions the sequence dimension across devices with all-gather of KV and cross-attention between local Q and global KV.
- Provide implementation compatibility as a dense Transformer variant compatible with standard optimization (FlashAttention, kernel fusion, quantization).
実験結果
リサーチクエスチョン
- RQ1Can dilated attention scale Transformer sequence length to 1B tokens without prohibitive compute or memory costs?
- RQ2Does LongNet maintain or improve performance on shorter sequences while expanding context length?
- RQ3Is the proposed distributed training approach effective for extremely long sequences beyond single-GPU memory limits?
- RQ4How does LongNet's context length affect language modeling performance and scaling laws compared to dense/other sparse transformers?
主な発見
| モデル | 長さ | バッチ | GitHub | 2K | 8K | 32K |
|---|---|---|---|---|---|---|
| Transformer | 2K | 256 | Yes | 4.24 | 5.07 | 11.29 |
| Sparse Transformer | 8K | 64 | Yes | 4.39 | 3.35 | 8.79 |
| LongNet (ours) | 2K | 3.24 | 3.36 | |||
| Sparse Transformer | 16K | 32 | Yes | 4.85 | 3.73 | 19.77 |
| LongNet (ours) | 16K | 3.26 | 3.31 | |||
| Sparse Transformer | 32K | 16 | Yes | 5.15 | 4.00 | 3.64 |
| LongNet (ours) | 32K | 3.01 | 3.33 | 3.01 |
- LongNet achieves near-linear computation complexity in sequence length with a token dependency of O(log N).
- Dilated attention enables scalable training of sequences up to 1B tokens using a distributed sequence-parallel approach with constant KV communication cost across devices.
- On language modeling tasks, LongNet consistently outperforms dense Transformer and sparse Transformer baselines across longer sequence lengths (2K, 8K, 16K, 32K) with lower perplexity.
- LongNet preserves strong performance on short sequences while enabling dramatically longer contexts, and follows similar compute-backed scaling laws as standard transformers.
- Experiments demonstrate the practicality of scaling context length during training, with longer contexts yielding better language modeling results under the same or comparable compute.
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。