The Fokker Planck Equation and Diffusion Models¶
In this tutorial, we continue our ride into the theory of diffusion models and stochastic differential equations (SDEs). We are going to derive the Fokker-Planck equation, one of the fundamental equations of stochastic analysis that expresses SDEs as partial differential equations (PDEs). This equation then allows us to describe our backward sampling process in diffusion models as an ordinary differential equation (ODE). In other words, we can sample from diffusion models by only sampling noise once at the start and iteratively applying a deterministic function. Such a sampling process is now widely used in generative models such as stable diffusion.
More specifically, we are going to learn:
- What it is: What the Fokker-Planck equation actually says.
- Intuition: What the Fokker-Planck equation describes intuitively. This involves understanding the concept of divergence.
- Ito's Lemma: The fundamental building block for deriving the Fokker-Planck equation.
- Derivation: The actual derivation of the Fokker-Planck equation.
- Application to diffusion models We show how this theory can applied to diffusion models to enable the ODE formulation of diffusion models.
Prerequisites. This tutorial is part of a tutorial series on diffusion models. We recommend reading the blogpost about diffusion models with SDEs before this tutorial.
1. What is the Fokker-Planck Equation?¶
Setting. Let's consider the Ito stochastic differential equation (Ito-SDE) with drift coefficients $f:\mathbb{R}^d\times\mathbb{R}\to\mathbb{R}^d$ and diffusion coefficients $g:\mathbb{R}^d\times\mathbb{R}\to\mathbb{R}^{d\times d}$: $$\begin{align*} dX_t = f(X_t,t)dt + g(X_t,t)dW_t \end{align*}$$ and some start distribution $X_0\sim \mu$. Let's define the distribution $p_t(x)$ to be the distribution of $X_t$, i.e. $X_t\sim p_t$. In other words, there is a probability flow $p:\mathbb{R}^d\times\mathbb{R}\to\mathbb{R}_{\geq 0}$ changing over time. Let's write $p_t(x) = p(x,t)$.
Focker Planck Equation. The Focker Planck equation describes how the distribution $p_t$ changes over time. The significance of the Focker Planck equation is that the change of a probability density over time can solely be expressed as derivatives over coordinates $x$:
$$ \begin{align*} \frac{d}{dt}p_t(x) =& -\sum\limits_{i=1}^{n}\frac{\partial}{\partial x_i}\left[f_i(x,t)p_t(x)\right] + \frac{1}{2}\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{n}\frac{\partial^2}{\partial x_i\partial x_j}\left[[g(x,t)g(x,t)^T]_{ij}p_t(x)\right]\\ =& -\nabla\cdot [fp_t]+\frac{1}{2}[\nabla^2\cdot [gg^Tp_t]]\\ \end{align*} $$
where we use the notation $\nabla \cdot$ (resp. $\nabla^2\cdot$) to describe the sum of all pairwise partial derivatives (resp. second partial derivatives) of the corresponding components. In physics, one also calls $\nabla \cdot [fp_t]$ the divergence of $fp_t$.
2. Let's intuitively understand the equation¶
In this chapter, we are going to derive an intuitive understanding of the Fokker Planck formula. As we will see later, one can find a technical proof for it but I think there is a lot of value in understanding what the formula actually says intuitively. It basically boils down to understanding the divergence $\nabla\cdot$ and the continuity equation of an incompressible flow. We go through both concepts step-by-step.
Let us look at the Fokker Planck equation again. The equation describes how an infinitesimal probability, i.e. a density $p_t(x)$, changes over time. The change can be described by the net inflow of a probability. If probability mass moves to $x$, then $\frac{d}{dt}p_t(x)>0$. If it moves away from $x$, then $\frac{d}{dt}p_t(x)<0$. In essence, we know that $$ \begin{align*} \frac{d}{dt}p_t(x) = \text{"Net inflow of change of probability mass to point }x\text{"} \end{align*} $$ However, what is that net inflow? In physics, this is called divergence. And we are going to explain in the next chapter how that concept is derived. If you are already very familiar with divergence, this section can safely be skipped.
2.1. A derivation of the divergence formula¶
Before we get to deep into SDEs, let's first understand divergence in a simpler setting. Let's assume that we are given a vector field $F:\mathbb{R}^d\to\mathbb{R}^d$. Let's try to describe the infinitesimal flux around $x_0$. When we say "infinitesimal", we mean "in the limit". We define it in the limit because it is not clear apriori how to define the flux in and out of a single point $x_0$. However, we have an intuitive idea of the flux in and out of a ball around an $\epsilon$-ball $B(x_0,\epsilon)$ around $x_0$. We can describe it by the formula: $$ \begin{align*} \text{Flux}_{\epsilon}(F)(x)=\frac{1}{\text{Vol}(B(x_0,\epsilon))}\int\limits_{\partial B(x_0,\epsilon)}F^T(x)\frac{x-x_0}{\|x-x_0\|}dS(x) \end{align*} $$ where the terms in the formula are:
-
Volumn: $\text{Vol}(B(x_0,\epsilon))$ is the volumn (i.e. Lebesgue-integral) of the epsilon-ball around $x_0$. We normalize by the volume.
-
Boundary: $\partial B(x_0,\epsilon)$ is the boundary (i.e. the surface) of the epsilon-ball. We only integrate over the boundary of the epsilon ball because that's where something can flow in or out.
-
Normal vector: the vector $x$ describe a point on the surface and $\frac{x-x_0}{\|x-x_0\|}$ is the normalized unit vector from pointing from $x_0$ to $x$.
-
Vector flow: the value $F(x)$ gives the flow at $x$. However, not all of the "flow" points inwards and outwards of the epsilon-ball. Therefore, we decompose the value of $F(x)$ by: $$\begin{align*} F(x) =:& \left(F(x)-F(x)^T\frac{x-x_0}{\|x-x_0\|}\frac{x-x_0}{\|x-x_0\|}\right)+F(x)^T\frac{x-x_0}{\|x-x_0\|}\frac{x-x_0}{\|x-x_0\|}\\ =:&v_{\perp}+v_{\parallel} \end{align*}$$ for vectors $v_{\perp}$ and $v_{\parallel}$. We know that $v_{\perp}$ is tangential to the surface of $B(x_0,\epsilon)$ ball because it is orthogonal to the normal - this can be seen by computing $(x-x_0)^Tv_{\perp}=0$. Therefore, the direction $v_{\perp}$ has a net influx/outflow of zero. Therefore, the flux is simply given by the vector $v_{\perp}$ that points in the direction of the normal with length $$\begin{align*} F(x)^T\frac{x-x_0}{\|x-x_0\|} \end{align*}$$ Note that the length has a sign, it is negative if the perpendicular flow points inwards.
-
Integral: Finally, we take the integral over all these flow values to get the total value of the inward/outward flow. We write $dS(x)$ to describe the surface integral over the surface $\partial B(x_0,\epsilon)$. If you are not yet comfortable with surface integrals, it actually does not really matter. Our proofs works without having ever computed a surface integral and you can simple think of it as a integral over a uniform space over the surface.
Divergence-Infinitesimal outflux. Let's rewrite the $\epsilon$-Flux around $x_0$ by setting $x-x_0=w$ for a unit vector $w$ and center the function flow:
$$ \begin{align*} \text{Flux}_{\epsilon}(F)(x)=&\frac{1}{\text{Vol}(B(x_0,\epsilon))}\int\limits_{\partial B(x_0,\epsilon)}F^T(x)\frac{x-x_0}{\|x-x_0\|}dS(x)\\ =&\frac{1}{\text{Vol}(B(0,\epsilon))}\int\limits_{\partial B(0,\epsilon)}F^T(x_0+w)\frac{w}{\|w\|}dS(w) \end{align*} $$
Next, let's use the following properties for volumes and surface areas: $$\begin{align*} \text{Vol}(B(0,\epsilon))&=\epsilon^{d}\text{Vol}(B(0,1))\\ \text{S}(\partial B(0,\epsilon))&=\epsilon^{d-1}\text{S}(\partial B(0,1)) \end{align*}$$ and simply the equation further $$ \begin{align*} \text{Flux}_{\epsilon}(f)(x) =&\frac{1}{\text{Vol}(B(0,\epsilon))}\int\limits_{\partial B(0,\epsilon)}F^T(x_0+w)\frac{w}{\|w\|}dS(w)\\ =&\frac{1}{\epsilon^{d}\text{Vol}(B(0,1))}\epsilon^{d-1}\int\limits_{\partial B(0,1)}F^T(x_0+\epsilon z)zdS(z)\\ =&\frac{1}{\epsilon\text{Vol}(B(0,1))}\int\limits_{\partial B(0,1)}F^T(x_0+\epsilon z)zdS(z) \end{align*} $$
Let's use a first-order Taylor approximation to get: $$F(x_0+\epsilon z) \approx F(x_0) + \epsilon DF(x_0)z$$ And therefore,
$$ \begin{align*} F(x_0+\epsilon z)^Tz \approx& F^T(x_0)z + [\epsilon DF(x_0)z]^Tz\\ =& F^T(x_0)z + \epsilon z^TDF(x_0)z\\ =& \sum\limits_{i=1}^{n}F_i(x_0)z_i + \epsilon\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{n}\frac{\partial}{\partial x_i}F_j(x_0)z_iz_j \end{align*}$$
Finally, let's substitute that in the integral equation of the flux to get: $$\begin{align*} \text{Flux}_{\epsilon}(F)(x) =&\frac{1}{\epsilon\text{Vol}(B(0,1))}\int\limits_{\partial B(0,1)}F^T(x_0+\epsilon z)zdS(z)\\ \approx &\frac{1}{\epsilon\text{Vol}(B(0,1))}\int\limits_{\partial B(0,1)} \sum\limits_{i=1}^{n}F_i(x_0)z_i + \epsilon\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{n}\frac{\partial}{\partial x_i}F_j(x_0)z_iz_jdS(z)\\ =&\frac{1}{\epsilon\text{Vol}(B(0,1))}\left[\sum\limits_{i=1}^{n}F_i(x_0)\int\limits_{\partial B(0,1)} z_idS(z) + \epsilon\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{n}\frac{\partial}{\partial x_i}F_j(x_0)\int\limits_{\partial B(0,1)}z_iz_jdS(z)\right] \end{align*} $$
Now, it is obvious that $$\begin{align*} &\int\limits_{\partial B(0,1)}z_idS(z) = 0 \end{align*}$$ as the average of a component is zero as a sphere is symmetric. In addition, we derive for an arbitrary $j$ $$ \begin{align*} S(\partial B(0,1))=&\int\limits_{\partial B(0,1)}1 dS(z)=\int\limits_{\partial B(0,1)}\sum\limits_{i=1}^{d}z_i^2dS(z)=\sum\limits_{i=1}^{d}\int\limits_{\partial B(0,1)}z_i^2dS(z)=d\int\limits_{\partial B(0,1)}z_j^2dS(z)\\ \end{align*} $$ where we used the symmetry of the sphere. Therefore, $$\begin{align*} \int\limits_{\partial B(0,1)}z_j^2dS(z) = \frac{1}{d}S(\partial B(0,1)). \end{align*}$$
For $i\neq j$, let $A$ be the function $A(z)=z_iz_j$ and $R$ the rotation matrix defined by: $$ \begin{align*} R_{kl} = \begin{cases} -1 &\text{for }k=i,l=j\\ 1 &\text{for }k=j,l=i\\ 1&\text{for }k=l, k\neq i,j\\ 0&\text{for }k\neq l, (k,l)\neq (i,j), (k,l)\neq (j,i)\\ \end{cases} \end{align*} $$ Then $A(R(z))=(-z_j)z_i=-A(z)$ and therefore, by a change of variable $$ \begin{align*} \int z_iz_jdS(z) &= \int A(z) dS(z) = \int A(Rz)dS(z) = \int (-z_j)z_idS(z) = -\int z_jz_idS(z) \end{align*} $$ which implies that $$ \begin{align*} &0=2\int z_jz_idS(z) \Rightarrow 0=\int z_jz_idS(z). \end{align*} $$ for $i\neq j$.
With that, we get the final formula: $$\begin{align*} \text{Flux}_{\epsilon}(F)(x) =&\frac{1}{\epsilon\text{Vol}(B(0,1))}\left[\frac{\epsilon S(\partial B(0,1))}{d}\sum\limits_{i=1}^{n}\frac{\partial}{\partial x_i}F_i(x_0)\right]+o(\epsilon)\\ =&\frac{1}{\text{Vol}(B(0,1))}\left[\frac{S(\partial B(0,1))}{d}\sum\limits_{i=1}^{n}\frac{\partial}{\partial x_i}F_i(x_0)\right]+o(\epsilon)\\ \end{align*} $$
Therefore, the flux is given by: $$\begin{align*} \text{div}(F)(x_0) = \lim\limits_{\epsilon\to 0}\text{Flux}_{\epsilon}(F)(x_0) =\frac{S(\partial B(0,1))}{d\text{Vol}(B(0,1))}\sum\limits_{i=1}^{n}\frac{\partial}{\partial x_i}F_i(x_0) \end{align*}$$
Getting rid of constant factors. Finally, we use the formula for the volume and surface area of the epsilon ball: $$\begin{align*} S(\partial B(0,1)) &= \frac{2\pi^{d/2}}{\Gamma(\frac{d}{2})}\\ \text{Vol}(B(0,1)) &= \frac{2\pi^{d/2}}{\Gamma(\frac{d}{2}+1))} \end{align*} $$ and the use the property of the gamma function $\Gamma(x+1)=x\Gamma(x)$ to get: $$\begin{align*} \frac{S(\partial B(0,1))}{d\text{Vol}(B(0,1))} =&\frac{\frac{2\pi^{d/2}}{\Gamma(\frac{d}{2})}}{d\frac{2\pi^{d/2}}{\Gamma(\frac{d}{2}+1))}} =&\frac{\Gamma(\frac{d}{2}+1)}{d\Gamma(\frac{d}{2})} = 1 \end{align*} $$ By plugging this into the above equation we finally get the desired formula.
Summary. In sum, we have derived that our physically intuitive net outwards flow is given by the divergence defined as: $$\text{div}(F)(x_0) = \sum\limits_{i=1}^{n}\frac{\partial}{\partial x_i} F_i(x_0)=[\nabla\cdot F](x_0)$$
2.2. The Fokker-Planck Equation as a Continuity Equation¶
Let us revisit our intuitive derivation of the Fokker-Planck equation applying what we just learnt about divergence: $$ \begin{align*} \frac{d}{dt}p_t(x)=&\text{"Net inflow of change of probability mass to point }x\text{"}\\ =&-\text{"Net outflow of change of probability mass to point }x\text{"}\\ =& -\text{div}(\text{"Change of probability mass at point }x\text{"})\\ =& -\nabla\cdot(\text{"Change of probability mass at point }x\text{"}) \end{align*} $$
In fact, we can also rewrite the Fokker-Planck equation in such a form as: $$ \begin{align*} \frac{d}{dt}p_t(x) =& -\nabla\cdot [fp_t+\frac{1}{2}\nabla^T[gg^Tp_t]]\\ \end{align*} $$
An equation in such a form is usually called a continuity equation on an incompressible fluid - incompressible means here simply that the probability cannot vanish but can also change its distribution throughout the space. We can see the two summands of $fp_t+\frac{1}{2}\nabla^Tgg^Tp_t$ have a very intuitive explanation:
- Drift component: the vector field $f(x,t)$ pushes $x$ into the direction of $f(x,t)$. However, $f(x,t)$ cannot be the changed of probability for the following reason: if $p_t(x)=0$, then the change is also zero - as there was no probability mass at $x$ in the first place. We should weigh it by $p_t(x)$ to account for that, i.e. $f(x,t)p_t(x)$ is the vector field describing the change of probability given by the deterministic drift. In other words, $$ \begin{align*} \text{"Change of probability mass at point }x\text{ by drift $f$"} = f(x,t)p_t(x). \end{align*}$$
- Diffusion component: the diffusion matrix $g(x,t)$ takes a small vector $\xi\sim\mathcal{N}(0,\mathbf{I}_d)$ and adds a small bit of noise $\sqrt{s}g(x,t)\xi \sim\mathcal{N}(0,sg(x,t)g(x,t)^T)$ where $s$ is a small step size. Therefore, the probability mass that is dispersed at point $x$ is proportional to the initial mass $p_t(x)$ and the rate of dispersion $g(x,t)g(x,t)^T$, i.e. $$\begin{align*} \text{"Dispersion of probability mass at }x\text{"} \propto p_t(x)g(x,t)g(x,t)^T \end{align*} $$ However, dispersion in itself does not change the probability mass as all neighboring points of $x$ also disperse probability mass with the approximately same rate. Therefore, we need to use the change of dispersion rate, i.e. this is described by the term: $$ \begin{align*} \text{"Change of probability mass }x\text{ by diffusion $g$"} = \frac{1}{2}\nabla[g(x,t)g^T(x,t)p_t(x)] \end{align*}$$ I acknowledge this is not a rigorous derivation (especially, where does the factor $\frac{1}{2}$ come from?) - but rather an intuitive explanation for the formula. We are going to derive it in the next chapter rigorously (e.g. you will see that the factor $\frac{1}{2}$ comes from the 2nd-order Taylor approximation).
In essence, the Fokker-Planck equation describes the change of probability mass at $x$ over time as the net inflow of change of probability mass at $x$ given by a vector field describing the change of probability mass given by the SDE equation.
Note: this nice interpretation of the Fokker-Planck equation came out of a discussion on stack overflow.
3. Ito's Lemma¶
In this chapter, we are going to derive Ito's Lemma. This is a fundamental building block in the theory of stochastic differential equations and is going to be a fundamental building block to derive the Fokker-Planck equation.
Let $X_t$ the stochastic process defined by the SDE above. Out of such a stochastic process, one can construct a new stochastic process $Y_t=\eta(X_t)$ via an arbitrary smooth function $\eta:\mathbb{R}^d\to\mathbb{R}$. A natural question that arises is: is the stochastic process $Y_t$ again described by an SDE and if yes, what are its drift and variation coefficients? The answer is yes, and Ito's Lemma describes the drift and variation coefficients of $Y_t$. Namely, Ito's lemma says that:
$$\begin{align*} dY_t = \bar{f}(X_t,t)dt + \bar{g}(X_t,t)dt \end{align*} $$ where $$ \begin{align*} \bar{f}(x,t) &= \nabla\eta(x)^Tf(x,t)+\text{Tr}(g(x,t)^T\nabla^2\eta(x)g(x,t))\\ \bar{g}(x,t) &= \sqrt{\nabla\eta(x)^Tg(x,t)g(x,t)\nabla\eta(x)}\\ \end{align*} $$
Note that the above SDE is slightly different than the ones we have seen so far. The left-hand side defines $Y_t$ but the right-hand side also depends on $X_t$.
Proof. We write $o(h^2)$ for an arbitrary function $g:\mathbb{R}\to\mathbb{R}$ such that $\lim\limits_{h\to 0}\frac{g(h)}{h^2}=0$. $$ \begin{align*} Y_{t+h} =& \eta(X_{t+h})\\ \approx& \eta(X_{t}+hf(X_t,t)+\sqrt{h}g(X_t,t)\epsilon)\\ =&\eta(X_t)+\nabla\eta(X_t)^T\left[hf(X_t,t)+\sqrt{h}g(X_t,t)\epsilon\right] +\frac{1}{2}\left[hf(X_t,t)+\sqrt{h}g(X_t,t)\epsilon\right]^T\nabla^2\eta(X_t)^T\left[hf(X_t,t)+\sqrt{h}g(X_t,t)\epsilon\right]+o(h^2)\\ =&Y_t+\nabla\eta(X_t)^T\left[hf(X_t,t)+\sqrt{h}g(X_t,t)\epsilon\right] +\frac{1}{2}h^2f(X_t,t)^T\nabla^2\eta(X_t)f(X_t,t)+\sqrt{h}h[g(X_t,t)\epsilon]^T\nabla^2\eta(X_t)f(X_t,t)+h\epsilon^Tg(X_t,t)^T\nabla^2\eta(X_t)g(X_t,t)\epsilon+o(h^2)\\ \end{align*}$$
It holds that: $$ \begin{align*} \bar{f}(x,t) =&\lim\limits_{h\to 0}\frac{1}{h}\left(\mathbb{E}[Y_{t+h}-Y_t|X_t]\right)\\ =&\lim\limits_{h\to 0}\mathbb{E}\left[\nabla\eta(X_t)^T\left[f(X_t,t)+\frac{1}{\sqrt{h}}g(X_t,t)\epsilon\right] +\frac{1}{2}hf(X_t,t)^T\nabla^2\eta(X_t)f(X_t,t)+\sqrt{h}[g(X_t,t)\epsilon]^T\nabla^2\eta(X_t)f(X_t,t)+\epsilon^Tg^T(X_t,t)\nabla^2\eta(X_t)g(X_t,t)\epsilon|X_t\right]\\ =&\nabla\eta(X_t)^Tf(X_t,t)+\text{Tr}\left(g(X_t,t)^T\nabla^2\eta(X_t)g(X_t,t)\right)\\ \end{align*} $$ where we just used the fact that $\epsilon$ is independent of $X_t$ and that $\mathbb{E}[\epsilon^TA\epsilon]=\text{Tr}(A)$ for an arbitrary matrix $A\in\mathbb{R}^{d\times d}$ and $\epsilon\sim\mathbb{R}^{d\times d}$.
Next, we get that: $$ \begin{align*} \bar{g}(x,t)^2=&\lim\limits_{h\to 0}\frac{1}{h}\mathbb{V}[Y_{t+h}-Y_t|X_t] \\ =&\lim\limits_{h\to 0}\frac{1}{h}\mathbb{V}\left[\nabla\eta(X_t)^T\sqrt{h}g(X_t,t)\epsilon +\sqrt{h}h[g(X_t,t)\epsilon]^T\nabla^2\eta(X_t)f(X_t,t)+h\epsilon^Tg(X_t,t)^T\nabla^2\eta(X_t)g(X_t,t)\epsilon\right|X_t]\\ =&\lim\limits_{h\to 0}\mathbb{V}\left[\nabla\eta(X_t)^Tg(X_t,t)\epsilon +h[g(X_t,t)\epsilon]^T\nabla^2\eta(X_t)f(X_t,t)+\sqrt{h}\epsilon^Tg(X_t,t)^T\nabla^2\eta(X_t)g(X_t,t)\epsilon\right|X_t]\\ =&\mathbb{V}\left[\nabla\eta(X_t)^Tg(X_t,t)\epsilon|X_t\right]\\ =&\nabla\eta(X_t)^Tg(X_t,t)[\nabla\eta(X_t)^Tg(X_t,t)]^T\\ =&\nabla\eta(X_t)^Tg(X_t,t)g(X_t,t)^T\nabla\eta(X_t)\\ \end{align*} $$
where we just the fact that $\mathbb{V}[A\epsilon]=\mathcal{N}(0,AA^T)$ for any matrix $A\in\mathbb{R}^{c\times d}$ and $\epsilon\sim\mathcal{N}(0,\mathbf{1}_d)$.
4. Deriving the Fokker Planck equation¶
Next, let's start deriving the Fokker-Planck equation and let's fix a point $x_0\in\mathbb{R}^d$. Let's define $\eta_k(x) =\mathcal{N}(x;x_0,\frac{1}{k^2})$ to be the Gaussian kernel centered around $x_0$. Then we know that for an arbitrary smooth and regular function $F:\mathbb{R}^d\to\mathbb{R}$ it holds that: $$\begin{align} F(x_0) = \lim\limits_{k\to\infty}\int \eta_k(x)F(x)dx \end{align} $$ In particular, it holds that: $$\begin{align*} p_t(x_0) =& \lim\limits_{k\to\infty}\int \eta_k(x)p_{t}(x)dx\\ =&\lim\limits_{k\to\infty}\mathbb{E}[\eta_k(X_t)] \end{align*} $$
By Ito's lemma, we know that $Y_t^{k}=\eta_k(X_t)$ follows the SDE $$\begin{align*} \frac{d}{dt}\mathbb{E}[\eta_{k}(X_t)]=& \mathbb{E}\left[\nabla\eta_k(X_t)^Tf(X_t,t)+\text{Tr}(g(t)^T\nabla^2\eta_k(X_t)g(t))\right] \end{align*} $$ (in expectation, one can remove the diffusion part of Ito's Lemma). This gives us:
$$ \begin{align*} &\frac{d}{dt}p_t(x_0)\\ =&\frac{d}{dt}\lim\limits_{k\to\infty}\mathbb{E}[\eta_k(X_t)]\\ =&\lim\limits_{k\to\infty}\frac{d}{dt}\mathbb{E}[\eta_k(X_t)]\\ =&\lim\limits_{k\to\infty}\mathbb{E}\left[\nabla\eta_k(X_t)^Tf(X_t,t)+\text{Tr}(g(X_t,t)^T\nabla^2\eta_k(X_t)g(X_t,t))\right]\\ =&\lim\limits_{k\to\infty}\int\left[\nabla\eta_k(x)^Tf(x,t)+\text{Tr}(g(X_t,t)^T\nabla^2\eta_k(x)g(X_t,t))\right]p_t(x)dx\\ =&\lim\limits_{k\to\infty}\left[\int\nabla\eta_k(x)^Tf(x,t)p_t(x)dx+\int\text{Tr}(g(X_t,t)^T\nabla^2\eta_k(x)g(X_t,t))p_t(x)dx\right]\\ =&\lim\limits_{k\to\infty}\left[-\int\eta_k(x)^T\nabla\cdot[f(x,t)p_t(x)]dx+\int\text{Tr}(g(X_t,t)^T\nabla^2\eta_k(x)g(X_t,t))p_t(x)dx\right]\quad\text{(using partial integration)}\\ =&-\nabla\cdot[f(x_0,t)p_t(x_0)]+\lim\limits_{k\to\infty}\int\text{Tr}(g(X_t,t)^T\nabla^2\eta_k(x)g(X_t,t))p_t(x)dx \end{align*}$$
Let's compute the second term: $$ \begin{align*} &\int\text{Tr}(g(x,t)^T\nabla^2\eta_k(x)g(x,t))p_t(x)dx\\ =&\sum\limits_{l=1}^{d} \int\left[g(t)^T\nabla^2\eta_k(x)g(t)\right]_{ll}p_t(x)dx\\ =&\sum\limits_{l=1}^{d} \int p_t(x)\sum\limits_{j=1}^{d}g_{jl}(t)[\nabla^2\eta_k(x)g(t)]_{jl}dx\\ =&\sum\limits_{l=1}^{d}\sum\limits_{j=1}^{d}\sum\limits_{i=1}^{d} \int \left[p_t(x)g_{jl}(t)g_{il}(t)\right]\frac{\partial^2}{\partial x_j\partial x_i}\eta_k(x)dx\\ =&\sum\limits_{l=1}^{d}\sum\limits_{j=1}^{d}\sum\limits_{i=1}^{d} \int \eta_k(x)\frac{\partial^2}{\partial x_j\partial x_i}\left[p_t(x)g_{jl}(t)g_{il}(t)\right]dx\quad\text{(using partial integration twice)}\\ \to& \sum\limits_{l=1}^{d}\sum\limits_{j=1}^{d}\sum\limits_{i=1}^{d} \frac{\partial^2}{\partial x_j\partial x_i}\left[p_t(x_0)g_{jl}(x_0,t)g_{il}(x_0,t)\right]\quad(\text{for }k\to\infty)\\ =& \sum\limits_{j=1}^{d}\sum\limits_{i=1}^{d} \frac{\partial^2}{\partial x_j\partial x_i}\left[p_t(x_0)[g(x_0,t)g^T(x_0,t)]_{ji}\right]\\ \end{align*} $$
This finishes the proof for the Fokker-Planck equation. In essence, the Fokker-Planck equation can be derived by applying Ito's Lemma on a kernel function $\eta_k$ that enclose an $\epsilon$-ball around a point $x_0$ for $\epsilon\to 0$.
5. Expressing SDEs via ODEs¶
Next, we are applying the Fokker-Planck equation to show that the result of every SDE can be achieved by also running an ODE - if we had full information. Let us revisit the SDE given by:
$$ \begin{align*} dX_t = f(X_t,t)dt + g(t)dW_t \end{align*}$$
For simplicity, we made a small change and made $g$ independent of the location - which is always case in standard diffusion models and many other settings.
A natural question is to ask: Could achieve the same effect by also running an ODE instead of an SDE as long as we have the same initial distribution?. More specifically, if we started with an initial distribution $X_0\sim p_0$, would there be an ODE of the form $$ \begin{align*} dZ_t = F(Z_t,t)dt \end{align*}$$ for some $F:\mathbb{R}^d\times\mathbb{R}\to\mathbb{R}^d$ such that if $Z_0$ and $X_0$ have the same distribution, then also $Z_t$ and $X_t$ have the same distribution? The answer is yes and we are going to show why now.
Let $X_t\sim p_t$ and let's assume that $Z_0\sim p_0$. Then compute:
$$ \begin{align*} \frac{d}{dt}p_t(x) =& -\sum\limits_{i=1}^{n}\frac{\partial}{\partial x_i}\left[f_i(x,t)p_t(x)\right] + \frac{1}{2}\sum\limits_{i=1}^{n}\sum\limits_{i=1}^{n}\frac{\partial^2}{\partial x_i\partial x_j}\left[[g(t)g(t)^T]_{ij}p_t(x)\right]\\ =& -\sum\limits_{i=1}^{n}\frac{\partial}{\partial x_i}\left[f_i(x,t)p_t(x)\right] + \sum\limits_{i=1}^{n}\frac{\partial}{\partial x_i}\sum\limits_{j=1}^{n}\frac{1}{2}\left[[g(t)g(t)^T]_{ij}\right]\frac{\partial}{\partial x_j}p_t(x)\quad\text{(derivatives are linear)}\\ =& -\sum\limits_{i=1}^{n}\frac{\partial}{\partial x_i}\left[f_i(x,t)p_t(x)\right] + \sum\limits_{i=1}^{n}\frac{\partial}{\partial x_i}\sum\limits_{j=1}^{n}\frac{1}{2}\left[[g(t)g(t)^T]_{ij}\right]\left[\frac{\partial}{\partial x_j}\log p_t(x)\right]p_t(x)\quad\text{(derivative of log)}\\ =& -\sum\limits_{i=1}^{n}\frac{\partial}{\partial x_i}\left[\left[f_i(x,t) -\frac{1}{2}\sum\limits_{j=1}^{n}\left[[g(t)g(t)^T]_{ij}\right]\left[\frac{\partial}{\partial x_j}\log p_t(x)\right]\right]p_t(x)\right] \quad\text{(derivatives are linear)}\\ =& -\nabla\cdot[Fp_t](x) \end{align*} $$ for $F$ defined by: $$ \begin{align*} F(x,t)= f(x,t)-\frac{1}{2}g(t)g(t)^T\nabla\log p_t(x)\quad\text{(SDE-2-ODE)} \end{align*} $$
In other words, we have shown that the effect of an SDE can be achieved by running an ODE. However, we would need to know the score function $\nabla\log p_t(x)$.
6. Application to diffusion models¶
Finally, we combine everything together to change the sampling process of diffusion models. Sampling in diffusion models works by running the backwards SDE given by: In other words, the reverse-time SDE is given by: $$\begin{align*} d\bar{X}_t= [f(X_t,t)-g^2(t)\nabla\log p_{t}(\bar{X}_t)]dt + g(t)d\bar{W}_t \end{align*}$$ where time is running backwards in this case (see this tutorial for an explanation). The functions $f,g$ are user-defined and the score $\nabla\log p_{t}$ is learnt during training. To convert this SDE into an ODE with the same final distribution, we use the formula above (SDE-2-ODE) to get: $$ \begin{align*} d\bar{X}_t= [f(X_t,t)-\frac{1}{2}g^2(t)\nabla\log p_{t}(\bar{X}_t)]dt \end{align*} $$ where we added $\frac{1}{2}g^2(t)\nabla\log p_t(x)$ - instead of substracted - from the drift coefficients as we run time backwards.
We added the option to sample via this ODE formulation to our diffusion model codebase by simply not adding noise and only running the ODE with the drift coefficient above.
import numpy as np
import matplotlib.pyplot as plt
from typing import Callable, List
from itertools import product
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
import seaborn as sns
from IPython.display import Video
import os
import pandas as pd
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import Dataset
from torch import Tensor
from abc import ABC, abstractmethod
from torch.nn.functional import relu
from torch.utils.data.dataloader import DataLoader
from tqdm.notebook import tqdm
import scipy.stats as st
from sde import VPSDE
from train import train_diffusion_model
from sampling import run_backwards
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1 import ImageGrid
from unet_conditional import Unet
torch._dynamo.config.suppress_errors = True
Load pre-trained MNIST classifier-free guidance model.
classes_by_index = np.arange(0,10).tolist()
token_variables = classes_by_index + [10]
image_size = 28
def load_mnist_model():
token_variables = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
model = Unet(token_variables=token_variables, base_dim=image_size, in_channels=1, out_channels=1, token_embedding_dim=256, time_embedding_dim=256, timesteps=100, dim_mults=[2, 4], temp=100.0)
model = torch.compile(model)
model_state_dict = torch.load("20231218_mnist_diffusion_denoiser_full_training.ckpt")
model.load_state_dict(model_state_dict)
return model
model = load_mnist_model()
6.1. Compare traditional vs ODE sampling¶
Define SDE
from sde import VPSDE
sde = VPSDE(T_max=1,beta_min=0.01, beta_max=10.0)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_SCORE = False
Sample with SDE and with ODE.
weight = 2.0
for ode_sampling in [False, True]:
single_target_shape = [8,1,image_size,image_size]
output_list = []
for target in classes_by_index:
x_start = torch.clip(torch.randn(size=single_target_shape),-1.0,1.0)
output,time_grid = run_backwards(model,sde,
x_start=x_start,
n_steps=1000,
device=DEVICE,
train_score=TRAIN_SCORE,
clip_min=-10.0,
clip_max=10.0,
clfree_guidance=True,
target=int(target),
unconditional_target=10,
clfree_guidance_weight=weight,
ode_sampling=ode_sampling)
output_list.append(output)
output_agg = torch.stack([output.transpose(1,0) for output in output_list],dim=1)
n_time_steps = output_agg.shape[0]
n_labels = output_agg.shape[1]
images_per_label = output_agg.shape[2]
time_idx = -1
fig = plt.figure(figsize=(12., 12.))
#fig, axs = plt.subplots(n_labels, images_per_label)
grid_idx = 0
grid = ImageGrid(fig, 111, # similar to subplot(111)
nrows_ncols=(n_labels, images_per_label), # creates 2x2 grid of axes
axes_pad=0.01)
for label in range(n_labels):
for image_idx in range(images_per_label):
grid[grid_idx].imshow(output_agg[-1,label,image_idx].squeeze(),cmap='grey')
grid[grid_idx].set_xticks([])
grid[grid_idx].set_yticks([])
grid[grid_idx].set_ylabel(f"{label}",fontsize=16)
grid_idx += 1
plt.savefig(f"diffusion_loss_odesampling={ode_sampling}.png")
plt.show()
break
In sum, we see that there is no big difference if we sample by running the SDE or the ODE using the Fokker-Planck equation. Awesome!