Skip to main content
QUICK REVIEW

[论文解读] Generalizing Across Domains via Cross-Gradient Training

Shiv Shankar, Vihari Piratla|arXiv (Cornell University)|Apr 28, 2018
Domain Adaptation and Few-Shot Learning参考文献 27被引用 268
一句话总结

CrossGrad 使用领域引导的输入扰动和联合标签-领域训练,在没有目标领域数据或显式领域特征的情况下,使分类器对未见领域具泛化能力。它在多项任务中优于领域对抗和通用扰动方法。

ABSTRACT

We present CROSSGRAD, a method to use multi-domain training data to learn a classifier that generalizes to new domains. CROSSGRAD does not need an adaptation phase via labeled or unlabeled data, or domain features in the new domain. Most existing domain adaptation methods attempt to erase domain signals using techniques like domain adversarial training. In contrast, CROSSGRAD is free to use domain signals for predicting labels, if it can prevent overfitting on training domains. We conceptualize the task in a Bayesian setting, in which a sampling step is implemented as data augmentation, based on domain-guided perturbations of input instances. CROSSGRAD parallelly trains a label and a domain classifier on examples perturbed by loss gradients of each other's objectives. This enables us to directly perturb inputs, without separating and re-mixing domain signals while making various distributional assumptions. Empirical evaluation on three different applications where this setting is natural establishes that (1) domain-guided perturbation provides consistently better generalization to unseen domains, compared to generic instance perturbation methods, and that (2) data augmentation is a more stable and accurate method than domain adversarial training.

研究动机与目标

  • 在多领域数据上进行学习以在不进行目标领域自适应的情况下实现对未见领域的泛化。
  • 提出一种领域引导的数据增强方法,通过沿领域损失梯度扰动输入。
  • 在避免过拟合到训练域的同时,保留对域内预测有帮助的域信号。
  • 在手写、字体识别、MNIST 旋转和语音单词任务上展示经验性泛化增益。

提出的方法

  • 将输入 x 模型化为在潜在领域特征 g 的作用下受标签 y 和领域 d 影响。
  • 使用领域分类器 G 从 x 提取连续领域特征 g 并预测 d。
  • 在领域损失梯度方向扰动 x,以生成具有不同领域特征的增强样本。
  • 用交叉扰动数据训练标签和领域分类器,以防止过拟合到训练域。
  • 将训练形式化为在交叉梯度下交替更新两个目标 J_l(标签)和 J_d(领域)。
  • 在四个数据集上,对比基线 Baseline、DAN 以及 LabelGrad,使用不同的训练/测试域来评估 CrossGrad。

实验结果

研究问题

  • RQ1CrossGrad 是否能够在没有显式目标域数据或域特征的情况下,对未见域实现泛化?
  • RQ2领域引导的扰动是否比通用扰动或领域对抗方法提供更好的泛化?
  • RQ3CrossGrad 在多样化任务(字体/手写识别、MNIST 旋转、语音指令)和多种架构上的表现如何?
  • RQ4在何种训练条件(域数量)下 CrossGrad 最有效?

主要发现

方法名称字体手写MNIST语音
Baseline68.582.595.672.6
DAN68.983.898.070.4
LabelGrad71.486.397.872.7
CrossGrad72.688.698.673.5
  • CrossGrad 在所有四个数据集上都比 Baseline、DAN 和 LabelGrad 提升准确率。
  • 在 Fonts、Handwriting、MNIST、Speech 上,CrossGrad 的准确率分别为 72.6、88.6、98.6 和 73.5(相对 Baseline 为 68.5、82.5、95.6、72.6)。
  • CrossGrad 在 Fonts 和 Handwriting 上对不同架构(LeNet 和 ResNet)保持了增益。
  • 当训练域数量较少时,CrossGrad 的性能增益更大,随着域覆盖范围增大而减小。
  • 领域对抗网络(DAN)在这些设置中提供的增益不稳定且难以调优。
  • LabelGrad 虽有帮助,但通常比 CrossGrad 提升更小,且域多样性增加时效果下降。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。