[论文解读] Learning Transformer Programs
本论文将 Transformers 设计为在机制上可解释的约束,然后将它们转换为人可读的程序(类似 Python/RASP 的风格),而性能损失不大。它在上下文学习、算法任务和 NLP 上展示了结果,并通过代码级调试提供可解释的电路。
Recent research in mechanistic interpretability has attempted to reverse-engineer Transformer models by carefully inspecting network weights and activations. However, these approaches require considerable manual effort and still fall short of providing complete, faithful descriptions of the underlying algorithms. In this work, we introduce a procedure for training Transformers that are mechanistically interpretable by design. We build on RASP [Weiss et al., 2021], a programming language that can be compiled into Transformer weights. Instead of compiling human-written programs into Transformers, we design a modified Transformer that can be trained using gradient-based optimization and then automatically converted into a discrete, human-readable program. We refer to these models as Transformer Programs. To validate our approach, we learn Transformer Programs for a variety of problems, including an in-context learning task, a suite of algorithmic problems (e.g. sorting, recognizing Dyck languages), and NLP tasks including named entity recognition and text classification. The Transformer Programs can automatically find reasonable solutions, performing on par with standard Transformers of comparable size; and, more importantly, they are easy to interpret. To demonstrate these advantages, we convert Transformers into Python programs and use off-the-shelf code analysis tools to debug model errors and identify the "circuits" used to solve different sub-problems. We hope that Transformer Programs open a new path toward the goal of intrinsically interpretable machine learning.
研究动机与目标
- 在高风险任务中需要本质上可解释的 Transformer 模型以便审计和调试。
- 提出一个框架,在约束下训练 Transformer,确保与人类可读程序的确定性映射。
- 展示 Transformer Programs 在多种算法和 NLP 任务上能以具有竞争力的性能解决问题。
- 从训练好的模型中自动提取可执行的 Python/RASP-风格的程序,以实现电路级调试。
提出的方法
- 引入一个解耦的残差流约束,使每个模块读取固定变量集合并写入专门的正交子空间。
- 定义并使用离散、可解释的模块(类别注意头)进行训练,采用硬注意力,通过 Gumbel-Softmax 在优化阶段放松。
- 将每个注意头映射到一个类似 RASP 的谓词-聚合原语;学习对离散权重的分布(πK, πQ, πV, Wpredicate),并使用 Gumbel 重新参数化进行采样。
- 训练完成后,通过最大化离散权重并将注意头转换为具有 select_closest 原语的谓词函数,确定性地提取一个 Python 程序。
- 扩展框架以包括词嵌入、数值注意力,以及前馈/查找类层,从而扩展程序的 repertoire。
- 提供关于训练和映射到可解释程序的扩展与细节,包括示例代码和调试工作流。
实验结果
研究问题
- RQ1是否可以在有约束的情况下训练 Transformer 模型,从而保证映射到可解释程序的确定性?
- RQ2在维持可解释性的同时,这类 Transformer Programs 能在上下文学习、RASP 风格的算法任务和 NLP 基准测试中达到多大程度的性能?
- RQ3将模型转换为可读的 Python/RASP 风格代码后,学习到的程序和电路的 Qualitative 结构是什么?
- RQ4与标准 Transformer 相比,Transformer Programs 在不同难度任务上的准确性和可解释性有何差异?
主要发现
| 数据集 | 描述 | 示例 | k | L | H | M | Acc. |
|---|---|---|---|---|---|---|---|
| Reverse | 翻译一段字符串。 | reverse("abbc") = "cbba" | 8 | 3 | 8 | 2 | 99.79 |
| Histogram | 对于每个 token,序列中该字母的出现次数。 | hist("abbc") = "1221" | 8 | 1 | 4 | 2 | 100.0 |
| Double hist. | 对于每个 token,具有相同直方图值的唯一 token 的数量。 | hist2("abbc") = "2112" | 8 | 3 | 4 | 2 | 98.40 |
| Sort | 按字典序对输入排序。 | sort("cbab") = "abbc" | 8 | 3 | 8 | 4 | 99.83 |
| Most-Freq | 按频率排序的唯一输入 token,若有并列则用位置打破。 | most_freq("abbc") = "bac" | 8 | 3 | 8 | 4 | 75.69 |
| Dyck-1 | 对于每个位置 i,输入直到 i 是否是 Dyck-1 的有效字符串(T);一个有效前缀(P);还是无效(F)。 | dyck1("()())") = "PTPTF" | 16 | 3 | 8 | 2 | 99.30 |
| Dyck-2 | 同上,但在 Dyck-2。 | dyck2("({})(}") = "PPPTPF" | 16 | 3 | 4 | 4 | 99.09 |
| (Table continues as described in text) |
- Transformer Programs 在多种任务上的表现接近同等规模的标准 Transformer,达到合理的性能。
- 在 RASP 风格任务上,在若干任务上准确率超过 99%(在较长输入时有例外)。
- 在一个上下文学习的小型任务中,模型学会通过组装头来再现 induction-head 的行为,并达到完全的测试正确率。
- 在 CoNLL-2003 NER 上,Transformer Programs 的 F1 与标准 Transformer 相当,并超出 unigram 基线。
- 提取的 Python/RASP-风格程序暴露了可解释的电路和特征权重,有助于调试和电路分析。
- 存在权衡:标准 Transformer 通常在较长的序列或更大词汇表上超越 Transformer Programs,突出了扩展性挑战。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。