[论文解读] Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning
本文提出非参数变换器(NPTs),将整个数据集作为输入,并利用数据点之间的自注意力来学习点之间的关系,从而实现跨数据点的查找并在表格数据和图像数据上获得更好的预测。
We challenge a common assumption underlying most supervised deep learning: that a model makes a prediction depending only on its parameters and the features of a single input. To this end, we introduce a general-purpose deep learning architecture that takes as input the entire dataset instead of processing one datapoint at a time. Our approach uses self-attention to reason about relationships between datapoints explicitly, which can be seen as realizing non-parametric models using parametric attention mechanisms. However, unlike conventional non-parametric models, we let the model learn end-to-end from the data how to make use of other datapoints for prediction. Empirically, our models solve cross-datapoint lookup and complex reasoning tasks unsolvable by traditional deep learning models. We show highly competitive results on tabular data, early results on CIFAR-10, and give insight into how the model makes use of the interactions between points.
研究动机与目标
- 质疑监督学习中的参数依赖假设。
- 提出一种通用架构(NPTs),使用整个数据集进行预测。
- 通过注意力机制实现端到端学习数据点之间的相互作用。
- 在表格数据和图像数据集上展示跨数据点的查找与推理。
提出的方法
- 将整个数据集(X)和掩码矩阵(M)输入到 NPTs,以实现对被遮蔽值 p(X^M | X^O) 的重建。
- 应用交替的数据点之间注意力(ABD)和属性之间注意力(ABA)来建模数据点间关系及每个数据点的变换。
- 使用带残差连接和层归一化的多头自注意力,遵循 Transformer 风格的架构。
- 采用受 BERT 启发的掩码目标进行训练,将目标损失与辅助特征屏蔽损失结合起来:L^NPT = (1-λ)L^Targets + λL^Features。
- 通过小批量处理应对大型数据集,保持训练与测试数据在同一批中以实现跨数据点的注意力。
实验结果
研究问题
- RQ1NPTs 是否能在标准监督基准测试中实现有竞争力的性能?
- RQ2NPTs 是否能够通过在理想化的跨点查找任务中利用数据点之间的注意力来学习预测?
- RQ3NPTs 是否在现实世界数据预测中确实依赖数据点之间的相互作用?
- RQ4在使用 NPTs 时,哪类数据点对预测最相关?
主要发现
- NPTs 在 UCI 基准的二分类和多分类任务中获得最高的平均排名,优于若干提升方法。
- 在回归任务中,NPTs 与 XGBoost 并列为最佳平均排名,仅被 CatBoost 超越。
- CIFAR-10 通过 CNN+ABD 架构达到 93.7% 的测试准确率;MNIST 通过线性分块达到 98.3%。
- 在半合成蛋白质回归任务中,NPTs 可以从重复的行中查找目标值,达到接近完美的相关性(r = 99.9%)。
- 腐蚀实验显示当其他数据点被随机化时预测性能下降,表明在真实数据上依赖数据点之间的相互作用(因数据集而异)。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。