[论文解读] Weight-averaged consistency targets improve semi-supervised deep learning results.
本文提出 Mean Teacher,一种半监督学习方法,通过在训练迭代中对模型权重进行平均,生成一致的目标预测,从而提升泛化能力。与时间集成(Temporal Ensembling)相比,该方法更频繁地更新目标,实现了最先进性能:在仅使用 250 个标签的情况下,SVHN 上错误率为 4.35%;在使用 4,000 个标签的情况下,CIFAR-10 上错误率为 6.28%,优于以往方法。
The recently proposed Temporal Ensembling has achieved state-of-the-art results in several semi-supervised learning benchmarks. It maintains an exponential moving average of label predictions on each training example, and penalizes predictions that are inconsistent with this target. However, because the targets change only once per epoch, Temporal Ensembling becomes unwieldy when learning large datasets. To overcome this problem, we propose Mean Teacher, a method that averages model weights instead of label predictions. As an additional benefit, Mean Teacher improves test accuracy and enables training with fewer labels than Temporal Ensembling. Without changing the network architecture, Mean Teacher achieves an error rate of 4.35% on SVHN with 250 labels, outperforming Temporal Ensembling trained with 1000 labels. We also show that a good network architecture is crucial to performance. Combining Mean Teacher and Residual Networks, we improve the state of the art on CIFAR-10 with 4000 labels from 10.55% to 6.28%, and on ImageNet 2012 with 10% of the labels from 35.24% to 9.11%.
研究动机与目标
- 为解决时间集成在大规模数据集上因目标更新频率过低而导致的效率低下问题。
- 通过用权重平均模型一致性替代基于预测的一致性,提升半监督学习性能。
- 减少半监督训练中达到高准确率所需的标注样本数量。
- 证明将 Mean Teacher 与残差网络等强模型架构结合可进一步提升性能。
提出的方法
- Mean Teacher 通过指数移动平均方法对模型自身权重进行计算,形成教师网络。
- 在训练过程中,学生网络的预测结果被正则化,以匹配教师网络在相同输入上的预测结果。
- 教师网络的权重通过动量更新规则进行更新:θ_teacher ← τθ_teacher + (1−τ)θ_student。
- 该方法通过最小化学生与教师在相同增强输入上的预测结果之间的 L2 损失,实现一致性正则化。
- 该方法支持更频繁的目标更新,提升了大规模数据集上的训练稳定性和收敛性。
- 该方法与网络架构无关,可与任意深度神经网络结合,尤其适用于残差网络。
实验结果
研究问题
- RQ1与基于预测平均的目标相比,基于权重平均的一致性目标是否能提升半监督学习性能?
- RQ2通过权重平均实现的频繁目标更新是否能带来更好的泛化能力和更快的收敛速度?
- RQ3Mean Teacher 是否能以远少于以往方法的标注样本数量,实现最先进性能?
- RQ4将 Mean Teacher 与残差网络结合后,在标准基准测试上的性能表现如何?
主要发现
- 在仅使用 250 个标注样本的情况下,Mean Teacher 在 SVHN 上实现了 4.35% 的测试错误率,优于使用 1,000 个标签训练的时间集成方法。
- 在 CIFAR-10 上使用 4,000 个标签时,结合残差网络后,错误率从 10.55% 降低至 6.28%。
- 在 ImageNet 2012 上使用 10% 的训练标签时,结合 Mean Teacher 和残差网络后,错误率从 35.24% 降低至 9.11%。
- 与时间集成相比,该方法实现了更频繁的目标更新,提升了大规模数据集上的训练效率。
- 性能提升不仅归因于一致性机制,还归因于使用了强大的残差网络架构。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。