[论文解读] DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter
本文提出 DistilBERT,一种更小、更快、更高效的 BERT 版本,通过在预训练过程中使用知识蒸馏,保留了 97% 的语言理解能力。通过结合掩码语言建模、蒸馏和余弦距离损失的三重损失函数,DistilBERT 将模型大小减少 40%,推理时间减少 60%,同时在 GLUE 和下游任务中保持了强大的性能。
As Transfer Learning from large-scale pre-trained models becomes more prevalent in Natural Language Processing (NLP), operating these large models in on-the-edge and/or under constrained computational training or inference budgets remains challenging. In this work, we propose a method to pre-train a smaller general-purpose language representation model, called DistilBERT, which can then be fine-tuned with good performances on a wide range of tasks like its larger counterparts. While most prior work investigated the use of distillation for building task-specific models, we leverage knowledge distillation during the pre-training phase and show that it is possible to reduce the size of a BERT model by 40%, while retaining 97% of its language understanding capabilities and being 60% faster. To leverage the inductive biases learned by larger models during pre-training, we introduce a triple loss combining language modeling, distillation and cosine-distance losses. Our smaller, faster and lighter model is cheaper to pre-train and we demonstrate its capabilities for on-device computations in a proof-of-concept experiment and a comparative on-device study.
研究动机与目标
- 开发一种更小、更快、更高效的通用语言模型,同时在自然语言处理任务中保持高性能。
- 解决训练和部署大规模预训练模型(如 BERT)带来的计算和环境成本。
- 探索在预训练阶段而非仅在微调或特定任务适应中应用知识蒸馏。
- 通过创建足够小的模型,实现在移动和边缘设备上的高效推理。
- 证明在预训练阶段进行蒸馏可以生成具有强大泛化能力和迁移能力的紧凑模型。
提出的方法
- DistilBERT 是 BERT 的蒸馏版本,参数量减少 40%,通过将层数减半并移除 token-type 嵌入和池化层实现。
- 学生模型通过从教师 BERT-base 模型中每隔一层取用的方式初始化,以保留知识并提升收敛性。
- 三重损失函数结合了三个部分:掩码语言建模(MLM)损失、教师模型软标签的蒸馏损失,以及对隐藏状态方向对齐的余弦嵌入损失。
- 蒸馏损失使用温度缩放的 softmax,将教师模型的软标签分布知识传递给学生模型。
- 训练在与 BERT 相同的维基百科和 BookCorpus 数据集上进行,采用大批次梯度累积和动态掩码策略。
- 模型在 8×16GB V100 GPU 上训练约 90 小时,显著降低了与完整 BERT 训练相比的计算预算。
实验结果
研究问题
- RQ1在预训练阶段进行知识蒸馏是否能生成一个更小的语言模型,同时保留 BERT 的大部分性能?
- RQ2损失函数中的哪些组件对蒸馏模型实现高性能至关重要?
- RQ3与特定任务的蒸馏相比,在下游任务中,预训练阶段的蒸馏表现如何?
- RQ4蒸馏模型是否足够小且快速,足以在移动和边缘设备上高效运行?
- RQ5从教师模型权重初始化学生模型是否能提升收敛性和性能?
主要发现
- DistilBERT 在 GLUE 基准的宏分上达到 BERT-base 的 97%,参数量仅为后者的 40%。
- 在 STS-B 任务中,DistilBERT 的推理速度比 BERT-base 快 60%,CPU 上的推理时间从 BERT 的 668ms 降低至 410ms。
- 在 IMDb 情感分类任务中,DistilBERT 达到 92.82% 的准确率,仅比 BERT-base 的 93.46% 低 0.64%。
- 在 SQuAD 1.1 上,DistilBERT 达到 77.7/85.8 的 EM/F1 分数,与 BERT-base 的 81.2/88.5 相比仅低约 3.5 分。
- 消融实验证明,若移除三重损失中的任意一个组件(MLM、蒸馏或余弦损失),性能均会下降,其中余弦损失贡献显著。
- 在概念验证的移动应用中,DistilBERT 在 iPhone 7 Plus 上比 BERT-base 快 71%,模型大小为 207 MB。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。