Skip to main content
QUICK REVIEW

[論文レビュー] Understanding Hallucinations in Diffusion Models through Mode Interpolation

Sumukh K Aithal, Pratyush Maini|arXiv (Cornell University)|Jun 13, 2024
Mental Health Research Topics被引用数 5
ひとこと要約

本論文は拡散モデルの故障モードとして mode interpolation(モード補間)を特定し、近傍のデータモード間でサンプルを生成することで訓練データ分布の外に幻覚を生むと説明し、生成時および再帰的訓練中にこのようなサンプルを検出・剪定するための分散ベースの指標を提案します。

ABSTRACT

Colloquially speaking, image generation models based upon diffusion processes are frequently said to exhibit "hallucinations," samples that could never occur in the training data. But where do such hallucinations come from? In this paper, we study a particular failure mode in diffusion models, which we term mode interpolation. Specifically, we find that diffusion models smoothly "interpolate" between nearby data modes in the training set, to generate samples that are completely outside the support of the original training distribution; this phenomenon leads diffusion models to generate artifacts that never existed in real data (i.e., hallucinations). We systematically study the reasons for, and the manifestation of this phenomenon. Through experiments on 1D and 2D Gaussians, we show how a discontinuous loss landscape in the diffusion model's decoder leads to a region where any smooth approximation will cause such hallucinations. Through experiments on artificial datasets with various shapes, we show how hallucination leads to the generation of combinations of shapes that never existed. Finally, we show that diffusion models in fact know when they go out of support and hallucinate. This is captured by the high variance in the trajectory of the generated sample towards the final few backward sampling process. Using a simple metric to capture this variance, we can remove over 95% of hallucinations at generation time while retaining 96% of in-support samples. We conclude our exploration by showing the implications of such hallucination (and its removal) on the collapse (and stabilization) of recursive training on synthetic data with experiments on MNIST and 2D Gaussians dataset. We release our code at https://github.com/locuslab/diffusion-model-hallucination.

研究の動機と目的

  • 近傍データモード間のモード補間として、拡散モデルにおける幻覚を形式化し、特徴づける。
  • 学習されたスコア関数が不連続性を滑らかにするメカニズムを分析し、補間的でサポート外のサンプルが生じることを明らかにする。
  • 生成時に幻覚を検出・フィルタリングするため、軌跡分散に基づく指標を提案する。
  • 再帰的訓練への影響を検討し、合成データおよび MNIST データセットでの事前フィルタリングによる緩和を実証する。

提案手法

  • 1Dおよび2Dのガウス混合分布を研究し、拡散モデルが近傍のモード間で補間することを示す。
  • ニューラルネットワークが真のスコア関数の滑らかな近似を学習し、分離したモード間の領域で補間を生じさせることを示す。
  • 最終的な拡散ステップでのx0予測の高分散軌道を幻覚の特徴として同定する。
  • 時間ステップ全体で予測されたx0の分散に基づく幻覚指標 Hal(x) を定義し、サンプルを分類する。
  • この指標のフィルタリング能力を評価し、幻覚の約95–96%を除去しつつ、サポート内のサンプルの約95–98%を保持する。
Figure 1 : Hallucinations in Diffusion Models : Original Dataset (Left) & Generated Dataset (Right). The original dataset consists of 64x64 images divided into three columns, each containing a triangle, square, or pentagon with a 0.5 probability of the shape being present. Each shape appears at most
Figure 1 : Hallucinations in Diffusion Models : Original Dataset (Left) & Generated Dataset (Right). The original dataset consists of 64x64 images divided into three columns, each containing a triangle, square, or pentagon with a 0.5 probability of the shape being present. Each shape appears at most

実験結果

リサーチクエスチョン

  • RQ1拡散モデルが訓練データのサポート外に位置するサンプルを生成してしまう原因は何か(幻覚とは何か)?
  • RQ2拡散モデルは近傍データモード間でモード補間を示すのか、またスコア関数はどのように寄与するのか?
  • RQ3軌跡分散ベースの指標は、訓練データのサポート内サンプルを大きく損なうことなく幻覚を検出・フィルタリングできるか?
  • RQ4幻覚が再帰的訓練とモデルの安定性に及ぼす影響はどのようなものか?

主な発見

  • 拡散モデルは、合成の1Dおよび2Dガウス混合分布の近傍モード間で補間し、訓練データのサポート外のサンプルを生み出す。
  • 鋭いモードジャンプではなく、滑らかな学習済みスコア関数が分離したモード間の補間を駆動する。
  • 逆拡散の終盤における予測x0軌道の高分散は幻覚と相関し、検出を可能にする。
  • Hal(x) 指標は、設定を横断して幻覚のおよそ95–96%を除去し、サポート内サンプルのおよそ95–98%を保存できる。
  • この指標に基づく事前フィルタリングは、2Dガウス、Simple Shapes、MNISTデータセットでの再帰訓練中のモデル崩壊を緩和する。
Figure 2 : Mode Interpolation in 1D Gaussian . The red curve indicates the PDF of the true data distribution $q(x)$ , which is a mixture of 3 Gaussians (notice that the y-axis is in log-scale). In blue, we show a density histogram of the samples generated by a DDPM trained on varying number of sampl
Figure 2 : Mode Interpolation in 1D Gaussian . The red curve indicates the PDF of the true data distribution $q(x)$ , which is a mixture of 3 Gaussians (notice that the y-axis is in log-scale). In blue, we show a density histogram of the samples generated by a DDPM trained on varying number of sampl

より良い研究を、今すぐ始めましょう

論文設計から論文執筆まで、研究時間を劇的に削減しましょう。

クレジットカード登録不要

このレビューはAIが作成し、人間の編集者が確認しました。