Site cover image

Site icon imageSen(Qian)’s Memo

This website is Donglin Qian (Torin Sen)’s memo, especially about machine learning papers and competitive programming.

Diffusion Modelについてのメモ

確率微分方程式による分布からのサンプリング

非常に高次元で複雑な分布p(x)p(x)について、全容を知るのは難しいので、以下のような式からサンプリングする。KKは比例定数。

π=Kexp(p(x))\pi = K\exp(-p(x))

つまり、p(x)p(x)が高いとサンプリングされる確率が低くなり、p(x)p(x)が低いとサンプリングされやすい。

このp(x)p(x)は物理的な背景からポテンシャル関数と呼ばれる。

これは単にサンプリングされる確率が指数的にポテンシャル関数に従い変わることを示しており、具体的なサンプリング方法はこれを満たす必要があるような手法が必要って感じ。

Langevin Monte Carlo

ランジュバンって読むぞ。

確率微分方程式で、ある拡散していく様子を記述できるが、その拡散の軌跡をサンプリング結果とすることで、このような指数的な条件に従った分布である。以下の確率微分方程式に従う。

BtB_tは時間に依存したノイズ項で局所最適解にとどまることを防いでいる。2\sqrt{2}は正規分布関連のノイズのためのスケーリングらしい。

dxt=p(xt)dt+2dBtd x_t = - \nabla p(x_t) dt + \sqrt{2}d B_t

確率密度の勾配が下がる向きに進むことで、欲しいサンプリングを実現できる。

これを離散化すると、以下のような漸化式になる。

xt+1=xtp(xt)dt+2dBtx_{t+1} = x_t - \nabla p(x_t) dt + \sqrt{2}d B_t

これは既知の確率密度の勾配がわかっているとき、これに従って動くと正しくサンプリングできるってこと。

Diffusion Modelの学習で使うわけではない。

Ornstein-Uhlenbeck Process

オルンシュタイン=ウーレンベックとよむ。OU過程。

Diffusionでは基本的に、OU過程の順過程と逆過程をたどることで学習させる。

Langevin力学の特別な場合がOU過程。ポテンシャル関数p(x)p(x)が二次関数の時らしい。

なぜポテンシャルが二次関数の時こうなるのか

ポテンシャル関数p(x)=θ(xμ)2/2p(x) = \theta(x - \mu)^2/2とすると、勾配を計算すると確かにこの形になる。

ブラウン運動に、今の座標に応じた何かしらの点へ戻るバネのような力が加わったものである。

dxt=θ(xtμ)dt+σdBtdx_t = -\theta(x_t - \mu) dt + \sigma dB_t

性質として以下のようなものがある。

  • 平均回帰性 時間がたつにつれて、平均値μ\muに収束していくし、確率分布も一定の分散のもとで安定する。
  • ガウス過程 ランダム性を持つ過程だが、各時刻ttにおける状態xtx_tはガウス分布に従う。

拡散モデルによる学習

流れは

  1. 前進過程により、何かしらの過程に従い拡散する
  2. 逆過程により、前進過程の動きを復元するように、重要な関数をDNNでモデリングして学習する。
  3. 新たなものを生成するときは、ガウシアンに従うノイズを元に学習したDNNを使い、逆過程でどんどん戻していく。

前進過程

先ほどのOU過程に従い、ノイズを加えつつ拡散をしていく。OU過程の性質によって最終的に安定した分散を持ち期待値が0のガウス分布になる。

dxt=xtdt+dBtd x_t = -x_t dt + dB_t

これを漸化式にすると以下のようになる。ηt\eta_tはノイズ。

xt+1=xtxtΔt+Δtηtx_{t+1} = x_t - x_t \Delta t + \sqrt{\Delta t} \eta_t

逆過程

OU過程の逆過程は以下のようになる。

なぜlogp(x)\nabla \log p(x)があるのかというと、真の分布p(x)p(x)はガウス分布にゆくゆくなるから。p(x)p(x)がガウス分布の形ならば、logp(x)=(xμ)\nabla \log p(x) = -(x - \mu)になり、確かに復元する項である。

dxt={12xt+logp(xt)}dt+dBtd x_t = -\{\frac{1}{2} x_t + \nabla \log p(x_t) \} dt + dB_t

ここでは、xt/2-x_t/2と原点へ戻る力を半分にして、logp(xt)\nabla \log p(x_t)に復元を任せる感じ。

これを漸化式にすると、以下のようになる。これを満たすように、sθs_\thetaを学習していく。

xt=xt+112xtΔt+sθ(xt+1,t)Δt+Δtηtx_t = x_{t+1} - \frac{1}{2} x_t \Delta_t + s_\theta(x_{t+1}, t) \Delta t + \sqrt{\Delta t} \eta_t