[论文解读] Domain Generalization via Model-Agnostic Learning of Semantic Features
MASF 采用分段的模型无关训练,结合全局类别对齐和局部基于度量的聚类,学习在未见域上也能泛化的语义特征空间,在 VLCS 和 PACS 上实现了最先进的结果,并在医学影像分割上有所提升。
Generalization capability to unseen domains is crucial for machine learning models when deploying to real-world conditions. We investigate the challenging problem of domain generalization, i.e., training a model on multi-domain source data such that it can directly generalize to target domains with unknown statistics. We adopt a model-agnostic learning paradigm with gradient-based meta-train and meta-test procedures to expose the optimization to domain shift. Further, we introduce two complementary losses which explicitly regularize the semantic structure of the feature space. Globally, we align a derived soft confusion matrix to preserve general knowledge about inter-class relationships. Locally, we promote domain-independent class-specific cohesion and separation of sample features with a metric-learning component. The effectiveness of our method is demonstrated with new state-of-the-art results on two common object recognition benchmarks. Our method also shows consistent improvement on a medical image segmentation task.
研究动机与目标
- 在测试域统计未知且训练期间没有目标数据的情况下,激励域泛化。
- 学习对跨多个源域有鲁棒性的具有语义意义的特征表示。
- 提出全局和局部正则化,塑造特征空间——全局对齐类间关系和局部、域无关的类聚类。
- 利用 episodic 训练的模型无关元学习来提升对未见域的泛化能力。
提出的方法
- 通过将源域分为 meta-train 和 meta-test 来模拟域转移,采用 episodic 训练。
- 引入一个全局类别对齐损失,使用对称KL散度使 meta-train 和 meta-test 域的软混淆矩阵一致。
- 通过度量嵌入网络引入局部样本聚类损失,利用对比或三元组损失鼓励域无关的类别内聚和分离。
- 用任务损失加元损失更新特征提取器和任务网络;用局部损失更新嵌入网络以强化聚类。
- 使用类别均值特征向量形成每个类别的软标签,并通过温度控制的 softmax 计算软混淆矩阵,指导跨域语义对齐。
- 提供两种具体的局部聚类度量学习损失:对比损失(d_phi)和带半难样本挖掘的三元组损失,以实现高效训练。
实验结果
研究问题
- RQ1我们如何在多个源域上训练模型,使其在没有目标域数据的训练条件下对未见域具备泛化能力?
- RQ2是否通过显式正则化特征空间的语义结构来提升域泛化,而不是仅依赖传统的任务驱动损失?
- RQ3将全局的类间关系对齐与局部样本聚类结合,是否在域转移下带来更好的泛化?
- RQ4模型无关的 episodic 学习框架在自然图像识别基准和医学影像分割任务上是否都有效?
主要发现
- MASF 在 VLCS 上实现了状态至上的平均准确率,目标域从 72.19 提升到 74.11。
- 在 PACS 上,MASF 相比基线平均准确率提升了 3.51 个百分点,在 Sketch 目标域上显示出显著增益。
- 消融研究证实全局类别对齐和局部聚类都对性能有贡献,且二者与 episodic 训练结合时达到最佳结果。
- 深度残差结构(ResNet-18/50)也从 MASF 中受益,展示了对不同网络骨干的鲁棒性。
- 在医学脑部 MRI 分割中,MASF 相比 DeepAll 提升了 Dice 分数,尤其在迁移到新的临床站点(Set-D)时,且通过轮廓分析显示类内聚集更紧凑。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。