[论文解读] Your Classifier is Secretly an Energy Based Model and You Should Treat it Like One
本论文将标准判别分类器重新解释为联合能量模型(EBMs),并训练它们以建模 p(x, y) 和 p(x),在提高校准、OOD 检测、鲁棒性方面实现与判别和生成性能并驾齐驱的结果。
We propose to reinterpret a standard discriminative classifier of p(y|x) as an energy based model for the joint distribution p(x,y). In this setting, the standard class probabilities can be easily computed as well as unnormalized values of p(x) and p(x|y). Within this framework, standard discriminative architectures may beused and the model can also be trained on unlabeled data. We demonstrate that energy based training of the joint distribution improves calibration, robustness, andout-of-distribution detection while also enabling our models to generate samplesrivaling the quality of recent GAN approaches. We improve upon recently proposed techniques for scaling up the training of energy based models and presentan approach which adds little overhead compared to standard classification training. Our approach is the first to achieve performance rivaling the state-of-the-artin both generative and discriminative learning within one hybrid model.
研究动机与目标
- 将标准分类器重新框架为联合能量基模型,以建模 p(x, y) 和 p(x)。
- 在不损害判别性能的前提下,实现在无标签数据上的训练。
- 通过基于能量的训练提高校准、鲁棒性和对分布外样本的检测。
- 在单一模型中展示与判别准确性并驾齐驱的生成能力。
提出的方法
- 通过 pθ(x,y) = exp(fθ(x)[y]) / Z(θ) 给出来自分类器对数的 p(x, y),其中能量 Eθ(x,y) = -fθ(x)[y]。
- 通过对 y 边缘化获得未归一化的 p(x): pθ(x) = sum_y exp(fθ(x)[y]) / Z(θ)。
- 将 LogSumExp(fθ(x)) 作为 p(x) 的能量代理,并通过标准交叉熵训练 p(y|x)。
- 使用 SGLD 从模型分布采样来训练对 log p(x) 的无偏梯度估计。
- 使用持续对比散度来估计 log p(x) 梯度中的期望。
- 基于 Wide Residual Networks 架构,并在 CIFAR10、SVHN 和 CIFAR100 上进行训练。
实验结果
研究问题
- RQ1Can a standard classifier be interpreted as a joint energy-based model for p(x, y) and p(x)?
- RQ2Does EBMs-based training improve calibration, OOD detection, and adversarial robustness while maintaining discriminative performance?
- RQ3Do joint EBMs offer competitive generative quality alongside discriminative accuracy in large-scale image datasets?
- RQ4How scales of SGLD-based sampling affect training stability and performance?
主要发现
- JEM 在 CIFAR10/SVHN/CIFAR100 上实现与现有最优混合模型相媲美的准确性,同时具备可与之相比的生成能力。
- JEM 提高 CIFAR100 的校准,在期望校准误差(ECE)上接近完美校准。
- JEM 使用多种评分方法提升对分布外样本的检测,包括 log p(x) 和基于梯度的 mass score,优于若干基线。
- 该模型对对抗干扰表现出鲁棒性,相较标准分类器有提高,并接近专门的鲁棒方法。
- 在 CIFAR10 上,JEM 达到 92.9% 的准确率,IS 8.76,FID 38.4,优于若干混合基线,同时与仅判别模型具有竞争力。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。