[論文レビュー] Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning
本論文は、Non-Parametric Transformers (NPTs) を導入し、全データセットを入力としてデータポイント間の自己注意を用いてポイント間の関係を学習し、データポイント間の照合を可能にし、データ表形式および画像データの予測を改善します。
We challenge a common assumption underlying most supervised deep learning: that a model makes a prediction depending only on its parameters and the features of a single input. To this end, we introduce a general-purpose deep learning architecture that takes as input the entire dataset instead of processing one datapoint at a time. Our approach uses self-attention to reason about relationships between datapoints explicitly, which can be seen as realizing non-parametric models using parametric attention mechanisms. However, unlike conventional non-parametric models, we let the model learn end-to-end from the data how to make use of other datapoints for prediction. Empirically, our models solve cross-datapoint lookup and complex reasoning tasks unsolvable by traditional deep learning models. We show highly competitive results on tabular data, early results on CIFAR-10, and give insight into how the model makes use of the interactions between points.
研究の動機と目的
- パラメトリック依存性の仮定を教師あり学習において問い直す。
- 全データセットを予測に用いる一般的なアーキテクチャ(NPTs)を提案する。
- アテンション機構を用いてデータポイント間の相互作用をエンドツーエンドで学習できるようにする。
- 表形式データと画像データセット上でデータポイント間の照合と推論を実証する。
提案手法
- 全データセット(X)とマスキング行列(M)をNPTsに入力し、マスクされた値 p(X^M | X^O) の再構成を可能にする。
- データポイント間の交互注意(ABD)と属性間注意(ABA)を適用してデータポイント間の関係と各データポイントの変換をモデル化する。
- 残差接続と層正規化を用いたマルチヘッド自己注意を、Transformer風のアーキテクチャに従って使用する。
- BERTに着想を得たマスク付き目的関数で学習し、ターゲット損失と補助的な特徴マスキング損失を組み合わせる: L^NPT = (1-λ)L^Targets + λL^Features.
- ミニバッチ処理で大規模データセットを扱い、訓練データとテストデータを同じバッチに保持してクロスポイント注意を可能にする。
実験結果
リサーチクエスチョン
- RQ1NPTs は標準的な教師付きベンチマークで競争力のある性能を達成できるか。
- RQ2理想化されたクロスポイント照合タスクにおいて、データポイント間の注意を活用して予測を学習できるか。
- RQ3実データの予測において、NPTs は実際にデータポイント間の相互作用に依存しているか。
- RQ4NPTs を用いた場合、予測に最も関連するデータポイントはどのようなものか。
主な発見
- NPTs は UCI ベンチマークの二値分類および多クラス分類タスクで最高の平均ランクを達成し、いくつかのブースティング法を上回った。
- 回帰タスクでは、NPTs は XGBoost と並んで最良の平均ランクを獲得し、CatBoost にのみ劣る。
- CIFAR-10 は CNN+ABD アーキテクチャで 93.7% のテスト精度を達成; MNIST は linear patching で 98.3% に到達。
- 半合成タンパク質回帰タスクで、NPTs は複製行からターゲット値を照合でき、ほぼ完璧な相関 (r = 99.9%) を達成。
- 破損実験では、他のデータポイントをランダム化すると予測性能が低下することを示し、実データにおけるデータポイント間の相互作用に依存していることを示す(データセットによって異なる)。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。