[论文解读] Unsupervised Domain Adaptation by Backpropagation
引入梯度反转层,在训练标签预测器的同时学习域不变特征,使在深度网络中通过标准反向传播实现无监督域适应;在 Office 数据集上达到最先进结果,在数字与合成到现实任务上取得强劲结果。
Top-performing deep architectures are trained on massive amounts of labeled data. In the absence of labeled data for a certain task, domain adaptation often provides an attractive option given that labeled data of similar nature but from a different domain (e.g. synthetic images) are available. Here, we propose a new approach to domain adaptation in deep architectures that can be trained on large amount of labeled data from the source domain and large amount of unlabeled data from the target domain (no labeled target-domain data is necessary). As the training progresses, the approach promotes the emergence of "deep" features that are (i) discriminative for the main learning task on the source domain and (ii) invariant with respect to the shift between the domains. We show that this adaptation behaviour can be achieved in almost any feed-forward model by augmenting it with few standard layers and a simple new gradient reversal layer. The resulting augmented architecture can be trained using standard backpropagation. Overall, the approach can be implemented with little effort using any of the deep-learning packages. The method performs very well in a series of image classification experiments, achieving adaptation effect in the presence of big domain shifts and outperforming previous state-of-the-art on Office datasets.
研究动机与目标
- 在目标域标签不可用时激励域适应,并利用丰富的源域带标签数据。
- 将域适应嵌入到深度特征学习中,以产生判别性同时又域不变的表示。
- 开发一个简单的与 SGD 兼容的训练过程,将梯度反转层整合到现有架构中。
提出的方法
- 提出一个三部分网络:特征提取器 G_f、标签预测器 G_y 和在 G_f 输出上工作的域分类器 G_d。
- 在 G_f 与 G_d 之间引入梯度反转层,使反向传播的域损失乘以 -λ,促使学习域不变特征。
- 提出一个鞍点目标,最小化源数据上的标签损失同时最大化域损失,由 λ 控制。
- 通过标准 SGD 风格的更新进行优化,反向通过梯度反转层,产生域不变的特征空间。
- 将该方法与基于 HΔH 距离的泛化界限相关联,并展示域分类器的性能界定域差异。
实验结果
研究问题
- RQ1一个深度网络是否可以端到端地进行无监督域适应,通过联合优化判别标签与域不变性来实现?
- RQ2在学习过程中引入梯度反转机制是否能够有效对齐源域和目标域的特征分布?
- RQ3与以往方法相比,该方法在标准域适应基准上表现如何?
- RQ4是否可以将此方法应用于具有合成到现实和跨域迁移的实际数据集(如 Office 数据集)?
主要发现
| 方法 | 源 | MNIST | Syn Numbers | SVHN | Syn Signs | 目标 | MNIST-M | SVHN | MNIST | GTSRB |
|---|---|---|---|---|---|---|---|---|---|---|
| 仅源数据 | MNIST | 0.5749 | - | - | - | MNIST-M | - | - | - | - |
| SA (Fernando et al., 2013) | MNIST | 0.6078 (7.9%) | 0.8672 (1.3%) | - | - | MNIST-M | - | - | - | - |
| 提出的方法 | MNIST | 0.8149 (57.9%) | 0.9048 (66.1%) | 0.7107 (29.3%) | 0.8866 (56.7%) | MNIST-M | - | - | - | - |
| 在目标上训练 | - | 0.9891 | 0.9244 | 0.9951 | 0.9987 | - | - | - | - | - |
| 仅源数据 | Syn Numbers | - | 0.8665 | - | - | SVHN | - | - | - | - |
| SA (Fernando et al., 2013) | Syn Numbers | - | 0.8672 (1.3%) | - | - | SVHN | - | - | - | - |
| 提出的方法 | Syn Numbers | - | 0.9048 (66.1%) | 0.7107 (29.3%) | 0.8866 (56.7%) | SVHN | - | - | - | - |
| 在目标上训练 | - | - | - | - | SVHN | 0.9244 | - | - | - | - |
| 仅源数据 | Syn Signs | - | - | - | 0.7400 | GTSRB | - | - | - | - |
| SA (Fernando et al., 2013) | Syn Signs | - | - | - | 0.7635 (9.1%) | GTSRB | - | - | - | - |
| 提出的方法 | Syn Signs | - | - | - | 0.8866 (56.7%) | GTSRB | - | - | - | - |
| 在目标上训练 | - | - | - | - | GTSRB | 0.9987 | - | - | - | - |
| 仅源数据 | Office (Amazon→DSLR) | - | - | - | - | DSLR | 0.433? | - | - | - |
| SA (Fernando et al., 2013) | Office (Amazon→DSLR) | - | - | - | - | DSLR | 0.450 | - | - | - |
| 提出的方法 | Office (Amazon→DSLR) | - | - | - | - | DSLR | 0.673±0.017 | - | - | - |
| 仅源数据 | Office (Amazon→Webcam) | - | - | - | - | Webcam | 0.464? | - | - | - |
| 提出的方法 | Office (Amazon→Webcam) | - | - | - | - | Webcam | 0.673±0.017 | - | - | - |
- 在多项跨域任务中显著优于仅源数据模型。
- 在 MNIST 到 MNIST-M 上,方法达到 0.8149 的准确率,对比线下基线 0.5749 并且超过 SA 基线。
- 在 Syn Numbers 到 SVHN 上,方法达到 0.9048 的准确率,对比线下基线 0.8665,且优于 SA。
- 在 SVHN 到 MNIST 上,方法达到 0.7107 的准确率,对比线下基线 0.5919,且优于 SA。
- 在 Syn Signs 到 GTSRB 上,方法达到 0.8866 的准确率,对比线下基线 0.7400,且优于 SA。
- Office 数据集实验表明,所提方法在 Amazon→DSLR/Webcam 为 0.673±0.017,在 DSLR→Webcam 为 0.940±0.008,在 Webcam→DSLR 为 0.937±0.010,优于以前的方法。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。