[论文解读] Transformer to CNN: Label-scarce distillation for efficient text classification
本文提出了一种知识蒸馏框架,通过使用大型预训练Transformer模型(OpenAI GPT)作为教师,训练一个轻量级、高效的卷积神经网络(BlendCNN)作为学生模型。尽管参数量少39倍、推理速度提升300倍,该学生模型在标签稀缺条件下的多个文本分类基准上仍优于教师模型,表明经过适当设计以实现层次化表征学习的蒸馏CNN模型,可在特定条件下超越大型注意力机制模型。
Significant advances have been made in Natural Language Processing (NLP) modelling since the beginning of 2018. The new approaches allow for accurate results, even when there is little labelled data, because these NLP models can benefit from training on both task-agnostic and task-specific unlabelled data. However, these advantages come with significant size and computational costs. This workshop paper outlines how our proposed convolutional student architecture, having been trained by a distillation process from a large-scale model, can achieve 300x inference speedup and 39x reduction in parameter count. In some cases, the student model performance surpasses its teacher on the studied tasks.
研究动机与目标
- 为解决大型预训练Transformer模型在工业自然语言处理应用中计算与内存开销过高的问题。
- 探究在标签稀缺条件下,轻量级CNN学生模型是否能够达到甚至超越大型预训练Transformer教师模型的性能。
- 研究利用伪标签未标注数据进行蒸馏,能否在标注样本有限的情况下提升学生模型的泛化能力。
- 设计一种新型CNN架构(BlendCNN),以有效从蒸馏后的logits中捕捉层次化表征,从而提升文本分类性能。
提出的方法
- 使用预训练的OpenAI Transformer模型作为教师,经任务特定数据微调后,用于为有标签和无标签数据生成软标签(logits)。
- 设计了一种新型CNN架构BlendCNN,其包含多个并行的卷积分支,每个分支从不同层进行池化,随后进行拼接并经过全连接融合层。
- 学生模型通过知识蒸馏进行训练,采用学生与教师在有标签数据和伪标签无标签样本上的logits之间的平均绝对误差(MAE)作为损失函数。
- 使用100维可学习的GloVe词嵌入作为学生模型的输入特征,应用迁移学习。
- 蒸馏过程使用每类100个有标签样本和1,000个无标签样本,通过伪标签生成额外的训练信号。
- 模型训练采用Adam优化器,所有实验均使用固定的初始学习率10⁻³。
实验结果
研究问题
- RQ1在标签稀缺条件下,轻量级CNN学生模型能否实现与大型预训练Transformer相当或更优的性能?
- RQ2当仅有少量有标签样本时,从强大教师模型中蒸馏知识在多大程度上能提升小型学生网络的准确率?
- RQ3利用伪标签未标注数据在提升低资源文本分类的蒸馏过程方面有多有效?
- RQ4像BlendCNN这样特别设计的CNN架构,能否有效利用从蒸馏logits中提取的层次化表征,从而超越更大的模型?
主要发现
- 3层BlendCNN学生模型在AG News数据集上达到91.2%的准确率,通过蒸馏训练后,超过OpenAI Transformer教师模型的88.7%。
- 在DBpedia数据集上,8层BlendCNN达到98.5%的准确率,超过教师模型的97.5%,采用相同的蒸馏协议。
- 在Yahoo Answers数据集上,3层BlendCNN达到71.0%的准确率,略高于教师模型的70.4%。
- 尽管性能超越教师模型,3层BlendCNN模型的参数量仅为298万(2.98M),比教师模型的1.165亿(116.5M)少39倍,且推理速度达3,676句/秒,比教师模型的11.76句/秒快300倍。
- 蒸馏带来的性能增益显著:若不使用蒸馏,BlendCNN在AG News上的准确率仅为87.6%,表明蒸馏对实现高性能至关重要。
- 在蒸馏过程中引入无标签数据能显著提升学生模型性能,表现为使用伪标签时的得分明显高于仅使用有标签数据训练的结果。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。