[论文解读] SpiralFormer: Looped Transformers Can Learn Hierarchical Dependencies via Multi-Resolution Recursion
SpiralFormer 在循环Transformers中引入多分辨率递归,能够实现分层、按尺度依赖的关系,并在160M–1.4B尺度上相对于循环与非循环基线都具备更高的效率。
Recursive (looped) Transformers decouple computational depth from parameter depth by repeatedly applying shared layers, providing an explicit architectural primitive for iterative refinement and latent reasoning. However, early looped Transformers often underperform non-recursive baselines of equal compute. While recent literature has introduced more effective recursion mechanisms to mitigate this gap, existing architectures still operate at a fixed, full-token resolution, neglecting the potential efficiency of computing over compressed latent representations. In this paper, we propose SpiralFormer, a looped Transformer that executes recurrence under a multi-resolution recursion schedule. We provide probing evidence that multi-resolution recursion enables the model to learn hierarchical dependencies by inducing iteration-wise functional specialization across different scales. Empirically, SpiralFormer achieves better parameter and compute efficiency than both looped and non-looped baselines across model scales from 160M to 1.4B, establishing sequence resolution as a potential axis for scaling recursive architectures.
研究动机与目标
- 研究多分辨率递归是否能使循环Transformer学习到全分辨率循环未能捕捉的分层依赖。
- 开发 SpiralFormer,将标记压缩为潜在槽并在递归内跨分辨率操作。
- 证明粗到细、共享核心的递归在计算/参数效率上优于全分辨率循环。
- 在不同模型尺度(160M–1.4B)上提供经验证据,表明多分辨率递归提升性能与效率。
提出的方法
- 采用中间循环架构,含前/循环/后块以及共享循环核心。
- 通过下采样到块级潜在表示,使用共享核心进行处理,并上采样回到标记长度,同时通过右移实现严格自回归因果性。
- 定义分辨率调度 {r_t},在迭代期间改变有效序列长度 L_t。
- 使用 MeSH 或 Anchor 拓扑更新,将每次迭代更新融合到运行状态。
- 实现带块化、偏移和基于注意力的下/上采样机制的因果下采样/上采样。
- 在 Pythia 序列(160M–1.4B)上以预训练解码器风格的 Transformer 进行评估,并在计算资源和参数预算方面与 Baseline 与 LoopedFormer 进行对比。
实验结果
研究问题
- RQ1多分辨率递归是否能够让循环Transformer学习到全分辨率循环无法捕捉的分层依赖?
- RQ2在压缩潜表示上运行的共享核心是否能实现比传统循环或非循环架构更好的参数与计算效率?
- RQ3分辨率调度如何影响模型性能与放大缩小行为?
- RQ4在 SpiralFormer 中递归比率对验证损失与容量的影响如何?
- RQ5随着分辨率增加,跨循环的注意力模式是否一致地发生变化,指示分层推理?
主要发现
| 模型 | 配置 | 参数量 (M) (总/非嵌入) | FLOPs (1e12) (4096 预填充) | 困惑度 ↓ | 任务准确性 ↑ | 0-shot | 5-shot | Pile | Wiki | LD-O | LD-S |
|---|---|---|---|---|---|---|---|---|---|---|---|
| Pythia-160M | Baseline (Pythia) | 163.5 / 85.1 | 1.65 | 11.31 | 30.32 | 42.86 | 175.62 | 39.88 | 40.54 | ||
| LoopedFormer * | 2+4×{1,1}+2 | 135.2 / 56.7 | 1.65 | 11.63 | 31.69 | 50.38 | 195.11 | 38.81 | 40.15 | ||
| LoopedFormer † | 2+4×{1,1}+2 | 135.2 / 56.7 | 1.65 | 11.37 | 30.43 | 46.60 | 178.77 | 39.41 | 40.60 | ||
| SpiralFormer-B † | 2+4×{1/8,1/4,1/2,1}+2 | 135.2 / 56.8 | 1.48 | 11.29 | 30.27 | 43.27 | 155.78 | 39.73 | 41.02 | ||
| SpiralFormer-L † | 4+4×{1/16,1/8,1/4,1/2}+4 | 163.6 / 85.1 | 1.49 | 10.94 | 28.85 | 41.24 | 147.52 | 39.30 | 41.37 | ||
| Pythia-410M | Baseline (Pythia) | 407.4 / 302.3 | 4.59 | 9.07 | 21.79 | 19.48 | 65.86 | 43.87 | 45.31 | ||
| LoopedFormer * | 4+8×{1,1}+4 | 306.7 / 201.5 | 4.59 | 9.19 | 22.12 | 20.37 | 52.55 | 43.70 | 45.68 | ||
| LoopedFormer † | 4+8×{1,1}+4 | 306.7 / 201.6 | 4.59 | 9.09 | 21.84 | 19.63 | 42.51 | 44.12 | 45.56 | ||
| SpiralFormer-B * | 4+8×{1/8,1/4,1/2,1}+4 | 306.7 / 201.6 | 4.10 | 9.13 | 22.04 | 21.96 | 47.33 | 43.87 | 46.30 | ||
| SpiralFormer-B † | 4+8×{1/8,1/4,1/2,1}+4 | 306.8 / 201.6 | 4.11 | 9.00 | 21.48 | 19.11 | 39.78 | 44.31 | 46.75 | ||
| SpiralFormer-L † | 8+8×{1/16,1/8,1/4,1/2}+8 | 407.5 / 302.4 | 4.16 | 8.73 | 20.55 | 20.38 | 47.89 | 44.97 | 47.06 | ||
| Pythia-1B | Baseline (Pythia) | 1020.2 / 805.7 | 9.67 | 7.96 | 17.66 | 13.53 | 33.65 | 46.95 | 49.07 | ||
| LoopedFormer * | 3+5×{1,1}+3 | 768.4 / 553.9 | 9.67 | 8.10 | 18.15 | 13.32 | 32.34 | 46.73 | 48.83 | ||
| LoopedFormer † | 3+5×{1,1}+3 | 768.4 / 554.0 | 9.67 | 7.90 | 17.54 | 12.19 | 26.71 | 47.53 | 49.51 | ||
| SpiralFormer-B † | 3+5×{1/8,1/4,1/2,1}+3 | 768.6 / 554.1 | 8.95 | 7.80 | 17.21 | 11.96 | 25.55 | 48.14 | 50.25 | ||
| SpiralFormer-L † | 5+6×{1/16,1/8,1/4,1/2}+5 | 1020.4 / 805.9 | 8.96 | 7.64 | 16.73 | 11.94 | 23.90 | 48.97 | 51.83 | ||
| Pythia-1.4B | Baseline (Pythia) | 1423.0 / 1208.6 | 14.08 | 7.44 | 15.97 | 10.51 | 22.81 | 49.50 | 51.93 | ||
| Baseline † | 24 Layers | 1423.1 / 1208.7 | 14.08 | 7.26 | 15.25 | 9.46 | 16.31 | 50.21 | 53.12 | ||
| LoopedFormer * | 4+8×{1,1}+4 | 1020.2 / 805.7 | 14.08 | 7.51 | 16.25 | 10.71 | 19.37 | 49.39 | 51.27 | ||
| LoopedFormer † | 4+8×{1,1}+4 | 1020.2 / 805.8 | 14.08 | 7.39 | 15.84 | 9.72 | 19.39 | 50.56 | 52.79 | ||
| SpiralFormer-B † | 4+8×{1/8,1/4,1/2,1}+4 | 1020.4 / 805.9 | 12.92 | 7.30 | 15.61 | 9.06 | 15.30 | 51.48 | 53.22 | ||
| SpiralFormer-L † | 8+8×{1/16,1/8,1/4,1/2}+8 | 1423.2 / 1208.8 | 13.13 | 7.14 | 15.03 | 9.73 | 14.42 | 51.75 | 54.37 |
- SpiralFormer 在160M–1.4B尺度上对比循环与非循环基线,在参数和计算效率方面具备更优表现。
- 粗到细的多分辨率调度在某些规模下持续降低 FLOPs(约7–11%),同时保持或提升困惑度与下游准确率。
- 在参数量相同的情况下,SpiralFormer-L 将 FLOPs 降低约3–10%,并提升困惑度与小样本准确率(如1.4B:FLOPs 由 14.08 降至 13.13;5-shot 由 1.93→54.37?)
- 注意力探针显示跨循环的位移:较高分辨率的循环变得更具选择性(熵更低)且局部性更强(本地注意力质量更高)。
- 在全分辨率 LoopedFormer 中跨循环注意力动态较弱,表明分层模式与多分辨率设计相关,而非仅与循环本身有关。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。