[论文解读] A Simple Framework for Contrastive Learning of Visual Representations
简而言之:SimCLR 提出一个简单、可扩展的自监督视觉表征学习框架,使用对比损失、强数据增强、非线性投影头和大批量训练,在 ImageNet 的线性评估上达到最先进水平,无需专用架构或内存库。
This paper presents SimCLR: a simple framework for contrastive learning of visual representations. We simplify recently proposed contrastive self-supervised learning algorithms without requiring specialized architectures or a memory bank. In order to understand what enables the contrastive prediction tasks to learn useful representations, we systematically study the major components of our framework. We show that (1) composition of data augmentations plays a critical role in defining effective predictive tasks, (2) introducing a learnable nonlinear transformation between the representation and the contrastive loss substantially improves the quality of the learned representations, and (3) contrastive learning benefits from larger batch sizes and more training steps compared to supervised learning. By combining these findings, we are able to considerably outperform previous methods for self-supervised and semi-supervised learning on ImageNet. A linear classifier trained on self-supervised representations learned by SimCLR achieves 76.5% top-1 accuracy, which is a 7% relative improvement over previous state-of-the-art, matching the performance of a supervised ResNet-50. When fine-tuned on only 1% of the labels, we achieve 85.8% top-5 accuracy, outperforming AlexNet with 100X fewer labels.
研究动机与目标
- 在没有监督标签的情况下,激发对有效自监督视觉表征的需求。
- 系统性研究对比框架的哪些组成部分能够带来高质量的表征。
- 展示数据增强、非线性投影头和训练动态如何影响性能。
- 证明更大批量和更长训练相对于监督学习能提升对比学习效果。
- 在 ImageNet 和迁移数据集上提供自监督、半监督与监督基线的实证证据。
提出的方法
- 定义一个简单的对比框架(SimCLR),包含四个组成部分:随机数据增强、基础编码器 f(·)、非线性投影头 g(·)以及对比损失(NT-Xent)。
- 通过对每个样本的两个相关视图进行随机增强;使用余弦相似度和温度参数 τ,在两个视图的投影表征 z_i 和 z_j 之间最大化一致性。
- 在编码器之上使用一个非线性投影头(一个小型多层感知机)将表征映射到对比空间,消融实验显示相较于单位投影或线性投影具有更大的收益,投影前的表示 h 也保留了更多与任务相关的信息。
- 在不使用内存库的情况下,使用大批量(256–8192)训练,采用 LARS 优化器以及跨设备的同步批量归一化;用冻结表征的线性分类器进行评估(线性评估协议)。
- 系统性地消融数据增强、投影头架构、损失函数、批量大小和训练时长,以识别驱动性能的因素。
实验结果
研究问题
- RQ1哪些数据增强组合能为对比学习提供最具信息性的预测任务?
- RQ2非线性投影头是否比直接使用编码器输出提高下游表征质量?
- RQ3批量大小、训练时长和优化选择如何相对于监督学习影响对比学习的性能?
- RQ4在这个框架中,哪些损失函数和归一化/温度设置最适合对比学习?
- RQ5学得的表征如何迁移到下游识别任务和数据集?
主要发现
| 方法 | 架构 | 参数(M) | Top 1 | Top 5 |
|---|---|---|---|---|
| 本地聚合 | ResNet-50 | 24 | 60.2 | - |
| MoCo | ResNet-50 | 24 | 60.6 | - |
| PIRL | ResNet-50 | 24 | 63.6 | - |
| CPC v2 | ResNet-50 | 24 | 63.8 | 85.3 |
| SimCLR(ours) | ResNet-50 | 24 | 69.3 | 89.0 |
| SimCLR(ours) | ResNet-50(2×) | 94 | 74.2 | 92.0 |
| SimCLR(ours) | ResNet-50(4×) | 375 | 76.5 | 93.2 |
- 数据增强的组合至关重要;随机裁剪加颜色失真相比任何单一增强显著提升表征质量。
- 在编码器顶部的非线性投影头(z = g(h))显著提升线性评估的准确性,相较于使用 h 或线性投影;投影前的表示 h 保留了更多任务相关信息。
- 带有温度参数的 NT-Xent 损失对正则化嵌入的归一化和合理的 τ 调整对性能至关重要;使用余弦相似度的 NT-Xent 优于其他替代方案。
- 更大的批量和更长的训练提供更多负样本并改善收敛,对比学习比监督学习更能从规模中受益。
- SimCLR 使用 ResNet-50(4× 宽度)进行线性评估在 ImageNet 上达到 76.5% 的 Top-1 准确率,并与有监督的 ResNet-50 性能相匹配;在 1% 标签(即微调)下,Top-5 提升达到 85.8% 的 ImageNet;在迁移和其他数据集上,在许多任务上表现具有竞争力或优越。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。