Understanding the Denoising Diffusion Probabilistic Model, the Socratic Way | by Wei Yi | Feb, 2023
A deep dive into the motivation behind the denoising diffusion model and detailed derivations for the loss function
The Denoising Diffusion Probabilistic Models by Jonathan Ho et. al. is a great paper. But I had difficulty understanding it. My colleagues told me they were also left confused after reading it. So I decided to dive into the model and worked out all the derivations. In this article, I will focus on the two main obstacles to understand the paper:
- why the denoising diffusion model is designed in terms of the forward process, the forward process posteriors (which I will call the reverse of the forward process to avoid the word “posteriors” because it confuses me) and backward process. And what is the relationship among these processes?
- how to derive the mysterious loss function. In the paper, there are many skipped steps in deriving the loss function Lₛᵢₘₚₗₑ. I went through all derivations to fill in the missing steps. Now I realize the derivation of the analytical formula for Lₛᵢₘₚₗₑ tells a truly beautiful Bayesian story. And after all the steps filled in, the whole story is easy to understand.
Medium supports Unicode in text. This allows me to write many math subscript notations such as x₀ and xₜ. But I could not write down some other subscripts. For example:
For those things, I will use an underscore “_” to lead the subscriptions, such as x_T, and p(x_0:T).
If some math notations render as question marks on your phone, please try to read this article from a computer. This is a known Unicode rendering issue.
Our goal is to use a neural network to generate natural images from noise. The input to the neural network is noise, and the output should be a natural image, such as a human face. Different noises will result in different natural images, for example, one noise may lead to a woman’s face, another noise to a man’s.
You may ask, what kind of noise? Without other constraints, a sensible researcher who is in love with Bayesian method will start with a Gaussian noise.
What is the dimensionality of this noise? Well, the desirable output is a colorful 2D image with red-green-blue (RGB) values. Let’s simplify it by first transform a colorful image into grayscales between [0, 255] and then scale the grayscales to the range of [-1, 1]. And then reshape this 2D array of scaled grayscale values into a long 1D vector, with length d. I will mention the name d multiple times in the article. Let’s use the above as our easy definition of the image generation task. But please know that in reality, neural networks can generate colorful images directly.
It is natural to assume the dimension and structure of the input noise is the same as the dimension and structure of the output image, which is a vector of length d. So the noise should be a d-dimensional multivariate standard Gaussian N(0, 1) — that’s the academic default.
Now the task of generating images from noise is more concrete: design a neural network that takes a sample from a d-dimensional multivariate standard Gaussian and outputs a d-dimensional vector of scaled grayscale values. Turning the output vector into a 2D shape and RGB colors is something we all know how to do, and not of interest of this article.
Generating an natural image from noise in one step is difficult. How about generating an image in many smaller steps? Sort of like to let an image emerge from a Kodak film in old fashion photography. This way, in each step, the neural network should have a simpler task, as the input and output in each step is more similar to each other than from pure noise to a final natural image.
This iterative generating idea comes with its own problem. What should the in-between images look like? A person old enough (like me) to have experience with old fashion photography would suggest that the in-between images should be gradual — it should not be the case that during this iterative process, an image of a cat first appears, and then the cat turns into a human face.
The “gradual-ness” constraint over in-between images is sensible. But how to formulate it mathematically?
Foreword process turns a natural image into noise
Even though it is not clear how to formulate the gradual-ness of the iterative generation process, it is easy to formulate the opposite process — the process that turns a natural image into pure noise by successively adding a little bit of Gaussian noise into it.
The process of turning a natural image into pure noise by adding successive noise to it is called the forward diffusion process, or forward process in short.
Reverse process turns noise into a natural image
On the other hand, we call the process of turning a Gaussian noise into a natural image the reverse process.
The following figure from the paper depicts these two processes, with the forward process at the bottom, and the reverse process at the top.
In the figure x₀, x₁, x₂,…, x_T are d-dimensional multivariate Gaussian random variables. x₀ is corresponding to the natural image, which means a sample drawn from x₀’s probability density function, which I will get to later, should look like a natural image. And x₁, x₂,…, x_T are corresponding to the in-between denoised images. I will also introduce their probability density functions later.
To remember that x₀ is the natural image and x_T is pure noise, and not the other way round, please remember: small subscript, small noise, big subscript, big noise. I got this idea from this video.
Now there is a way to mathematically define the gradual-ness of in-between images. Every image xₜ generated from the reverse process must be close to the corresponding image diffused from the forward process.
The forward process is a probabilistic model. Why? Because every step adds a Gaussian noise into an image. So the result is not deterministic — starting from the same natural image x₀, you may end up with different samples of standard multivariate Gaussian noise x_T. Just like dipping a drop of ink into a glass of water at different times will give you different diffusions each time.
In this probabilistic model, x₀, x₁,… to x_T are random variables. Each of them is a d-dimensional random variable.
Since the forward process is probabilistic, the appropriate mathematical tool to talk about it is probability density function and probability theory.
The above figure uses q(xₜ|xₜ₋₁) to denote the probability density function for a single step from image xₜ to image xₜ₋₁ in the forward diffusion process. We define its probability density function as:
with βₜ being a value that changes over time, much like scheduling learning rate.
Using the reparameterization trick (see derivation here) the random variable xₜ can be equivalently described as:
where ϵₜ₋₁ is d-dimensional standard Gaussian noise. This formula reveals that the more noisier image xₜ is a weighted average between the less noisier image xₜ₋₁ and some noise ϵₜ₋₁. In other words, the forward process adds noise ϵₜ₋₁ into the less noisier image xₜ₋₁. The value of βₜ controls the amount of noise to add on timestamp t. That’s why βₜ is scheduled to be really small values, from β₁=10⁻⁴ to β_T=10⁻². And T is set to 1000 — otherwise, noise will quickly dominate the forward process.
q(xₜ₋₁|xₜ) is the probability density function for the random variable xₜ₋₁, which describes a single step in the forward process. The following joint probability density function describes the full forward process:
This is the first factorization of the forward process q(x_1:T|x₀).
Why does the above joint probability density function of the forward process q(x_1:T|x₀) depend on the random variable x₀? This is because when t=1, q(xₜ|xₜ₋₁) turns into q(x₁|x₀), which mensions x₀.
The term q(x₁|x₀) also reinforces the fact that x₀ is a random variable. After all, the notation q(x₁|x₀) in probability theory denotes the probability density for the random variable x₁ given the random variable x₀.
But do we know the probability density function q(x₀) for x₀? No, we don’t. q(x₀) describes the probability of natural images. This means:
- given a natural image, say X₀, plugging X₀ in q(x₀=X₀) should return a probability number between 0 and 1 to indicate how likely this natural image occurs in all natural images.
- Summing up the probability numbers from q(x₀) for all natural images gives us 1.
Obviously, as usual, we don’t know the analytical formula for q(x₀). But this does not prevent us from writing its notation down, and from drawing samples for the random variable x₀— we just randomly pick an image from our training set of natural images.
The forward process consists of T+1 random variables x₀, x₁ to x_T. They form two groups:
- x₀: there are observations to the random variable x₀. The observations are the actual images from the training dataset. We call x₀ observational random variable.
- x₁ to x_T: there is no observation for them, hence they are latent random variables.
The definition of the forward process brings three important properties.
Property 1: Fully joint probability density function q(x_0:T)
Joint probability density function represents trajectories. Visually, the fully joint probability density function q(x_0:T) describes the set of possible trajectories of images. Each trajectory consists of T+1 images with x₀ representing a noise-free image, and x_T representing a pure noise image.
The following illustration shows some trajectories that start from two natural images X₀ and X₁. These trajectories end at different pure noise images, indicating that the forward process is a probabilistic. It is an illustration because I hand-drew the picture — it is not necessary that the trajectories starting from X₀ do not overlap with the trajectories starting from X₁.
The illustration also shows that at a timestamp t, the random variable xₜ is responsible to explain all possible images that could be generated by the forward process at this timestamp.
Property 2: Marginal probability density function q(xₜ|x₀)
Using the reparameterization trick repeatedly (see derivation here again) give us the probability density functions for only a single latent random variable without the dependence on another latent random variable, hence the resulting probability density is called marginal:
with
This property reveals that given x₀, the latent random variable xₜ does not depend on the latent random variable xₜ₋₁ anymore. In other words, given x₀, the latent random variables x₁ to x_T are independent to each other.
For independent random variables a and b, the product rule in probability theory is p(a, b) = p(a) p(b). Applying the product rule gives us the second factorization of q(x_1:T|x₀):
The two factorizations are equivalent, meaning they describe the same joint probability distribution for the same set of random variables. Sometimes we will choose one over the other to simplify formula derivations.
Property 3: The reverse of the forward process q(xₜ₋₁|xₜ, x₀)
Using the Bayes rule, it is possible to derive the probability density function for the reverse of the forward process. In the paper, reverse of the forward process is called forward process posterior. But I found the word “posterior” leads to confusing in the context of many conditionals in this article.
Let’s start from the forward process q(xₜ|xₜ₋₁) which is a probability density function for the random variable xₜ:
Note that we can add the redundant conditioned random variable x₀ into q(xₜ|xₜ₋₁) to turn it into q(xₜ|xₜ₋₁, x₀) because by definition, given xₜ₋₁, the random variable xₜ doesn’t depend on any other random variable. So adding x₀ as a dependence doesn’t change the probability density for xₜ.
However, the random variable xₜ in q(xₜ|x₀) and the random variable xₜ₋₁ in q(xₜ₋₁|x₀) are truly depend on x₀. We know the formula q(xₜ|x₀) and q(xₜ₋₁|x₀) from the above property 2 of the marginal probability density.
By rearranging the terms, we can derive the probability density function for the reverse of the forward process:
First, please notice that when t=1, q(xₜ₋₁|xₜ, x₀) turns into q(x₀|x₁, x₀), with x₀ appearing in both left and right of the bar “|”. The quantity q(x₀|x₁, x₀) always evaluates to 1 because given x₀, the probability of x₀ is 1. Why? Because x₀ has already happened, there is no uncertainty about it anymore.
Bear this in mind, and you will smile when you see later that the first analytical loss function starts with t=2. And you will smile again at the end, when we expand the loss function to cover the case of t=1.
The left hand side of the equation q(xₜ₋₁|xₜ, x₀) tells us this is a process that takes a noisier image xₜ and generates an less noisy image xₜ₋₁ — remember large subscript, more noise, small subscript, less noise. So q(xₜ₋₁|xₜ, x₀) describes the process that goes in the opposite direction as the forward process. We call it the reverse of the forward process.
The right hand side of the equation tells us that q(xₜ₋₁|xₜ, x₀) is defined by the forward process probability density q(xₜ|xₜ₋₁), rescaled by q(xₜ₋₁|x₀)/q(xₜ|x₀). We have defined all these three components. If you plug them into the right hand side of the equation, and patiently simplifying the formula, you will see that q(xₜ₋₁|xₜ, x₀) is a multivariate Gaussian distribution, which is fully specified by its mean vector and covariance matrix. Mathematically, we have:
with
We can see the mean vector for xₜ is a weighted sum between x₀ and xₜ. And the weights in front of these two random variables depend on the timestamp t. The covariance matrix is a quantity that also depends on the timestamp t. Neither the mean vector or the covariance matrix mention trainable parameters (we will introduce trainable parameters in a bit later).
The reverse of the forward process is important because it describes the generation process that we exactly want — a process that gradually turns noise into a natural image. This importance is reflected later in the loss function derivation.
Where does the reverse of the forward process start?
The q(xₜ₋₁|xₜ, x₀) distribution defines the mean vector and the covariance matrix for the random variable xₜ₋₁, given xₜ and x₀. We can get an sample of x₀ by randomly picking a natural image from the training set, but how to get a sample of xₜ? Inductively, we can sample from q(xₜ|xₜ₊₁, x₀). But how to sample the last, that is, the noisiest random variable x_T? Note x_T is actually the starting point of the reverse of the forward pross. Well, we can sample from q(x_T|x₀), which is defined in Property 2. Looking at the definition of q(xₜ|x₀) again:
we realize when t=T, a large number, say 1000, the mean vector is almost 0, and the covariance matrix is almost I. In other words, x_T follows a standard multivariate Gaussian distribution. In other words, the reverse of the forward process essentially starts from pure noise.
Where does the reverse of the forward process end?
Another burning question is can the reverse of the forward process eventually generate the original image X₀? Yes, in expectation, the reverse of the forward process that conditioned on x₀=X₀, will generate X₀. “In expectation” means if run the reverse of the forward process which is conditioned on x₀=X₀ many many times, and average the final images at the end of these trajectories, we will get exact X₀.
How does the reverse of the forward process do that? We can know the answer by staring at the mean vector of q(xₜ₋₁|xₜ, x₀):
At each step, the above mean vector for the random variable xₜ₋₁ contains a bit of x₀=X₀, and a bit of the random variable xₜ, which transitively includes a bit of x₀=X₀ from its mean vector µₜ₊₁(xₜ₊₁, x₀). This way, step by step (from large t to small t), the random variables from the reverse of the forward process converges to x₀ more and more.
At the top of the Figure is the reverse process p(xₜ₋₁|xₜ). The reverse process takes a noiser image xₜ and generate a new image xₜ₋₁ that contains less noise, same as the reverse of the forward process q(xₜ₋₁|xₜ, x₀). Note the figure uses notation p_θ, but I decided to use p because p_θ doesn’t look good in Medium.
The reverse process must contain the same set of random variables x₀, x₁ to x_T as in the reverse of the forward process because we want to establish the “gradual-ness” correspondence between variables between the reverse process and the reverse of the forward process. In other words, since the reverse of the forward process already tells us how to gradually remove noise from images step by step, we want our neural network to mimic that at every step. The way we convey the stepwise mimicking requirement is to require the random variable xₜ in the reverse process to behave like the corresponding random variable in the reverse of the forward process.
Since the reverse of the forward process is defined using multivariate Gaussian distributions, it makes sense to also define the reverse process using multivariate Gaussian distributions:
In the inductive formula, p(xₜ₋₁|xₜ), the mean vector μₚ(xₜ, t) and the covariance matrix Σₚ(xₜ, t) are actually two deep neural networks that predicts the d-dimensional mean and the d×d dimensional covariance matrix for the multivariate Gaussian distribution.
In the base formula, p(x_T) is a standard multivariate Gaussian, which confirms that the reverse process starts from pure noise. Note, the starting point of the reverse process is not conditioned on any random variable, unlike the case of the reverse of the forward process.
Both neural networks μₚ(xₜ, t) and Σₚ(xₜ, t) takes two inputs, the first is the noiser image xₜ, and the second is the timestamp t. The noiser image xₜ makes sense. After all, we want to use the neural networks to denoise the noiser image. But how to understand the timestamp t as an input to the neural networks? The intention is the same as in the Transformer model for natural language processing, which uses cosines to encode the position of a word in a sentence and feeds the encoded position as additional input to the Transformer. Here we also want to encode where we are in the reverse process as additional input to the neural networks to give them a bit of positional context.
How does a neural network predicts a d-dimensional mean vector and the d×d covariance matrix? For the mean vector, the mean predicting neural network will have d output heads, each predicting an entry in the d-dimensional mean vector. The covariance matrix predicting neural network has d×d output heads. This is a rough understanding. With a large d, the number of neural network outputs, especially for the covariance predicting neural network, is huge. There are more concise ways for the covariance matrix, see the mean-field parameterization from here.
μₚ(xₜ, t) and Σₚ(xₜ, t) contain model parameters
The weights inside the mean vector predicting μₚ(xₜ, t) network and the covariance matrix predicting network Σₚ(xₜ, t) are the model parameters in this machine learning task. We want to use optimization to find proper values for those model parameters so when starting from a noise sample from p(x_T), and iteratively sample xₜ₋₁ from the distribution p(xₜ₋₁|xₜ), and when we retrieve a sample for x₀ from the distribution p(x₀|x₁), this sample of x₀ is the scaled grayscale of a realistic looking natural image.
Another question you may have is that should we use the same mean neural network to predict the mean vector for all timestamp? Same question for the covariance matrix prediction network. Well it is a design choice. At least one network is needed, and the authors experimentally showed one network is enough. You can have two or more, at the expense of more parameters to learn.
Like the case in the forward process, the joint probability density function p(x_0:T) also represents the set of image trajectories that the reverse process can generate, by starting from some pure Gaussian noise.
Why do we need the reverse process p(xₜ₋₁|xₜ)? Isn’t the reverse of the forward process q(xₜ₋₁|xₜ, x₀) enough?
Since we already know the distribution q(xₜ₋₁|xₜ, x₀) for the reverse of the forward process, which de-noise images, you may wonder, why is the reverse process p(xₜ₋₁|xₜ) even needed? Why can’t we directly sample natural images from q(xₜ₋₁|xₜ, x₀)?
Of course we can. But please look at q(xₜ₋₁|xₜ, x₀) closely. xₜ₋₁ not only depends on xₜ, it also depends on the initial image x₀. This means we have to know an initial image to start sampling and the goal is to finally sample an image that is similar to the already known x₀. This is not what we want. We want to be able to sample natural images freely!
At this point, you don’t want to stop. You may ask, can we work out the analytical formula for q(xₜ₋₁|xₜ), that is, the reverse of the forward process without the dependence on x₀? Let’s do it by using the Bayes rule again:
Now we can spot the trouble: on the right hand side of the equation, q(xₜ|xₜ₋₁) is defined, but q(xₜ₋₁) and q(xₜ) are not defined. So it is a dead end.
You still wouldn’t stop, and you ask, why can’t we define q(xₜ)? I hear you! Let’s try to define q(xₜ). Conceptually, q(xₜ) represents the possible set of images the forward process can generate at timestamp t. Thinking about it from the trajectory point of view, see again below, the images that the forward process can generate at time t depends on where the trajectories start (natural image X₀, X₁, etc.). So the probability density function of q(xₜ) will inevitably reference the probability density function of the starting point, that is q(x₀).
What is q(x₀)? It is the probability density of the training data. Unfortunately, previously we have already made it clear that q(x₀) is unknown. The best we can do is to sample from it by randomly picking natural images from our training set. Consequently, we will not be able to write down the analytical formula for q(xₜ).
The reverse of the forward process q(xₜ₋₁|xₜ, x₀) smartly defines the probability density function for the random variable xₜ₋₁ to condition on x₀. Conditioning on x₀ allows us to plug in a sample for x₀ to reason about properties of xₜ₋₁. As long as we can sample from x₀, which we can, and reason about xₜ₋₁ in an expectation fashion with respect to the samples from x₀, we are good. For more details about “reasoning a random variable in an expectation fashion”, see sampling-averaging below.
Can mimicking the reverse of the forward process gives us a reverse process that can start from any multivariate Gaussian noise?
We’ve established that the starting point of the reverse of the forward process are pure Gaussian noise already. By mimicking the the behavior of the reverse of the forward process that denoises many Gaussian noises into natural images from the training set, our unconditioned denoising model, which is the reverse process p(xₜ₋₁|xₜ) should be capable to turn any Gaussian noise into a realistic looking natural images. Just like if a linearly regressed line path through many data points, we would expect the line to interpolate to other unseen data points along the same direction.
With the structure of the reverse process defined and its necessity clearly explained, now it is the time to think about the objective function we minimize to perform parameter learning for the mean vector predicting network μₚ(xₜ, t) and the covariance matrix predicting network Σₚ(xₜ, t).
For a probabilistic model, the likelihood of the data is always a good starting point to think about the objective function. Let’s define what “ likelihood of the data” means for our model.
The joint probability density function, with data plugged in
The joint probability density of the reverse process p(x₀, x₁⋯, x_T), or shorthanded as p(x_0:T) is a function has T+1 random variables as its arguments, namely x₀, x₁ to x_T. Being a probability density function, it evaluates to a probability number between [0, 1] when concrete values are plugged into its arguments.
The purpose of a probabilistic model is to explain training data well. “explaining the training data well” means images in the training dataset evaluate to a high probability number when they are plugged into the x₀ argument, one image at a time.
Plugging an image X₀ into the argument x₀ of the joint probability density function p(x_0:T), which has T+1 random variable, results in a new function with T random variables: p(x₀=X₀, x₁⋯, x_T). This function cannot be evaluated into a probability number yet, because it mentions random variables x₁ to x_T, which are not concrete values. x₁ to x_T are latent random variables, there is no observations for them, so we cannot find some meaningful concrete values (like the case for the observational random variable x₀) to plug in for them. They need to be removed, or more precisely, integrated away.
The likelihood p(x₀)
By definition, a random variable, say, x₁, describes a spectrum of possible values. The go-to way to remove a random variable from a probability density function is to compute the expected value of the density function with respect to that random variable. In other words, to remove the random variable x₁ from p(x₀, x₁⋯, x_T), compute the average value, or alternatively, expectation, of this function with respect to x₁. Essentially we are saying since we cannot observe concrete values for latent random variables, we have to reason about their average behavior.
Let’s pick x₁ to integrate away first. Since x₁ is a continuous random variable, its expectation is defined by an integration, hence the name “integrating a random variable away”:
The same “integrating away” approach, applied T times, can remove all latent random variables from p(x₀, x₁⋯, x_T):
p(x₀) now only describes how likely actual images can be generated using our model, we call p(x₀) the likelihood of the data.
Note the above equation is merely a notation to indicate that p(x₀) is what is left after all latent random variables are integrated away. It does not tell us how to integrate them away. This is because the integration symbol “∫” represents the result of the integration, without telling us how to do the integration.
Why not integrate x₀ away from the likelihood p(x₀) as well?
The above T-dimensional integration only integrated away the latent random variables x₁ to x_T, and left the observational random variable x₀ in p(x₀). Why? Because if x₀ is integrated away as well, the whole joint probability density function becomes 1 since all probability density function when integrated over its full set of random variables, yields 1:
You see, there is no place to plug in actual images to this number 1 to evaluate how well it explains the training data. This prevents us from performing parameter learning.
That’s why we leave x₀ unintegrated-away and work with p(x₀), which is called likelihood of the data, or likelihood for short.
The likelihood p(x₀) mentions all model parameters
Even though the likelihood p(x₀), with an actual image plugged in, that is, p(x₀=X₀), is a function that doesn’t mention any random variable, it still mentions all model parameters, that is, weights in the two neural networks, via the probability density function p(xₜ₋₁|xₜ), shown again here:
The weights are in the mean predicting network ηₚ and the covariance matrix predicting network Σₚ. The latent random variables x₁ to x_T are integrated away, but their mean, and covariance matrix terms are left in the result of the integrations.
You may ask, since we have not discussed how the latent random variables are integrated away, how do we know ηₚ (xₜ, t) and Σₚ(xₜ, t) will survive the integrations? You will see for yourself later in the derivation for the analytical loss section, but here you have to believe me: if after the integrations, those Gaussian latent random variables are gone, and the two very things, namely, the mean vector ηₚ (xₜ, t), and the covariance matrix Σₚ(xₜ, t), that describe them are also gone, then it seems that those random variables have never existed in our model. This doesn’t make sense. So the two neural networks ηₚ (xₜ, t) and Σₚ(xₜ, t) will survive the integrations. In other words p(x₀) mentions all model parameters in ηₚ (xₜ, t) and Σₚ(xₜ, t).
p(x₀) mentions all model parameters, which are the weights in the two neural networks, but we don’t know the proper values for the model parameters yet. If we pretend to know all parameter values, then we can evaluate p(x₀) into a probability numbers between [0, 1] by plugging a training image from the training image set, one image at a time. This results in many probability numbers. Averaging these probability numbers gives us a measurement of how well our model explains the training data.
Of course, we don’t know the values for the neural network weights. We can set them to arbitrary values, but this will likely result in a poor model that doesn’t explain the training data well. In this case, it is not that our model structure is incapable of explaining the data, it is the model has not been calibrated the model correctly. By “model structure” I mean the reverse process with two deep neural networks predicting the mean vector and the covariance matrix of a denoised image.
Optimization can find proper parameter values for those neural networks And it needs a loss function to minimize.
The negative log likelihood of data as the loss function
A loss function must mention all model parameters. The likelihood p(x₀) satisfies this requirement. A good model should let an actual image X₀ evaluates to a high likelihood probability number p(x₀=X₀). We want to minimize the loss function, hence the negative sign. A good model not only need to work well for a single image from the training set, it needs to work well for all images in the training set, hence the expectation with respect to images sampled from the training set x₀~q(x₀). We can take the log of p(x₀) because log is a monotonic function that does not affect the optimal value for the loss function; we want to introduce the log function because it is the essential part in the KL-divergence that we will use below.
All the above thinking leads use to the famous negative expected log likelihood, denoted by L:
Line (1) is the definition of the negative log likelihood of data.
Line (2) plugs in the definition of p(x₀), which integrates all latent random variables x₁ to x_T out of the density function p(x₀, ⋯, x_T).
The negative log likelihood loss function is not opitimizable
The standard way to perform parameter learning via optimization is to use gradient descent to minimize the loss function with respect to the model parameters. Gradient descent needs to know the analytical formula for the loss function to take the derivative of it. Unfortunately, the analytical formula for the negative log likelihood loss function is very difficult to derive.
To work out the analytical formula for the loss function L, let’s look at its definition again:
p(x_0:T) is previously defined (shown again here) as:
It is easy to see that our model parameters — the weights in the mean vector predicting neural networks μₚ and the covariance matrix predicting neural network Σₚ — are mentioned in the loss function. But they are mentioned in the integration symbol ∫.
Unlike things such as the exponential symbol exp, or the squared operator “²”, which represent computation which we immediately how to do. The integration symbol represents the result of a computation, that is, it asks you to integrate a function, without telling you how to do the integration.
From our Calculus course, we all know that taking derivative is work, but performing integration is art — taking the derivative of a term is a mechanical procedure as long as you have the derivative cheat sheet. But integration requires creativity, and we have so many integrations that we just don’t know how to do.
Unfortunately, the integration of p(x_0:T) inside the loss function L falls into the kind of integrations that are very hard to solve analytically. We call it an intractable integration. Let’s use a short reverse process where T is set to 1 to demonstrate this point. Our goal is to show that L is intractable:
Since we know how to sample x₀ from our training set of natural images, the outside expectation with respect to x₀ can be dealt with using sample-averaging (see the next section about sample-averaging). So the only difficult term is the integration inside the log and we want to show that the integration is hard to solve:
After setting T to 1, the above term turns into:
Line (1) shows the the shortened joint probability density function for the reverse process when T=1.
Line (2) factorizes the joint into product of probability density functions, each for a single random variable, x₁ and x₀ respectively.
Line (3) plugs in the names of those single probability density functions.
Line (4) plugs in the actual probability density functions, which are multivariate Gaussians. The first exp is for the random variable x₁ from the standard Gaussian, and the second exp is for the random variable x₀ conditioned on x₁. I used the proportional symbol “∝” to ignore the normalization terms in front of each multivariate Gaussian.
The integration at line (4) is hard to solve analytically. Note, in this case, we know how to compute the integration analytically by using the product rule of integration, it is just hard and messy, especially when T=1000. The variational method (explained later) that the authors of the paper proposed is more elegant.
This little exercise also reveals that we may be able to use a technique called sample-averaging to approximate the integration analytically. This is because in line (4) the probability density function of each random variable is mentioned only once, and this single mention of the density function is computing the expected value with respect to that random variable. Sample-average can approximate such expectations.
Can we use sample-averaging to derive the analytical form for the loss function?
The loss function L contains an intractable integration. There are multiple ways to approximate an intractable integration, such as sample-averaging, importance sampling and Gaussian quadrature. Let’s look at the simplest of them all, sample-averaging.
What is sample-averaging
Sample-averaging is simple — average function evaluations based on samples of a random variable to compute expectation. Formally, let x be a random variable from the distribution h(x), then the following integration on function f(x) with respect to x can be approximated by the averaging the evaluation of f(x) with samples of x from the distribution h(x) plugged in.
In other words, as long as f(x) is evaluable when samples for x is plugged-in, sample-averaging approximates the integration.
Approximating an integration analytically using sampling-averaging
As you can see, sample-averaging approximates an integration with a sum of the integral function. This sum of terms is analytically with respect to our model parameters. A simple example shows this point: we want to write down the analytical formula for the loss function of some model, written as the the following integration. In this integration, let’s say x is a random variable that can be sampled from h(x) and μ is our model parameter to optimize.
After drawing two samples S₁ and S₂ of x from h(x), and applying sample-averaging to approximate this integration results in an analytical expression for μ at the right hand side of the following approximation equation:
The right hand side of the approximation is an expression that mentions the model parameter μ. This expression is analytical — it does not mention symbols that represents results of computation, such as the integration ∫, it only mentions symbols that represents computation, such as the exponential function exp, and the squared operator “²”, for which we know how to compute gradients.
We use sample-averaging all the time. For example, to compute the expected height from students in a school, we don’t know the distribution of student’s height, but we have samples of measured students’ heights. Then compute the expectation by doing the averaging.
Counter-example that makes sample-averaging inapplicable
As long as p(x) does not appear in the function f(x) to be integrated over, and p(x) is easy to sample, we can use sample-averaging. Here is a counterexample:
In this case, q(xₒ) is the data distribution, whose probability density formula is unknown. Still we can sample from it by randomly picking images from the training set. Sample-averaging can remove the right q(xₒ) But note q(xₒ) also appears in the g function being integrated. This left q(xₒ) cannot be removed by sample-averaging. g(1+q(xₒ=Xₒ)) remains an un-evaluable function even with the sample Xₒ plugged-in. So in this case, sample-averaging cannot solve the integration.
Luckly, our loss function L doesn’t fall into this category, so we can use sample-averaging to approximate L analytically.
Deriving L’s analytical formula via sample-averaging
To derive the analytical formula for the loss function L, rewrite L as follows:
Now it is easy to see that to apply sample-averaging, we can sample latent random variable x₁ to x_T first from the definition of the reverse process:
Finally we sample the observational random variable x₀ from q(x₀).
Explicitly, the sampling process goes like this:
- First a sample for x_T from the standard multivariate Gaussian distribution using the base case to remove the integration with respect to x_T.
- With a sample Sₜ for the random variable xₜ at hand, plug Sₜ into p(xₜ₋₁|xₜ=Sₜ), then sample xₜ₋₁, using the inductive case.
- As long as we don’t lose model parameters during this process, we will end up with an analytical formula for the loss function L. Losing model parameters during sample-averaging means that it is possible that sample-averaging results in a formula that does not mention model parameters anymore. This is bad because a loss function that does not mention model parameters is useless. Reparameterization trick is used to prevent this from happening. But in our case, we don’t need to worry about losing model parameters when applying sample-averaging. Appendix “Why we won’t loss model parameters when applying sample-averaging to derive the analytical formula for the loss function L?” explains why.
- Once all samples for x₁ to x_T are available, let’s call them a sample trajectory. Plug this trajectory into the joint probability density p(x_0:T) to get the analytical expression for p(x_0) under this trajectory.
- Repeat steps 1~4 to get analytical expression for p(x_0) under different trajectories and average them to approximate the inner integration. Say there are m sample trajectories, each trajectory i gives an analytical formula pᵢ(x_0), then the analytical formula for the average is:
Analytical formula for L via sample-averaging is expensive to compute
The above sample-averaging requires a lot of computation, this is because each trajectory requires T samples, one for each latent random variable. Common sense tells us drawing a single sample for a random variable is not enough — for example, you shouldn’t compute the average student heights in a school by just measuring the height of a single student. Why that’s bad? Because small sample size gives us an estimate, in this example, the expected height, with high variance. Every time you draw a single sample, you get a different expectation — that’s variance.
We want to draw many samples for each random variable because more samples means the average more closely approximates the original integration. In other words, less variance. But drawing more samples requires a lot of computation: drawing two samples for each latent random variable and set T=1000, results in m=2¹⁰⁰⁰ trajectories to average over. That’s very expensive.
Practically, we can only afford to draw a single sample for each latent random variable. But this brings the high variance problem back.
Analytical formula for L via sample-averaging using small number of samples has high variance
The problem with drawing only a very small number of samples (for example, a single sample per latent random variable) for each latent random variable is that the probability number computed for an actual image X₀, that is, p(x₀=X₀) has high variance. This is because the probability number p(x₀=X₀) depends on the sampled concrete values for the latent random variable x₁ to x_T. Every time you compute p(x₀=X₀) for the same image X₀, this probability number is different. Since there are T concrete sampled values in a sample trajectory, the variance is likely to be quite high.
To make things worse, at the beginning of the parameter learning process, the weights in the mean vector and variance matrix predicting neural networks are randomly initialised. So the sampled images through those neural networks may be of very bad quality — in the sense that they don’t look like a denoised version of the previous image at all. Even though this doesn’t contribute to a higher variance, but low quality samples makes the parameter learning harder.
Why a high variance to p(x₀=X₀) is bad? Because p(x₀=X₀) is our measurement of how well our model explains the training data, in this case, how well it explains the training image X₀. If for the same image, our measurement sometimes reports a large p(x₀=X₀) probability number and sometimes reports a small probability number, then the optimizer, for example the Adam optimizer, is uncertain whether our current model can explain the training data well or not. This uncertainty is usually reflected by a very slow and even diverging training process.
Since sample-averaging is not a good way to to derive the analytical formula for the loss function L. Is there a better way? Once again, variational inference comes to the rescue.
To shorten the current article, I decided not to introduce variational inference and use it as known knowledge. For its introduction and two applications, please see: Demystifying Tensorflow Time Series: Local Linear Trend and Variational Gaussian Process (VGP) — What To Do When Things Are Not Gaussian.
The key idea is to use another distribution to compute (note the word “compute”, not “approximate”) an otherwise intractable integration in an analytical way. I’ll use importance sampling inside the loss function to introduce the new distribution.
Derivation by importance sampling
Importance sampling introduces a distribution that is easy to sample from to help solve an otherwise intractable integration.In our case, the intractable integration is with respect to the random variables x₁ to x_T from the joint reverse process distribution p(x₀, …, x_T).
Note that in our case, p(x₀, …, x_T) is sample-able. As we described previously, sample-averaging is applicable to approximate the loss function analytically.
The true motivation to introduce a new distribution, which is the joint forward process q(x_1:T|x₀) in our case is that it helps derive the analytical formula for L. And it encodes the “gradual-ness” requirements for in-between images from the reverse process. You won’t know what I’m talking about here. Both points will be clear later after we finished deriving the analytical formula of the loss function using importance sampling.
In the loss L, the integration is with respect to the latent random variables x₁ to x_T, shown again in line (2) below, so the new distribution we introduce must be over the same set of random variables. The distribution q(x_1:T|x₀) that we defined for the forward process fits this requirement. Line (3) introduces it into the formula of L.
I would like to point out the the above derivation is valid for any probabilistic model with random variables x₀ to x_T because it only uses properties from the probability theory. This properties are true for any valid probability distributions. Only starting from the next derivation over Lᵥ, when we start to factorize joint probabilities using our definition of the forward process and the reverse process, we start to rely on the specific model structures.
Line (3) introduces the q(x_1:T|x₀)/q(x_1:T|x₀) quantify. This quantity evaluates to 1, so its addition does not change the integration.
Line (4) re-organizes the terms, turning the old integration into a new one with respect to q(x_1:T|x₀).
Line (5) represents the integration using the equivalent expectation notation.
Line (6) uses Jensen’s inequality to push the log function into the inner expectation because expectation of logs is easier to compute than the log of expectations. Jensen’s inequality also turns the result we will eventually minimize into a quantity that is larger than the original loss L.
Line (7) replaces the expectation notations into their definitions, that is, integrations. And line (8) re-arranges terms.
Line (9) applies the reverse chain rule in probability theory to derive the joint probability q(x_0:T).
Line (10) represents the integration using the equivalent expectation notation. Note that we started with introducing the q distribution over the random variables x₁ to x_T, and arrived at an expectation with respect to the random variables x₀ to x_T. We give this new quantity the name Lᵥ, standing for variational loss.
New loss Lᵥ to derive analytical formula for, and to minimize
From now on, Lᵥ is the quantity to minimize. Our goal is updated to derive the analytical formula for the new loss Lᵥ. Looking at line (10) it is hard to believe that it is analytical. But in math, amazing things do happen. Please read on.
Rewriting Lᵥ to get the important Lₜ₋₁ term
This is an important derivation, please pay attention.
Line (1) shows the derivation of the new loss Lᵥ. Lᵥ mentions the joint probability density of the reverse process p(x_0:T) and the forward process q(x_1:T|xₒ) that we defined previously.
Line (2) factorizes these two joint probability densities. It factorizes p(x_0:T) using the definition of the reverse process. And it factorizes q(x_1:T|x₀) using the first factorization of q.
Starting from this line, we are relying on the model structure that we defined, that is, the structure in which random variable xₜ₋₁ depends on xₜ in the reverse process p, and xₜ depends on xₜ₋₁ in the forward process. This is not true for arbitrary probabilistic models.
Line (3) again performs factorization. Note that the products start with t=2 instead of t=1 because of the factorization at this line.
Line (4) pushes the minus sign from outside of the expectation to inside of the expectation, and it uses the property that log(a×b) = log(a) + log(b).
Line (5) introduces name F_T to represent the first term and Fₒ for the third term inside the expectation to shorten the derivations so they fit in one line.
Line (6) is the key line, it uses the Bayes rule to replace q(xₜ₋₁|xₜ):
Note the addition of the dependency on x₀ to turn q(xₜ|xₜ₋₁) into q(xₜ|xₜ₋₁, x₀). This addition is redundant, it does not change the conditional probability because by definition, the random variable xₜ only depends on xₜ₋₁. See the definition for xₜ, shown again below. It only mentions xₜ₋₁ and not x₀.
The addition makes it easier for us to apply the Bayes rule because the Bayes rule mentions q(xₜ|x₀) and q(xₜ₋₁|x₀) which explicitly depends on x₀.
Note the dependency on x₀ in q(xₜ₋₁|xₜ, x₀) is not redundant. x₀ appears here because of the Bayes rule.
The reason for using the Bayes rule is to make the term q(xₜ₋₁|xₜ, x₀) popup. q(xₜ₋₁|xₜ, x₀) is a term from the reverse of the forward process. We now have a probability ratio between p(xₜ₋₁|xₜ) and q(xₜ₋₁|xₜ, x₀), seen at line (6). p(xₜ₋₁|xₜ) and q(xₜ₋₁|xₜ, x₀) are both:
- probability density function for the same random variable xₜ₋₁ and
- they are both multivariate Gaussian distributions with their analytical probability density available — previously we have defined the analytical form for both p(xₜ₋₁|xₜ) in the reverse process, and q(xₜ₋₁|xₜ, x₀) in the reverse of the forward process.
These two properties make it possible to derive the KL-divergence between p(xₜ₋₁|xₜ) and q(xₜ₋₁|xₜ, x₀) analytically, detailed later.
Line (7) uses the property of log to split terms.
Line (8) uses the property of log to term to turn sum of logs into log of products.
Line (9) realizes that in the log of products, the numerator and the denominator shares many terms, which can cancel, leaving just one term in the numerator and one term in the denominator.
Line (10) introduces the name F₀ to denote the last term inside the expectation. And it introduces the name Lₜ₋₁ for each of the negative log terms in the summation to make the derivations shorter. That is:
Obviously the Lₜ₋₁ for t=[2, T] terms are important. Note that for Lₜ₋₁, t starts from 2 instead of 1 because of the split in line (3). These T-1 terms constitute most part in the whole loss function, leaving only other three terms behind. Let’s worry about those three terms later and focus on the Lₜ₋₁ terms, as it will be the core for the final loss function that we minimize.
Let’s keep manipulating Lₜ₋₁ for t=[2, T]:
Line (1) is the definition of the Lₜ₋₁ term. Line (2) pushed the minus sign into the log. The expectation is with respect to the random variable x₀ to x_T from the q distribution.
Line (3) replaces the expectation notation with its mathematical definition, which is an integration over the random variables x₀ to x_T.
Line (4) factorizes the joint probability density q using the second factorization of the forward process.
Note the second factorization is a product of many distributions, each mentions a single latent random variable. This is correct because given the observational random variable x₀, all latent random variable x₁ to x_T are independent to each other.
Line (5) organizes all the factors from the q distribution into four parts:
- q(x₀), which is a distribution about x₀, and its formula is unknown.
- q(xₜ₋₁|x₀), which is a distribution about xₜ₋₁.
- q(xₜ|x₀), which is a distribution about xₜ.
- q(xₒₜₕₑᵣ), which is a distribution about the latent random variables other than xₜ₋₁ and xₜ.
The reason for line(5)’s factorization is that the log function only mentions x₀, xₜ₋₁ and xₜ.
Line (6) applies the chain rule to derive the joint probability q(xₜ₋₁, xₜ|x₀).
Line (7) is a key line. Using the reverse chain rule (which is applicable for any joint probability density), it replaces q(xₜ₋₁, xₜ|x₀) with q(xₜ₋₁|xₜ, x₀)q(xₜ|x₀) because
Line (8) splits the integrating variables into 4 parts, corresponding to x₀, xₜ₋₁, xₜ and xₒₜₕₑᵣ, and re-orders them.
Line (9) recognizes that the inner integration is the KL-divergence between q(xₜ₋₁|xₜ, x₀) and p(xₜ₋₁|xₜ). This KL-divergence is between two multivariate Gaussian distributions, whose analytical probability density functions are known. So we can write down the formula for this KL-divergence analytically. It is a function that mentions the random variable xₜ and x₀ (note, it does not mention xₜ₋₁), as well as all model parameters.
Line (10) factorizes q(xₜ, xₒₜₕₑᵣ, x₀) into conditionals.
Now we have the analytical expression for the KL-divergence between q(xₜ₋₁|xₜ, x₀) and p(xₜ₋₁|xₜ), but this KL-divergence is inside an integration. How do we solve the integration analytically?
That’s right, we can use sample-averaging to approximate the expectation with respect to x₀, xₜ and xₒₜₕₑᵣ:
- Sample x₀ by randomly picking natural images from the training set.
- Sample xₜ from the marginal q(xₜ|x₀) after plugging the sample for x₀.
- No need to sample xₒₜₕₑᵣ as line(10) reveals that xₒₜₕₑᵣ is not mentioned in the KL-divergence. The values for random variables inside xₒₜₕₑᵣ won’t change the computed result of the KL-divergence.
Phew, after so many steps, we finally arrived at the analytical expression for the Lₜ₋₁ terms for t in [2, T] in our new loss function Lᵥ to minimize.
Sample-averaging to solve the integration
Let me paste the analytical formula for Lₜ₋₁ here, and add the steps that use sample-averaging to approximate the integration analytically.
Line (1) is the analytical formula we derived just now for Lₜ₋₁. It has a multiple integration over the random variable x₀, xₜ and xₒₜₕₑᵣ. All three kinds are easy to deal with because:
- First, sample x₀ from our training set. Let’s call x₀’s sample S₀.
- Plug S₀ in q(xₜ|x₀) to get q(xₜ|x₀=S₀), which is now a fully specified multivariate Gaussian distribution ready to be sampled. Let’s call a xₜ’s sample Sₜ.
- Ignore the integration over xₒₜₕₑᵣ because xₒₜₕₑᵣ does not appear in the KL-divergence, their samples do not change the analytical form for the integration result.
Line (2) uses the above sampling scheme to sample n pair of (S₀, Sₜ); plugs each pair into the KL-divergence formula to get a analytical term, and then averages these analytical terms.
You may ask, how many pair n we should sample? The more the better, but empirically, a single pair already gives us good results, so n=1.
So line (3) uses the fact n=1 to remove the summation from line (2) to arrive at this simple formula:
KL(q(xₜ₋₁|xₜ, x₀) || p(xₜ₋₁|xₜ)) serves as regularization
After so much effort to derive the analytical formula for this KL-divergence, it is wise to look at it closely.
For each step t in [2, T], this KL-divergence quantifies the distance between two distributions:
- q(xₜ₋₁|xₜ, x₀) — the reverse of the forward process that we derived from the forward process by using the Bayes rule.
- p(xₜ₋₁|xₜ) — the reverse process that we used deep neural network to implement.
We are minimizing this KL-divergence. That is, we want these two distributions to be similar at each time step from t=[2,T]. In other words, we want to find a model p(xₜ₋₁|xₜ) that gives similar results as the reverse of the forward process q(xₜ₋₁|xₜ, x₀) at all steps. “similar results” means that at timestamp t-1, images sampled from p(xₜ₋₁|xₜ ) should be similar to image sampled from q(xₜ₋₁|xₜ, x₀).
Pay attention the timestamp range t=[2, T] here. This range means that the Lₜ₋₁ terms only covers the timestamps from 2 to T, leaving the first step t=1 unformulated. The timestamp t=1, being the step that finally generates the natural image, is of course important. Remember we left three teams from Lᵥ unanalyzed? Later we will see that the left terms covers the first timestamp.
So the reverse of the forward process serves as a regularization for the neural network and establishes the notion of “gradual-ness” among the images generated by the reverse process — the images generated by the reverse process at each time stamp should be similar to the images from the reverse of the forward process at the corresponding time stamps. Since the images from the reverse of the forward process changes gradually, the images from the reverse process must also change gradually.
This regularization restricts the neural network to behave according to an already known and much simpler process — the reverse of the forward process. The per-step KL-divergence prevents the learnt neural network to do weird things, such as first generates an image of a cat at an early step, and then morphs the cat into human face.
Now you should be convinced that the introduction of the forward process distribution q helps establish the gradual-ness of the generated images from the reverse process p.
Trajectory viewpoint
Let’s use the illustration below to reveal what KL(q(xₜ₋₁|xₜ, x₀) || p(xₜ₋₁|xₜ)) is trying to do from the trajectory point of view.
The left subplot shows two natural images X₀ and X₁. Starting from each natural image, if we apply the forward process multiple times, we get multiple trajectories. The black curves starting from X₀ or X₁ represent these trajectories. Timestamps go from left to right, so the images at the end of each trajectory are pure Gaussian noise already.
In this completely unconditioned setting, at timestamp t-1, the random variable xₜ₋₁ in our model can take values from any trajectory, no matter a trajectory starts from X₀ or X₁. In other words, at timestamp t-1, our model needs to be able to explain all possible images that can be generated by the forward process, starting from any natural image. Our model can do that by giving the random variable xₜ₋₁ a mean that is in the middle of all the trajectories and a large variance.
The middle subplot shows the situation when x₀ is given, which sets the random variable x₀ to the natural image X₀. This setting restricts the model to only explain the trajectories that start from the natural image X₀. They are the red trajectories in the middle subplot. In other words, our model now only need to explain the possible values from the red curves at timestamp t-1. The model can do that by offering a more precise mean and a smaller variance, since it does not need to cover the black trajectories starting from the natural images X₁ anymore.
The right subplot shows the situation when x₀ is still conditioned to X₀, and additionally, xₜ is conditioned on a particular image Sₜ, which is sampled from the distribution q(xₜ|x₀=X₀). This second conditioning further restricts the model to only need to explain trajectories that go through Sₜ at timestamp t. These are the blue trajectories, which are all start from X₁ and pass through Sₜ.
Under this condition, the possible values that the random variable xₜ₋₁ can take at timestamp t-1 is further restricted. This means that our model needs to predict a mean that is around middle of the blue trajectories, and predicts an even smaller covariance for xₜ₋₁.
But how “around the middle of the blue trajectories” should the predicted mean be, and how “even smaller” should the predicted covariance be for the random variable xₜ₋₁? These two target quantities are defined by the reverse of the forward process q(xₜ₋₁|xₜ, x₀), with its definition shown here again:
with
By conditioning the model on xₜ and x₀, we are giving the model an easier task to learn at each training step because at each step, the model only needs to explain a single time step at a relatively small amount of trajectories.
Optimization forces p to change by fixing q
Since the the reverse of the forward process q(xₜ₋₁|xₜ, x₀) is fixed, that is, there is no trainable parameters in q(xₜ₋₁|xₜ, x₀), the only way the optimization can do to make q(xₜ₋₁|xₜ, x₀) and the reverse process p(xₜ₋₁|xₜ) similar to each is to change the model parameters’ values to move p closer to q.
One thing to note is that many other papers introduce a learnable q and move q closer to p. Not in this paper. In this paper, the q distribution introduced in importance sampling is fixed, and minimizing the KL-divergence between q and p moves p.
Since the KL-divergence Lₜ₋₁=KL(q(xₜ₋₁|xₜ, x₀) || p(xₜ₋₁|xₜ)) is analytical, let’s write it down. Recap the probability density functions for the two mentioned distributions in the KL-divergence are both multivariate Gausisans:
The analytical formula for the KL-divergence between two multivariate Gaussians is:
The above formula has 4 terms.
The first term at line (1) computes the log ratio between two covariance matrix determinant, denoted by the name “det”. This team mentions model parameters.
The second term at line (2) reference d, the dimension of the random variable xₜ₋₁, which is the number of pixels in the images that we are working with. This term does not mention any model parameter.
The third term at line (3) computes the trace, denoted by the name “tr”, of two matrix product. This term mentions model parameters.
The fourth term at line (4) is the square of the vector μₚ(xₜ, t)-μₜ(xₜ, x₀), scaled by the covariance matrix Σₚ(xₜ, t)⁻¹.
I know, this formula is terrible. And please remind ourselves that we need to minimize this term with respect to the model parameters, which appears in:
- μₚ(xₜ, t), the neural network that is responsible to predict the mean of the mean vector for the p(xₜ₋₁|xₜ) multivariate Gaussian distribution.
- Σₚ(xₜ, t), a second neural network that is responsible to predict the covariance matrix for the p(xₜ₋₁|xₜ) multivariate Gaussian distribution.
Simplifying the model by setting the reverse process covariance matrix to constant
let’s simplify the model by removing the second neural network that predicts the covariance matrix. Mathematically, we set Σₚ(xₜ, t)=σₜ²I, where one of the obvious choice for σₜ² is:
The above makes the covariance matrix from the reverse process p(xₜ₋₁|xₜ) the same as the covariance matrix of the reverse of the forward process.
With this simplification, the first three terms become constants, let’s name their sum C. C does not mention model parameters anymore. They can be ignored during optimization. This left us with only the fourth term, let’s call it LMₜ₋₁. So we have:
with LMₜ₋₁ being:
Line (1) is the fourth term. Line(2) plugs in the simplified covariance matrix. The ||…||² in line (3) is the vector square operation, that is, vector dot product with itself. Line(4) swaps the two components in the square, which does not make a difference in result, just to be more consistent with the order of terms in the paper.
Note that I dropped the expectation with respect to x₀ and xₜ in LMₜ₋₁ to make the formula concise. But the computation is the same as before, we need to sample x₀ and xₜ, plug the samples in LSₜ₋₁ to approximate the integration analytically.
Interpreting the meaning of LMₜ₋₁
LMₜ₋₁ quantifies the distance between the two vector μₜ(xₜ, x₀) and μₚ(xₜ, t). This makes a lot of sense now:
- Originally we want to minimize the distance between q(xₜ₋₁|xₜ, x₀) the reverse of the forward process and p(xₜ₋₁|xₜ), which is our neural network implementation of the reverse process, at every time step t from 2 to T. In other words, we want to find a configuration (model parameter values) for the p(xₜ₋₁|xₜ) distribution such that these two distributions are similar to each other.
- These two distributions for the random variable xₜ₋₁ are both multivariate Gaussian. A multivariate distribution is fully specified by it mean vector and covariance matrix. If p(xₜ₋₁|xₜ) needs to be similar to q(xₜ₋₁|xₜ, x₀), their mean vector and covariance matrix must be similar to each other. This is called momentum matching, with the mean being the first momentum, and the covariance being the second. The letter “M” in LMₜ₋₁ stands for momentum matching.
- After we simplified the covariance matrix from the p(xₜ₋₁|xₜ) distribution to a quantity that is equal to the covariance matrix from the reverse of the forward process, the only thing that we can still change to make these two distributions similar or different is the mean vector. So we want to minimize the distance between the mean vectors from the p(xₜ₋₁|xₜ) and the q(xₜ₋₁|xₜ, x₀) distribution.
- Since the mean vector from the p(xₜ₋₁|xₜ) distribution is predicted by our neural network, we can use optimization to move the values of the neural network weights around by minimizing LMₜ₋₁.
Simplifying LMₜ₋₁
It is possible to simplify LMₜ₋₁, a lot. In LMₜ₋₁’s formula, the μₚ(xₜ, t) part is from the neural network, it’s like a black box, there is little we can simplify. So let’s try to simplify the other term μₜ(xₜ, x₀), which is the mean vector of the reverse of the forward process q(xₜ₋₁|xₜ, x₀), whose analytical probability density function is already derived:
with the covariance matrix:
and the mean vector:
We only need to look at the mean vector μₜ(xₜ, x₀) because previous derivation of LMₜ₋₁ reveals that we only need to use our neural network to predict a mean vector that is close to, or alternatively, match, μₜ(xₜ, x₀).
We also have the analytical probability density function for q(xₜ|x₀):
Using the reparameterization trick, we can rewrite the above into:
Re-organize the terms in the above equation to get the expression for x₀:
Now plug in this expression of x₀ into the formula for μₜ(xₜ, x₀):
Line (1) is a horrible formula, and line (2) introduces name A to represents the coefficient in front of xₜ, and the name B for ϵₜ. We will simplify A and B separately.
Simplifying A
Simplifying B
Wow, what an amazing simplification! It gives us:
Don’t panic, our goal has not changed — we still want our neural network to predict the mean vector of the p(xₜ₋₁|xₜ) distribution and the predicted mean vector should be as close to μₜ(xₜ, x₀) as possible. But upon seeing the simplified formula for μₜ(xₜ, x₀), we realize:
- xₜ is known via sampling, there is no need to predict it.
- Given timestamp t, βₜ is constant, and so all the other quantities derived from βₜ, namely αₜ and αₜ bar.
- The only part that needs predicting is the noise ϵₜ.
We can drop the original neural network, and design a new one ϵₚ(xₜ, t) that predicts the noise ϵₜ. Then we can construct the desirable mean vector μₚ(xₜ, t) by:
Plug this formulation into the definition of LMₜ₋₁ give us:
Line (7) is the simplified objective function to minimize.
Note that this objective function mentions the noise ϵₜ twice. They are the same random variable, not two different noises. This is because they both come from the same source:
The first time we use the above to get x₀ as an expression of xₜ and ϵₜ. The second time we use get xₜ as an expression of x₀ and ϵₜ.
Is this objective function still analytical?
Remember previously we drop the expectation with respect to xₜ and x₀ for LMₜ₋₁ to shorten our derivations? To answer the question if LMₜ₋₁ is still analytical, we have to add them back, because only with those expectations, we are computing the correct LMₜ₋₁.
Note:
- In the final formula for LMₜ₋₁, there is no mention of xₜ anymore, xₜ is expressed via x₀ and the noise ϵₜ. So we don’t need to add the expectation with respect to xₜ. Instead, we need to add the expectation with respect to ϵₜ, which is a standard multivariate Gaussian, that is ϵₜ~N(0, 1).
- There is the mention of timestamp t, which represents an integer between 2 and T. We need to add an expectation with respect to t, which comes from a uniform distribution.
- There is the mention of x₀, which comes from the unknown data distribution q(x₀).
So, the complete formula for LMₜ₋₁ is:
where Uni(2,T) denotes the uniform distribution between 2 and T.
This formula is analytical with sample-averaging. When we plug in the samples for x₀, ϵₜ and t into the above formula, we have an analytical expression, from which we can take gradient to perform stochastic gradient descent.
The authors found by ignoring the constants in front of the vector distance erm, the results is better:
The following Algorithm 0 minimizes the above loss:
Algorithm 0 evaluates the expectation with respect to x₀, xₜ and t by sample-averaging. Note at line (3), the timestamp t is sampled from the uniform distribution Uni(2, T).
One notational difference between the paper and this article is that in the paper, the authors use ϵ_θ to denote the neural network, and I use ϵₚ. The authors used ϵ_θ to highlight that the neural network has parameter set θ. This is also explicitly shown at line (5) of the above algorithm when the gradient (notice the ▽ symbol that denotes derivative over vector) is computed on the loss function with respect to θ. I use ϵₚ, because there is no subscript θ in Unicode, and I don’t want to write two many ϵ_θ as they don’t look good.
Another notational difference is the paper uses ϵ to denote standard Gaussian noise, and I used ϵₜ. I use ϵₜ because I derived my formulas this way. But I think ϵ is better because the standard Gaussian noise does not depend on the timestamp t.
The derivation for Lᵥ shows that it is an expectation with respect to q(x_0:T) and inside the expectation there are multiple terms, shown below:
Previously we only focused on the Lₜ₋₁ terms for t=[2, T]. Now let’s talk about the remaining terms, which I extracted into the first expectation at line (2) using the linearity of expectation property: E[a + b] = E[a] + E[b].
Line (2) replaces the names F_T and F₀ with their actual formula.
Line (3) and (4) re-writes the terms using the properties of log.
Line (5) simplifies the second log.
Line (6) splits the expectation into 2 using the linearity of expectation property.
Line (7) gives the first expectation the name L_T, same as the paper.
Line (8) gives the negative of the second expectation the name L₀, same as the paper.
The L_T term can be ignored in optimization, while the L₀ needs special treatment. We will see why.
Ignoring the L_T team
Here is the formula for the L_T term again:
It mentions q(X_T|x₀), which is the marginal probability density for the random variable X_T. The forward process doesn’t include any model parameters.
It also mentions p(X_T) which is the reverse process at timestamp T. We defined p(X_T) = N(0, 1). So p(X_T) doesn’t mention model parameters either.
This means the whole L_T term doesn’t mention model parameters, thus it can be ignored during parameter learning.
Approximating the L₀ term
The L₀ term is:
This term is for the timestamp t=1. Let’s understand what this term is saying. We want to minimize this team, which translates to finding model parameters that maximize the log likelihood log(p(x₀|x₁)). In other words, we want p(x₀|x₁) to evaluate to a high probability number when a natural image is plugged into x₀.
Alternatively, we can understand it by using the formula from Lₜ₋₁:
Line (1) is the definition of Lₜ₋₁ that we derived previously. Note that when we derived it, t starts from 2 because when t≥2, all Lₜ₋₁ terms are KL-divergences between two proper Gaussian distributions. This is not true for t=1 as you will see at line (4).
Line (2) sets t=1 to derive L₀. And line (3) expands the KL notation to its mathematical definition.
Line (4) uses the property that q(x₀|xₜ, x₀) = 1. This line also reveals that when t=1, there is no KL anymore. The formula degrades to an integration of a log. That’s why we cannot handle t=1 in Lₜ₋₁.
Line (5) uses the property of log to simplify the formula.
Line (6) replaces the integration using the expectation notation.
Line (7) simplifies the two expectations over x₀ into one expectation over x₀ since one expectation already removes the random variable x₀. The second expectation over x₀ doesn’t change the result anymore. This line also reveals that the resulting quantity is indeed the L₀ term.
L₀ needs to be minimized differently, it can’t fit into Algorithm 0
Now we should understand that it is not that we cannot derive L₀ from the Lₜ₋₁ point of view. We can, but the derivation of L₀ is not a KL-divergence between two proper multivariate Gaussian distributions, which means the analytical formula of L₀ is different from the analytical formula of Lₜ₋₁ for t≥2. This means we need a different way to minimize L₀. In other words, the minimization of L₀ doesn’t fit into Algorithm 0. Well, it doesn’t fit yet, later we will introduce a proximation to make it fit.
L₀ is optimizable
Since we want to minimize L₀, it is important that either:
- L₀ does not mention any model parameters so it can be ignored during the optimization. Or
- L₀ mentions model parameters and is analytical so its gradient can be taken for gradient descent.
Since the previous loss function LMₜ₋₁ only handles the case when t≥2, we hope that L₀ falls into the second category above so some part of our loss function covers the case t=1. Indeed that’s the case:
Line (1) is the definition of L₀.
Line (2) plugs in the definition of p(x₀|x₁), which is a multivariate Gaussian distribution with the neural network µₚ(x₁, 1) predicting its mean vector, and with its covariance matrix set to the constant 𝛼₁² I. I ignored the normalization term in front of the exponential, and used the proportional symbol “∝”.
Line (3) and line (4) simplifies the formula.
Line (4) reveals that L₀ mentions all the model parameters in µₚ(x₁, 1) and it is analytical after we sample x₀ and xₜ. So L₀ is optimizable.
Minimizing an approximation of L₀ inside Algorithm 0
Line (4) from above also shows that to minimize L₀, the neural network µₚ(x₁, 1) needs to predict a mean vector that is close to a natural image, say X₀, sampled for x₀.
Previously when we derive the analytical formula of Lₜ₋₁ for t≥2, we arrived at the realization that we want our neural network µₚ(xₜ, t) to predict mean vectors that are close to the mean of the reverse of the forward process µₜ(xₜ, x₀).
If we can:
- write down µₜ(xₜ, x₀) for t=1, that is µ₁(x₁, x₀) and,
- if µ₁(x₁, x₀) is close to the natural image sample X₀
then we can turn the original task of “minimizing the distance between between µₚ(x₁, 1) and X₀” to an approximation task of “minimizing the distance between between µₚ(x₁, 1) and µ₁(x₁, x₀)”. The benefit of the latter is that we can handle the case of t=1 using Algorithm 0, the same way as for the cases of t≥2.
We can write down µ₁(x₁, x₀)
Note that we cannot set t=1 into the first line above. This is because when t=1, quantifies such as 𝛼ₜ₋₁ bar is not defined. But we can set t=1 into the second line. This is because the second line replaces x₀ in the first line with an expression that only mentions x₁. And all quantities involving 𝛼₁ and β₁ are defined.
Set t=1 to derive:
After plugging sample for x₁ and ϵ₁, the above is a constant.
We know µ₁(x₁, x₀) must be close to the natural image X₀
This is because µ₁(x₁, x₀) is the mean vector for the ending random variable x₀ from the reverse of the forward process. So if we draw a sample for x₀ from the reverse of the forward process, we should get an image that is close to the natural image X₀. That’s by the definition of the reverse of the forward process. In fact, if we draw many many images for x₀ from the reverse of the forward process and averages all those sampled images, the average should be exactly equal to X₀. In other words, the reverse of the forward process can generate the exact starting image in expectation. But if we only sample a single image for x₀ from the reverse of the forward process, that sample is not equal to X₀. That’s why we are approximating the L₀ term.
Now we can use Algorithm 0 to handle all timestamps starting from t=1. Mathematically, we expand LMₜ₋₁ which only covers the cases for t≥2, see the t~Uni(2,T) part under the expectation:
to cover the case for t=1 as well, see the t~Uni(1,T) part under the expectation:
Lₛᵢₘₚₗₑ is the final loss function, and it covers all timestamps from 1 to T. Algorithm 1 from the paper, copied below, minimizes Lₛᵢₘₚₗₑ:
We happily notice that at line (3), the timestamp t is sampled from the uniform distribution Uni(1, T) covering all cases t≥1 because of the approximation for the L₀ term.
No concern on high variance in sample-averaging Lₛᵢₘₚₗₑ?
Previously I said that we can use sample-averaging to compute the analytical formula for expectation of the negative log likelihood L with respect to all the latent random variable x₁ to x_T. But this results in high variance in the computed expectation if we can only afford to draw one sample per random variable for practical computation reason.
Why we have no problem to use sample-averaging to compute the analytical formula Lₛᵢₘₚₗₑ and drawing a single sample per random variable?
The main reason is that in the final loss function Lₛᵢₘₚₗₑ, there are only 3 random variables to sample, compared to the T+1=1000+1 random variables to sample in the case of expectation of the negative log likelihood. So the variance in the final loss function’s case should be much smaller than the case of expected negative log likelihood.
To make things even better, now the samples are not drawn through uncalibrated neural networks any more, they all come from standard distributions whose behaviours do not depend on how much we’ve trained our neural networks. This results in a more predictable parameter learning experience.
But just for fun, let’s consider the alternative to sample-averaging. That is, to compute the expectation in the final loss function Lₛᵢₘₚₗₑ analytically:
- For the random variable x₀, there is no way to compute the expectation with respect to it analytically because the data distribution q(x₀) is unknown. So sample-averaging is the only option.
- For the random variable t that comes from an uniform distribution. It’s expectation is just take all possible values of t, compute the formula inside the expectation and average them. This is equivalent to sample-averaging in our context of stochastic gradient descent. Even though in stochastic gradient descent, Algorithm 1 only works with a single term, instead of adding all those terms together and dividing the sum by T, the algorithm does it repeated until converging. This is equivalent to computing the expectation over t asymptotically. For more details, please see the proof in Can We Use Stochastic Gradient Descent (SGD) on a Linear Regression Model?
- For the standard multivariate Gaussian random variable ϵₜ, we can use Gaussian quadrature to approximate the expectation analytically. For more details about Gaussian quadrature, please see Variational Gaussian Process (VGP) — What To Do When Things Are Not Gaussian. But Gaussian quadrature works better in low dimensional settings. In our case, the ϵₜ is a d dimensional random variable with d being the number of pixels in the images that we want to generation, so d is a large integer. And applying Gaussian quadrature is not practical. For more details about why it is not practical, please see the Appendix of the above link.
Given the above, using sample-averaging to approximate the expectation in Lₛᵢₘₚₗₑ is a sensible choice.
This article established clear motivation why the denoising diffusion probabilistic model is designed in that way by reasoning about the relationships among the forward process q(xₜ|xₜ₋₁), the reverse of the forward process q(xₜ₋₁|xₜ, x₀) and the reverse process p(xₜ₋₁|xₜ). It also provides detailed derivation of the loss function used for model parameter learning.
Why we won’t loss model parameters when applying sample-averaging to derive the analytical formula for the loss function L
A typical problem applying sample-averaging to approximate integrations in a loss function is that the resulting formula does not mention model parameters anymore. The reparameterization trick (see here) is the go-to recipe to prevent this from happening.
Our case of using sample-averaging to derive the analytical approximation for the loss function L does not have the losing model parameter problem, let’s use an example with a short reverse process (T=1) to see why.
Let’s show the loss function L together with some manipulations to demonstrate sample-averaging:
Line (1) is the loss L, and line(2) replaces the expectation notation with its mathematical definition.
Line (3) set T=1 to demonstrate following derivations on a short reverse trajectory.
Line (4) factorizes the joint probability inside the inner integration using the definition of the reverse process.
Line (5) replaces all probability density function notation with the actual probability density distribution names. It also reveals that the random variable x₁ is sample-able from the standard multivariate Gaussian distribution N(0, 1). Let’s denote S₁ as the sample for x₁.
Line (6) plugs in the sample S₁, removing the inner integration by doing sample-averaging using only one sample, for demonstration purpose. Sample-averaging is an approximation, which is reflected by the approximation sign “≈” in front of the line.
Line (7) draws the sample S₀ for the random variable x₀ from the unknown distribution q(x₀); practically just randomly pick an natural image from the training set. It then uses sample-averaging again to remove the integration over x₀.
Line (8) plugs in the formula for the multivariate Gaussian probability density function. The proportional symbol “∝” allows me to drop the normalization terms in front of the exponential function.
Line (9) simplifies the formula. It reveals after sample-averaging, the analytical loss is still a function that mentions all model parameters. So no need for the reparameterization trick.
A deep dive into the motivation behind the denoising diffusion model and detailed derivations for the loss function
The Denoising Diffusion Probabilistic Models by Jonathan Ho et. al. is a great paper. But I had difficulty understanding it. My colleagues told me they were also left confused after reading it. So I decided to dive into the model and worked out all the derivations. In this article, I will focus on the two main obstacles to understand the paper:
- why the denoising diffusion model is designed in terms of the forward process, the forward process posteriors (which I will call the reverse of the forward process to avoid the word “posteriors” because it confuses me) and backward process. And what is the relationship among these processes?
- how to derive the mysterious loss function. In the paper, there are many skipped steps in deriving the loss function Lₛᵢₘₚₗₑ. I went through all derivations to fill in the missing steps. Now I realize the derivation of the analytical formula for Lₛᵢₘₚₗₑ tells a truly beautiful Bayesian story. And after all the steps filled in, the whole story is easy to understand.
Medium supports Unicode in text. This allows me to write many math subscript notations such as x₀ and xₜ. But I could not write down some other subscripts. For example:
For those things, I will use an underscore “_” to lead the subscriptions, such as x_T, and p(x_0:T).
If some math notations render as question marks on your phone, please try to read this article from a computer. This is a known Unicode rendering issue.
Our goal is to use a neural network to generate natural images from noise. The input to the neural network is noise, and the output should be a natural image, such as a human face. Different noises will result in different natural images, for example, one noise may lead to a woman’s face, another noise to a man’s.
You may ask, what kind of noise? Without other constraints, a sensible researcher who is in love with Bayesian method will start with a Gaussian noise.
What is the dimensionality of this noise? Well, the desirable output is a colorful 2D image with red-green-blue (RGB) values. Let’s simplify it by first transform a colorful image into grayscales between [0, 255] and then scale the grayscales to the range of [-1, 1]. And then reshape this 2D array of scaled grayscale values into a long 1D vector, with length d. I will mention the name d multiple times in the article. Let’s use the above as our easy definition of the image generation task. But please know that in reality, neural networks can generate colorful images directly.
It is natural to assume the dimension and structure of the input noise is the same as the dimension and structure of the output image, which is a vector of length d. So the noise should be a d-dimensional multivariate standard Gaussian N(0, 1) — that’s the academic default.
Now the task of generating images from noise is more concrete: design a neural network that takes a sample from a d-dimensional multivariate standard Gaussian and outputs a d-dimensional vector of scaled grayscale values. Turning the output vector into a 2D shape and RGB colors is something we all know how to do, and not of interest of this article.
Generating an natural image from noise in one step is difficult. How about generating an image in many smaller steps? Sort of like to let an image emerge from a Kodak film in old fashion photography. This way, in each step, the neural network should have a simpler task, as the input and output in each step is more similar to each other than from pure noise to a final natural image.
This iterative generating idea comes with its own problem. What should the in-between images look like? A person old enough (like me) to have experience with old fashion photography would suggest that the in-between images should be gradual — it should not be the case that during this iterative process, an image of a cat first appears, and then the cat turns into a human face.
The “gradual-ness” constraint over in-between images is sensible. But how to formulate it mathematically?
Foreword process turns a natural image into noise
Even though it is not clear how to formulate the gradual-ness of the iterative generation process, it is easy to formulate the opposite process — the process that turns a natural image into pure noise by successively adding a little bit of Gaussian noise into it.
The process of turning a natural image into pure noise by adding successive noise to it is called the forward diffusion process, or forward process in short.
Reverse process turns noise into a natural image
On the other hand, we call the process of turning a Gaussian noise into a natural image the reverse process.
The following figure from the paper depicts these two processes, with the forward process at the bottom, and the reverse process at the top.
In the figure x₀, x₁, x₂,…, x_T are d-dimensional multivariate Gaussian random variables. x₀ is corresponding to the natural image, which means a sample drawn from x₀’s probability density function, which I will get to later, should look like a natural image. And x₁, x₂,…, x_T are corresponding to the in-between denoised images. I will also introduce their probability density functions later.
To remember that x₀ is the natural image and x_T is pure noise, and not the other way round, please remember: small subscript, small noise, big subscript, big noise. I got this idea from this video.
Now there is a way to mathematically define the gradual-ness of in-between images. Every image xₜ generated from the reverse process must be close to the corresponding image diffused from the forward process.
The forward process is a probabilistic model. Why? Because every step adds a Gaussian noise into an image. So the result is not deterministic — starting from the same natural image x₀, you may end up with different samples of standard multivariate Gaussian noise x_T. Just like dipping a drop of ink into a glass of water at different times will give you different diffusions each time.
In this probabilistic model, x₀, x₁,… to x_T are random variables. Each of them is a d-dimensional random variable.
Since the forward process is probabilistic, the appropriate mathematical tool to talk about it is probability density function and probability theory.
The above figure uses q(xₜ|xₜ₋₁) to denote the probability density function for a single step from image xₜ to image xₜ₋₁ in the forward diffusion process. We define its probability density function as:
with βₜ being a value that changes over time, much like scheduling learning rate.
Using the reparameterization trick (see derivation here) the random variable xₜ can be equivalently described as:
where ϵₜ₋₁ is d-dimensional standard Gaussian noise. This formula reveals that the more noisier image xₜ is a weighted average between the less noisier image xₜ₋₁ and some noise ϵₜ₋₁. In other words, the forward process adds noise ϵₜ₋₁ into the less noisier image xₜ₋₁. The value of βₜ controls the amount of noise to add on timestamp t. That’s why βₜ is scheduled to be really small values, from β₁=10⁻⁴ to β_T=10⁻². And T is set to 1000 — otherwise, noise will quickly dominate the forward process.
q(xₜ₋₁|xₜ) is the probability density function for the random variable xₜ₋₁, which describes a single step in the forward process. The following joint probability density function describes the full forward process:
This is the first factorization of the forward process q(x_1:T|x₀).
Why does the above joint probability density function of the forward process q(x_1:T|x₀) depend on the random variable x₀? This is because when t=1, q(xₜ|xₜ₋₁) turns into q(x₁|x₀), which mensions x₀.
The term q(x₁|x₀) also reinforces the fact that x₀ is a random variable. After all, the notation q(x₁|x₀) in probability theory denotes the probability density for the random variable x₁ given the random variable x₀.
But do we know the probability density function q(x₀) for x₀? No, we don’t. q(x₀) describes the probability of natural images. This means:
- given a natural image, say X₀, plugging X₀ in q(x₀=X₀) should return a probability number between 0 and 1 to indicate how likely this natural image occurs in all natural images.
- Summing up the probability numbers from q(x₀) for all natural images gives us 1.
Obviously, as usual, we don’t know the analytical formula for q(x₀). But this does not prevent us from writing its notation down, and from drawing samples for the random variable x₀— we just randomly pick an image from our training set of natural images.
The forward process consists of T+1 random variables x₀, x₁ to x_T. They form two groups:
- x₀: there are observations to the random variable x₀. The observations are the actual images from the training dataset. We call x₀ observational random variable.
- x₁ to x_T: there is no observation for them, hence they are latent random variables.
The definition of the forward process brings three important properties.
Property 1: Fully joint probability density function q(x_0:T)
Joint probability density function represents trajectories. Visually, the fully joint probability density function q(x_0:T) describes the set of possible trajectories of images. Each trajectory consists of T+1 images with x₀ representing a noise-free image, and x_T representing a pure noise image.
The following illustration shows some trajectories that start from two natural images X₀ and X₁. These trajectories end at different pure noise images, indicating that the forward process is a probabilistic. It is an illustration because I hand-drew the picture — it is not necessary that the trajectories starting from X₀ do not overlap with the trajectories starting from X₁.
The illustration also shows that at a timestamp t, the random variable xₜ is responsible to explain all possible images that could be generated by the forward process at this timestamp.
Property 2: Marginal probability density function q(xₜ|x₀)
Using the reparameterization trick repeatedly (see derivation here again) give us the probability density functions for only a single latent random variable without the dependence on another latent random variable, hence the resulting probability density is called marginal:
with
This property reveals that given x₀, the latent random variable xₜ does not depend on the latent random variable xₜ₋₁ anymore. In other words, given x₀, the latent random variables x₁ to x_T are independent to each other.
For independent random variables a and b, the product rule in probability theory is p(a, b) = p(a) p(b). Applying the product rule gives us the second factorization of q(x_1:T|x₀):
The two factorizations are equivalent, meaning they describe the same joint probability distribution for the same set of random variables. Sometimes we will choose one over the other to simplify formula derivations.
Property 3: The reverse of the forward process q(xₜ₋₁|xₜ, x₀)
Using the Bayes rule, it is possible to derive the probability density function for the reverse of the forward process. In the paper, reverse of the forward process is called forward process posterior. But I found the word “posterior” leads to confusing in the context of many conditionals in this article.
Let’s start from the forward process q(xₜ|xₜ₋₁) which is a probability density function for the random variable xₜ:
Note that we can add the redundant conditioned random variable x₀ into q(xₜ|xₜ₋₁) to turn it into q(xₜ|xₜ₋₁, x₀) because by definition, given xₜ₋₁, the random variable xₜ doesn’t depend on any other random variable. So adding x₀ as a dependence doesn’t change the probability density for xₜ.
However, the random variable xₜ in q(xₜ|x₀) and the random variable xₜ₋₁ in q(xₜ₋₁|x₀) are truly depend on x₀. We know the formula q(xₜ|x₀) and q(xₜ₋₁|x₀) from the above property 2 of the marginal probability density.
By rearranging the terms, we can derive the probability density function for the reverse of the forward process:
First, please notice that when t=1, q(xₜ₋₁|xₜ, x₀) turns into q(x₀|x₁, x₀), with x₀ appearing in both left and right of the bar “|”. The quantity q(x₀|x₁, x₀) always evaluates to 1 because given x₀, the probability of x₀ is 1. Why? Because x₀ has already happened, there is no uncertainty about it anymore.
Bear this in mind, and you will smile when you see later that the first analytical loss function starts with t=2. And you will smile again at the end, when we expand the loss function to cover the case of t=1.
The left hand side of the equation q(xₜ₋₁|xₜ, x₀) tells us this is a process that takes a noisier image xₜ and generates an less noisy image xₜ₋₁ — remember large subscript, more noise, small subscript, less noise. So q(xₜ₋₁|xₜ, x₀) describes the process that goes in the opposite direction as the forward process. We call it the reverse of the forward process.
The right hand side of the equation tells us that q(xₜ₋₁|xₜ, x₀) is defined by the forward process probability density q(xₜ|xₜ₋₁), rescaled by q(xₜ₋₁|x₀)/q(xₜ|x₀). We have defined all these three components. If you plug them into the right hand side of the equation, and patiently simplifying the formula, you will see that q(xₜ₋₁|xₜ, x₀) is a multivariate Gaussian distribution, which is fully specified by its mean vector and covariance matrix. Mathematically, we have:
with
We can see the mean vector for xₜ is a weighted sum between x₀ and xₜ. And the weights in front of these two random variables depend on the timestamp t. The covariance matrix is a quantity that also depends on the timestamp t. Neither the mean vector or the covariance matrix mention trainable parameters (we will introduce trainable parameters in a bit later).
The reverse of the forward process is important because it describes the generation process that we exactly want — a process that gradually turns noise into a natural image. This importance is reflected later in the loss function derivation.
Where does the reverse of the forward process start?
The q(xₜ₋₁|xₜ, x₀) distribution defines the mean vector and the covariance matrix for the random variable xₜ₋₁, given xₜ and x₀. We can get an sample of x₀ by randomly picking a natural image from the training set, but how to get a sample of xₜ? Inductively, we can sample from q(xₜ|xₜ₊₁, x₀). But how to sample the last, that is, the noisiest random variable x_T? Note x_T is actually the starting point of the reverse of the forward pross. Well, we can sample from q(x_T|x₀), which is defined in Property 2. Looking at the definition of q(xₜ|x₀) again:
we realize when t=T, a large number, say 1000, the mean vector is almost 0, and the covariance matrix is almost I. In other words, x_T follows a standard multivariate Gaussian distribution. In other words, the reverse of the forward process essentially starts from pure noise.
Where does the reverse of the forward process end?
Another burning question is can the reverse of the forward process eventually generate the original image X₀? Yes, in expectation, the reverse of the forward process that conditioned on x₀=X₀, will generate X₀. “In expectation” means if run the reverse of the forward process which is conditioned on x₀=X₀ many many times, and average the final images at the end of these trajectories, we will get exact X₀.
How does the reverse of the forward process do that? We can know the answer by staring at the mean vector of q(xₜ₋₁|xₜ, x₀):
At each step, the above mean vector for the random variable xₜ₋₁ contains a bit of x₀=X₀, and a bit of the random variable xₜ, which transitively includes a bit of x₀=X₀ from its mean vector µₜ₊₁(xₜ₊₁, x₀). This way, step by step (from large t to small t), the random variables from the reverse of the forward process converges to x₀ more and more.
At the top of the Figure is the reverse process p(xₜ₋₁|xₜ). The reverse process takes a noiser image xₜ and generate a new image xₜ₋₁ that contains less noise, same as the reverse of the forward process q(xₜ₋₁|xₜ, x₀). Note the figure uses notation p_θ, but I decided to use p because p_θ doesn’t look good in Medium.
The reverse process must contain the same set of random variables x₀, x₁ to x_T as in the reverse of the forward process because we want to establish the “gradual-ness” correspondence between variables between the reverse process and the reverse of the forward process. In other words, since the reverse of the forward process already tells us how to gradually remove noise from images step by step, we want our neural network to mimic that at every step. The way we convey the stepwise mimicking requirement is to require the random variable xₜ in the reverse process to behave like the corresponding random variable in the reverse of the forward process.
Since the reverse of the forward process is defined using multivariate Gaussian distributions, it makes sense to also define the reverse process using multivariate Gaussian distributions:
In the inductive formula, p(xₜ₋₁|xₜ), the mean vector μₚ(xₜ, t) and the covariance matrix Σₚ(xₜ, t) are actually two deep neural networks that predicts the d-dimensional mean and the d×d dimensional covariance matrix for the multivariate Gaussian distribution.
In the base formula, p(x_T) is a standard multivariate Gaussian, which confirms that the reverse process starts from pure noise. Note, the starting point of the reverse process is not conditioned on any random variable, unlike the case of the reverse of the forward process.
Both neural networks μₚ(xₜ, t) and Σₚ(xₜ, t) takes two inputs, the first is the noiser image xₜ, and the second is the timestamp t. The noiser image xₜ makes sense. After all, we want to use the neural networks to denoise the noiser image. But how to understand the timestamp t as an input to the neural networks? The intention is the same as in the Transformer model for natural language processing, which uses cosines to encode the position of a word in a sentence and feeds the encoded position as additional input to the Transformer. Here we also want to encode where we are in the reverse process as additional input to the neural networks to give them a bit of positional context.
How does a neural network predicts a d-dimensional mean vector and the d×d covariance matrix? For the mean vector, the mean predicting neural network will have d output heads, each predicting an entry in the d-dimensional mean vector. The covariance matrix predicting neural network has d×d output heads. This is a rough understanding. With a large d, the number of neural network outputs, especially for the covariance predicting neural network, is huge. There are more concise ways for the covariance matrix, see the mean-field parameterization from here.
μₚ(xₜ, t) and Σₚ(xₜ, t) contain model parameters
The weights inside the mean vector predicting μₚ(xₜ, t) network and the covariance matrix predicting network Σₚ(xₜ, t) are the model parameters in this machine learning task. We want to use optimization to find proper values for those model parameters so when starting from a noise sample from p(x_T), and iteratively sample xₜ₋₁ from the distribution p(xₜ₋₁|xₜ), and when we retrieve a sample for x₀ from the distribution p(x₀|x₁), this sample of x₀ is the scaled grayscale of a realistic looking natural image.
Another question you may have is that should we use the same mean neural network to predict the mean vector for all timestamp? Same question for the covariance matrix prediction network. Well it is a design choice. At least one network is needed, and the authors experimentally showed one network is enough. You can have two or more, at the expense of more parameters to learn.
Like the case in the forward process, the joint probability density function p(x_0:T) also represents the set of image trajectories that the reverse process can generate, by starting from some pure Gaussian noise.
Why do we need the reverse process p(xₜ₋₁|xₜ)? Isn’t the reverse of the forward process q(xₜ₋₁|xₜ, x₀) enough?
Since we already know the distribution q(xₜ₋₁|xₜ, x₀) for the reverse of the forward process, which de-noise images, you may wonder, why is the reverse process p(xₜ₋₁|xₜ) even needed? Why can’t we directly sample natural images from q(xₜ₋₁|xₜ, x₀)?
Of course we can. But please look at q(xₜ₋₁|xₜ, x₀) closely. xₜ₋₁ not only depends on xₜ, it also depends on the initial image x₀. This means we have to know an initial image to start sampling and the goal is to finally sample an image that is similar to the already known x₀. This is not what we want. We want to be able to sample natural images freely!
At this point, you don’t want to stop. You may ask, can we work out the analytical formula for q(xₜ₋₁|xₜ), that is, the reverse of the forward process without the dependence on x₀? Let’s do it by using the Bayes rule again:
Now we can spot the trouble: on the right hand side of the equation, q(xₜ|xₜ₋₁) is defined, but q(xₜ₋₁) and q(xₜ) are not defined. So it is a dead end.
You still wouldn’t stop, and you ask, why can’t we define q(xₜ)? I hear you! Let’s try to define q(xₜ). Conceptually, q(xₜ) represents the possible set of images the forward process can generate at timestamp t. Thinking about it from the trajectory point of view, see again below, the images that the forward process can generate at time t depends on where the trajectories start (natural image X₀, X₁, etc.). So the probability density function of q(xₜ) will inevitably reference the probability density function of the starting point, that is q(x₀).
What is q(x₀)? It is the probability density of the training data. Unfortunately, previously we have already made it clear that q(x₀) is unknown. The best we can do is to sample from it by randomly picking natural images from our training set. Consequently, we will not be able to write down the analytical formula for q(xₜ).
The reverse of the forward process q(xₜ₋₁|xₜ, x₀) smartly defines the probability density function for the random variable xₜ₋₁ to condition on x₀. Conditioning on x₀ allows us to plug in a sample for x₀ to reason about properties of xₜ₋₁. As long as we can sample from x₀, which we can, and reason about xₜ₋₁ in an expectation fashion with respect to the samples from x₀, we are good. For more details about “reasoning a random variable in an expectation fashion”, see sampling-averaging below.
Can mimicking the reverse of the forward process gives us a reverse process that can start from any multivariate Gaussian noise?
We’ve established that the starting point of the reverse of the forward process are pure Gaussian noise already. By mimicking the the behavior of the reverse of the forward process that denoises many Gaussian noises into natural images from the training set, our unconditioned denoising model, which is the reverse process p(xₜ₋₁|xₜ) should be capable to turn any Gaussian noise into a realistic looking natural images. Just like if a linearly regressed line path through many data points, we would expect the line to interpolate to other unseen data points along the same direction.
With the structure of the reverse process defined and its necessity clearly explained, now it is the time to think about the objective function we minimize to perform parameter learning for the mean vector predicting network μₚ(xₜ, t) and the covariance matrix predicting network Σₚ(xₜ, t).
For a probabilistic model, the likelihood of the data is always a good starting point to think about the objective function. Let’s define what “ likelihood of the data” means for our model.
The joint probability density function, with data plugged in
The joint probability density of the reverse process p(x₀, x₁⋯, x_T), or shorthanded as p(x_0:T) is a function has T+1 random variables as its arguments, namely x₀, x₁ to x_T. Being a probability density function, it evaluates to a probability number between [0, 1] when concrete values are plugged into its arguments.
The purpose of a probabilistic model is to explain training data well. “explaining the training data well” means images in the training dataset evaluate to a high probability number when they are plugged into the x₀ argument, one image at a time.
Plugging an image X₀ into the argument x₀ of the joint probability density function p(x_0:T), which has T+1 random variable, results in a new function with T random variables: p(x₀=X₀, x₁⋯, x_T). This function cannot be evaluated into a probability number yet, because it mentions random variables x₁ to x_T, which are not concrete values. x₁ to x_T are latent random variables, there is no observations for them, so we cannot find some meaningful concrete values (like the case for the observational random variable x₀) to plug in for them. They need to be removed, or more precisely, integrated away.
The likelihood p(x₀)
By definition, a random variable, say, x₁, describes a spectrum of possible values. The go-to way to remove a random variable from a probability density function is to compute the expected value of the density function with respect to that random variable. In other words, to remove the random variable x₁ from p(x₀, x₁⋯, x_T), compute the average value, or alternatively, expectation, of this function with respect to x₁. Essentially we are saying since we cannot observe concrete values for latent random variables, we have to reason about their average behavior.
Let’s pick x₁ to integrate away first. Since x₁ is a continuous random variable, its expectation is defined by an integration, hence the name “integrating a random variable away”:
The same “integrating away” approach, applied T times, can remove all latent random variables from p(x₀, x₁⋯, x_T):
p(x₀) now only describes how likely actual images can be generated using our model, we call p(x₀) the likelihood of the data.
Note the above equation is merely a notation to indicate that p(x₀) is what is left after all latent random variables are integrated away. It does not tell us how to integrate them away. This is because the integration symbol “∫” represents the result of the integration, without telling us how to do the integration.
Why not integrate x₀ away from the likelihood p(x₀) as well?
The above T-dimensional integration only integrated away the latent random variables x₁ to x_T, and left the observational random variable x₀ in p(x₀). Why? Because if x₀ is integrated away as well, the whole joint probability density function becomes 1 since all probability density function when integrated over its full set of random variables, yields 1:
You see, there is no place to plug in actual images to this number 1 to evaluate how well it explains the training data. This prevents us from performing parameter learning.
That’s why we leave x₀ unintegrated-away and work with p(x₀), which is called likelihood of the data, or likelihood for short.
The likelihood p(x₀) mentions all model parameters
Even though the likelihood p(x₀), with an actual image plugged in, that is, p(x₀=X₀), is a function that doesn’t mention any random variable, it still mentions all model parameters, that is, weights in the two neural networks, via the probability density function p(xₜ₋₁|xₜ), shown again here:
The weights are in the mean predicting network ηₚ and the covariance matrix predicting network Σₚ. The latent random variables x₁ to x_T are integrated away, but their mean, and covariance matrix terms are left in the result of the integrations.
You may ask, since we have not discussed how the latent random variables are integrated away, how do we know ηₚ (xₜ, t) and Σₚ(xₜ, t) will survive the integrations? You will see for yourself later in the derivation for the analytical loss section, but here you have to believe me: if after the integrations, those Gaussian latent random variables are gone, and the two very things, namely, the mean vector ηₚ (xₜ, t), and the covariance matrix Σₚ(xₜ, t), that describe them are also gone, then it seems that those random variables have never existed in our model. This doesn’t make sense. So the two neural networks ηₚ (xₜ, t) and Σₚ(xₜ, t) will survive the integrations. In other words p(x₀) mentions all model parameters in ηₚ (xₜ, t) and Σₚ(xₜ, t).
p(x₀) mentions all model parameters, which are the weights in the two neural networks, but we don’t know the proper values for the model parameters yet. If we pretend to know all parameter values, then we can evaluate p(x₀) into a probability numbers between [0, 1] by plugging a training image from the training image set, one image at a time. This results in many probability numbers. Averaging these probability numbers gives us a measurement of how well our model explains the training data.
Of course, we don’t know the values for the neural network weights. We can set them to arbitrary values, but this will likely result in a poor model that doesn’t explain the training data well. In this case, it is not that our model structure is incapable of explaining the data, it is the model has not been calibrated the model correctly. By “model structure” I mean the reverse process with two deep neural networks predicting the mean vector and the covariance matrix of a denoised image.
Optimization can find proper parameter values for those neural networks And it needs a loss function to minimize.
The negative log likelihood of data as the loss function
A loss function must mention all model parameters. The likelihood p(x₀) satisfies this requirement. A good model should let an actual image X₀ evaluates to a high likelihood probability number p(x₀=X₀). We want to minimize the loss function, hence the negative sign. A good model not only need to work well for a single image from the training set, it needs to work well for all images in the training set, hence the expectation with respect to images sampled from the training set x₀~q(x₀). We can take the log of p(x₀) because log is a monotonic function that does not affect the optimal value for the loss function; we want to introduce the log function because it is the essential part in the KL-divergence that we will use below.
All the above thinking leads use to the famous negative expected log likelihood, denoted by L:
Line (1) is the definition of the negative log likelihood of data.
Line (2) plugs in the definition of p(x₀), which integrates all latent random variables x₁ to x_T out of the density function p(x₀, ⋯, x_T).
The negative log likelihood loss function is not opitimizable
The standard way to perform parameter learning via optimization is to use gradient descent to minimize the loss function with respect to the model parameters. Gradient descent needs to know the analytical formula for the loss function to take the derivative of it. Unfortunately, the analytical formula for the negative log likelihood loss function is very difficult to derive.
To work out the analytical formula for the loss function L, let’s look at its definition again:
p(x_0:T) is previously defined (shown again here) as:
It is easy to see that our model parameters — the weights in the mean vector predicting neural networks μₚ and the covariance matrix predicting neural network Σₚ — are mentioned in the loss function. But they are mentioned in the integration symbol ∫.
Unlike things such as the exponential symbol exp, or the squared operator “²”, which represent computation which we immediately how to do. The integration symbol represents the result of a computation, that is, it asks you to integrate a function, without telling you how to do the integration.
From our Calculus course, we all know that taking derivative is work, but performing integration is art — taking the derivative of a term is a mechanical procedure as long as you have the derivative cheat sheet. But integration requires creativity, and we have so many integrations that we just don’t know how to do.
Unfortunately, the integration of p(x_0:T) inside the loss function L falls into the kind of integrations that are very hard to solve analytically. We call it an intractable integration. Let’s use a short reverse process where T is set to 1 to demonstrate this point. Our goal is to show that L is intractable:
Since we know how to sample x₀ from our training set of natural images, the outside expectation with respect to x₀ can be dealt with using sample-averaging (see the next section about sample-averaging). So the only difficult term is the integration inside the log and we want to show that the integration is hard to solve:
After setting T to 1, the above term turns into:
Line (1) shows the the shortened joint probability density function for the reverse process when T=1.
Line (2) factorizes the joint into product of probability density functions, each for a single random variable, x₁ and x₀ respectively.
Line (3) plugs in the names of those single probability density functions.
Line (4) plugs in the actual probability density functions, which are multivariate Gaussians. The first exp is for the random variable x₁ from the standard Gaussian, and the second exp is for the random variable x₀ conditioned on x₁. I used the proportional symbol “∝” to ignore the normalization terms in front of each multivariate Gaussian.
The integration at line (4) is hard to solve analytically. Note, in this case, we know how to compute the integration analytically by using the product rule of integration, it is just hard and messy, especially when T=1000. The variational method (explained later) that the authors of the paper proposed is more elegant.
This little exercise also reveals that we may be able to use a technique called sample-averaging to approximate the integration analytically. This is because in line (4) the probability density function of each random variable is mentioned only once, and this single mention of the density function is computing the expected value with respect to that random variable. Sample-average can approximate such expectations.
Can we use sample-averaging to derive the analytical form for the loss function?
The loss function L contains an intractable integration. There are multiple ways to approximate an intractable integration, such as sample-averaging, importance sampling and Gaussian quadrature. Let’s look at the simplest of them all, sample-averaging.
What is sample-averaging
Sample-averaging is simple — average function evaluations based on samples of a random variable to compute expectation. Formally, let x be a random variable from the distribution h(x), then the following integration on function f(x) with respect to x can be approximated by the averaging the evaluation of f(x) with samples of x from the distribution h(x) plugged in.
In other words, as long as f(x) is evaluable when samples for x is plugged-in, sample-averaging approximates the integration.
Approximating an integration analytically using sampling-averaging
As you can see, sample-averaging approximates an integration with a sum of the integral function. This sum of terms is analytically with respect to our model parameters. A simple example shows this point: we want to write down the analytical formula for the loss function of some model, written as the the following integration. In this integration, let’s say x is a random variable that can be sampled from h(x) and μ is our model parameter to optimize.
After drawing two samples S₁ and S₂ of x from h(x), and applying sample-averaging to approximate this integration results in an analytical expression for μ at the right hand side of the following approximation equation:
The right hand side of the approximation is an expression that mentions the model parameter μ. This expression is analytical — it does not mention symbols that represents results of computation, such as the integration ∫, it only mentions symbols that represents computation, such as the exponential function exp, and the squared operator “²”, for which we know how to compute gradients.
We use sample-averaging all the time. For example, to compute the expected height from students in a school, we don’t know the distribution of student’s height, but we have samples of measured students’ heights. Then compute the expectation by doing the averaging.
Counter-example that makes sample-averaging inapplicable
As long as p(x) does not appear in the function f(x) to be integrated over, and p(x) is easy to sample, we can use sample-averaging. Here is a counterexample:
In this case, q(xₒ) is the data distribution, whose probability density formula is unknown. Still we can sample from it by randomly picking images from the training set. Sample-averaging can remove the right q(xₒ) But note q(xₒ) also appears in the g function being integrated. This left q(xₒ) cannot be removed by sample-averaging. g(1+q(xₒ=Xₒ)) remains an un-evaluable function even with the sample Xₒ plugged-in. So in this case, sample-averaging cannot solve the integration.
Luckly, our loss function L doesn’t fall into this category, so we can use sample-averaging to approximate L analytically.
Deriving L’s analytical formula via sample-averaging
To derive the analytical formula for the loss function L, rewrite L as follows:
Now it is easy to see that to apply sample-averaging, we can sample latent random variable x₁ to x_T first from the definition of the reverse process:
Finally we sample the observational random variable x₀ from q(x₀).
Explicitly, the sampling process goes like this:
- First a sample for x_T from the standard multivariate Gaussian distribution using the base case to remove the integration with respect to x_T.
- With a sample Sₜ for the random variable xₜ at hand, plug Sₜ into p(xₜ₋₁|xₜ=Sₜ), then sample xₜ₋₁, using the inductive case.
- As long as we don’t lose model parameters during this process, we will end up with an analytical formula for the loss function L. Losing model parameters during sample-averaging means that it is possible that sample-averaging results in a formula that does not mention model parameters anymore. This is bad because a loss function that does not mention model parameters is useless. Reparameterization trick is used to prevent this from happening. But in our case, we don’t need to worry about losing model parameters when applying sample-averaging. Appendix “Why we won’t loss model parameters when applying sample-averaging to derive the analytical formula for the loss function L?” explains why.
- Once all samples for x₁ to x_T are available, let’s call them a sample trajectory. Plug this trajectory into the joint probability density p(x_0:T) to get the analytical expression for p(x_0) under this trajectory.
- Repeat steps 1~4 to get analytical expression for p(x_0) under different trajectories and average them to approximate the inner integration. Say there are m sample trajectories, each trajectory i gives an analytical formula pᵢ(x_0), then the analytical formula for the average is:
Analytical formula for L via sample-averaging is expensive to compute
The above sample-averaging requires a lot of computation, this is because each trajectory requires T samples, one for each latent random variable. Common sense tells us drawing a single sample for a random variable is not enough — for example, you shouldn’t compute the average student heights in a school by just measuring the height of a single student. Why that’s bad? Because small sample size gives us an estimate, in this example, the expected height, with high variance. Every time you draw a single sample, you get a different expectation — that’s variance.
We want to draw many samples for each random variable because more samples means the average more closely approximates the original integration. In other words, less variance. But drawing more samples requires a lot of computation: drawing two samples for each latent random variable and set T=1000, results in m=2¹⁰⁰⁰ trajectories to average over. That’s very expensive.
Practically, we can only afford to draw a single sample for each latent random variable. But this brings the high variance problem back.
Analytical formula for L via sample-averaging using small number of samples has high variance
The problem with drawing only a very small number of samples (for example, a single sample per latent random variable) for each latent random variable is that the probability number computed for an actual image X₀, that is, p(x₀=X₀) has high variance. This is because the probability number p(x₀=X₀) depends on the sampled concrete values for the latent random variable x₁ to x_T. Every time you compute p(x₀=X₀) for the same image X₀, this probability number is different. Since there are T concrete sampled values in a sample trajectory, the variance is likely to be quite high.
To make things worse, at the beginning of the parameter learning process, the weights in the mean vector and variance matrix predicting neural networks are randomly initialised. So the sampled images through those neural networks may be of very bad quality — in the sense that they don’t look like a denoised version of the previous image at all. Even though this doesn’t contribute to a higher variance, but low quality samples makes the parameter learning harder.
Why a high variance to p(x₀=X₀) is bad? Because p(x₀=X₀) is our measurement of how well our model explains the training data, in this case, how well it explains the training image X₀. If for the same image, our measurement sometimes reports a large p(x₀=X₀) probability number and sometimes reports a small probability number, then the optimizer, for example the Adam optimizer, is uncertain whether our current model can explain the training data well or not. This uncertainty is usually reflected by a very slow and even diverging training process.
Since sample-averaging is not a good way to to derive the analytical formula for the loss function L. Is there a better way? Once again, variational inference comes to the rescue.
To shorten the current article, I decided not to introduce variational inference and use it as known knowledge. For its introduction and two applications, please see: Demystifying Tensorflow Time Series: Local Linear Trend and Variational Gaussian Process (VGP) — What To Do When Things Are Not Gaussian.
The key idea is to use another distribution to compute (note the word “compute”, not “approximate”) an otherwise intractable integration in an analytical way. I’ll use importance sampling inside the loss function to introduce the new distribution.
Derivation by importance sampling
Importance sampling introduces a distribution that is easy to sample from to help solve an otherwise intractable integration.In our case, the intractable integration is with respect to the random variables x₁ to x_T from the joint reverse process distribution p(x₀, …, x_T).
Note that in our case, p(x₀, …, x_T) is sample-able. As we described previously, sample-averaging is applicable to approximate the loss function analytically.
The true motivation to introduce a new distribution, which is the joint forward process q(x_1:T|x₀) in our case is that it helps derive the analytical formula for L. And it encodes the “gradual-ness” requirements for in-between images from the reverse process. You won’t know what I’m talking about here. Both points will be clear later after we finished deriving the analytical formula of the loss function using importance sampling.
In the loss L, the integration is with respect to the latent random variables x₁ to x_T, shown again in line (2) below, so the new distribution we introduce must be over the same set of random variables. The distribution q(x_1:T|x₀) that we defined for the forward process fits this requirement. Line (3) introduces it into the formula of L.
I would like to point out the the above derivation is valid for any probabilistic model with random variables x₀ to x_T because it only uses properties from the probability theory. This properties are true for any valid probability distributions. Only starting from the next derivation over Lᵥ, when we start to factorize joint probabilities using our definition of the forward process and the reverse process, we start to rely on the specific model structures.
Line (3) introduces the q(x_1:T|x₀)/q(x_1:T|x₀) quantify. This quantity evaluates to 1, so its addition does not change the integration.
Line (4) re-organizes the terms, turning the old integration into a new one with respect to q(x_1:T|x₀).
Line (5) represents the integration using the equivalent expectation notation.
Line (6) uses Jensen’s inequality to push the log function into the inner expectation because expectation of logs is easier to compute than the log of expectations. Jensen’s inequality also turns the result we will eventually minimize into a quantity that is larger than the original loss L.
Line (7) replaces the expectation notations into their definitions, that is, integrations. And line (8) re-arranges terms.
Line (9) applies the reverse chain rule in probability theory to derive the joint probability q(x_0:T).
Line (10) represents the integration using the equivalent expectation notation. Note that we started with introducing the q distribution over the random variables x₁ to x_T, and arrived at an expectation with respect to the random variables x₀ to x_T. We give this new quantity the name Lᵥ, standing for variational loss.
New loss Lᵥ to derive analytical formula for, and to minimize
From now on, Lᵥ is the quantity to minimize. Our goal is updated to derive the analytical formula for the new loss Lᵥ. Looking at line (10) it is hard to believe that it is analytical. But in math, amazing things do happen. Please read on.
Rewriting Lᵥ to get the important Lₜ₋₁ term
This is an important derivation, please pay attention.
Line (1) shows the derivation of the new loss Lᵥ. Lᵥ mentions the joint probability density of the reverse process p(x_0:T) and the forward process q(x_1:T|xₒ) that we defined previously.
Line (2) factorizes these two joint probability densities. It factorizes p(x_0:T) using the definition of the reverse process. And it factorizes q(x_1:T|x₀) using the first factorization of q.
Starting from this line, we are relying on the model structure that we defined, that is, the structure in which random variable xₜ₋₁ depends on xₜ in the reverse process p, and xₜ depends on xₜ₋₁ in the forward process. This is not true for arbitrary probabilistic models.
Line (3) again performs factorization. Note that the products start with t=2 instead of t=1 because of the factorization at this line.
Line (4) pushes the minus sign from outside of the expectation to inside of the expectation, and it uses the property that log(a×b) = log(a) + log(b).
Line (5) introduces name F_T to represent the first term and Fₒ for the third term inside the expectation to shorten the derivations so they fit in one line.
Line (6) is the key line, it uses the Bayes rule to replace q(xₜ₋₁|xₜ):
Note the addition of the dependency on x₀ to turn q(xₜ|xₜ₋₁) into q(xₜ|xₜ₋₁, x₀). This addition is redundant, it does not change the conditional probability because by definition, the random variable xₜ only depends on xₜ₋₁. See the definition for xₜ, shown again below. It only mentions xₜ₋₁ and not x₀.
The addition makes it easier for us to apply the Bayes rule because the Bayes rule mentions q(xₜ|x₀) and q(xₜ₋₁|x₀) which explicitly depends on x₀.
Note the dependency on x₀ in q(xₜ₋₁|xₜ, x₀) is not redundant. x₀ appears here because of the Bayes rule.
The reason for using the Bayes rule is to make the term q(xₜ₋₁|xₜ, x₀) popup. q(xₜ₋₁|xₜ, x₀) is a term from the reverse of the forward process. We now have a probability ratio between p(xₜ₋₁|xₜ) and q(xₜ₋₁|xₜ, x₀), seen at line (6). p(xₜ₋₁|xₜ) and q(xₜ₋₁|xₜ, x₀) are both:
- probability density function for the same random variable xₜ₋₁ and
- they are both multivariate Gaussian distributions with their analytical probability density available — previously we have defined the analytical form for both p(xₜ₋₁|xₜ) in the reverse process, and q(xₜ₋₁|xₜ, x₀) in the reverse of the forward process.
These two properties make it possible to derive the KL-divergence between p(xₜ₋₁|xₜ) and q(xₜ₋₁|xₜ, x₀) analytically, detailed later.
Line (7) uses the property of log to split terms.
Line (8) uses the property of log to term to turn sum of logs into log of products.
Line (9) realizes that in the log of products, the numerator and the denominator shares many terms, which can cancel, leaving just one term in the numerator and one term in the denominator.
Line (10) introduces the name F₀ to denote the last term inside the expectation. And it introduces the name Lₜ₋₁ for each of the negative log terms in the summation to make the derivations shorter. That is:
Obviously the Lₜ₋₁ for t=[2, T] terms are important. Note that for Lₜ₋₁, t starts from 2 instead of 1 because of the split in line (3). These T-1 terms constitute most part in the whole loss function, leaving only other three terms behind. Let’s worry about those three terms later and focus on the Lₜ₋₁ terms, as it will be the core for the final loss function that we minimize.
Let’s keep manipulating Lₜ₋₁ for t=[2, T]:
Line (1) is the definition of the Lₜ₋₁ term. Line (2) pushed the minus sign into the log. The expectation is with respect to the random variable x₀ to x_T from the q distribution.
Line (3) replaces the expectation notation with its mathematical definition, which is an integration over the random variables x₀ to x_T.
Line (4) factorizes the joint probability density q using the second factorization of the forward process.
Note the second factorization is a product of many distributions, each mentions a single latent random variable. This is correct because given the observational random variable x₀, all latent random variable x₁ to x_T are independent to each other.
Line (5) organizes all the factors from the q distribution into four parts:
- q(x₀), which is a distribution about x₀, and its formula is unknown.
- q(xₜ₋₁|x₀), which is a distribution about xₜ₋₁.
- q(xₜ|x₀), which is a distribution about xₜ.
- q(xₒₜₕₑᵣ), which is a distribution about the latent random variables other than xₜ₋₁ and xₜ.
The reason for line(5)’s factorization is that the log function only mentions x₀, xₜ₋₁ and xₜ.
Line (6) applies the chain rule to derive the joint probability q(xₜ₋₁, xₜ|x₀).
Line (7) is a key line. Using the reverse chain rule (which is applicable for any joint probability density), it replaces q(xₜ₋₁, xₜ|x₀) with q(xₜ₋₁|xₜ, x₀)q(xₜ|x₀) because
Line (8) splits the integrating variables into 4 parts, corresponding to x₀, xₜ₋₁, xₜ and xₒₜₕₑᵣ, and re-orders them.
Line (9) recognizes that the inner integration is the KL-divergence between q(xₜ₋₁|xₜ, x₀) and p(xₜ₋₁|xₜ). This KL-divergence is between two multivariate Gaussian distributions, whose analytical probability density functions are known. So we can write down the formula for this KL-divergence analytically. It is a function that mentions the random variable xₜ and x₀ (note, it does not mention xₜ₋₁), as well as all model parameters.
Line (10) factorizes q(xₜ, xₒₜₕₑᵣ, x₀) into conditionals.
Now we have the analytical expression for the KL-divergence between q(xₜ₋₁|xₜ, x₀) and p(xₜ₋₁|xₜ), but this KL-divergence is inside an integration. How do we solve the integration analytically?
That’s right, we can use sample-averaging to approximate the expectation with respect to x₀, xₜ and xₒₜₕₑᵣ:
- Sample x₀ by randomly picking natural images from the training set.
- Sample xₜ from the marginal q(xₜ|x₀) after plugging the sample for x₀.
- No need to sample xₒₜₕₑᵣ as line(10) reveals that xₒₜₕₑᵣ is not mentioned in the KL-divergence. The values for random variables inside xₒₜₕₑᵣ won’t change the computed result of the KL-divergence.
Phew, after so many steps, we finally arrived at the analytical expression for the Lₜ₋₁ terms for t in [2, T] in our new loss function Lᵥ to minimize.
Sample-averaging to solve the integration
Let me paste the analytical formula for Lₜ₋₁ here, and add the steps that use sample-averaging to approximate the integration analytically.
Line (1) is the analytical formula we derived just now for Lₜ₋₁. It has a multiple integration over the random variable x₀, xₜ and xₒₜₕₑᵣ. All three kinds are easy to deal with because:
- First, sample x₀ from our training set. Let’s call x₀’s sample S₀.
- Plug S₀ in q(xₜ|x₀) to get q(xₜ|x₀=S₀), which is now a fully specified multivariate Gaussian distribution ready to be sampled. Let’s call a xₜ’s sample Sₜ.
- Ignore the integration over xₒₜₕₑᵣ because xₒₜₕₑᵣ does not appear in the KL-divergence, their samples do not change the analytical form for the integration result.
Line (2) uses the above sampling scheme to sample n pair of (S₀, Sₜ); plugs each pair into the KL-divergence formula to get a analytical term, and then averages these analytical terms.
You may ask, how many pair n we should sample? The more the better, but empirically, a single pair already gives us good results, so n=1.
So line (3) uses the fact n=1 to remove the summation from line (2) to arrive at this simple formula:
KL(q(xₜ₋₁|xₜ, x₀) || p(xₜ₋₁|xₜ)) serves as regularization
After so much effort to derive the analytical formula for this KL-divergence, it is wise to look at it closely.
For each step t in [2, T], this KL-divergence quantifies the distance between two distributions:
- q(xₜ₋₁|xₜ, x₀) — the reverse of the forward process that we derived from the forward process by using the Bayes rule.
- p(xₜ₋₁|xₜ) — the reverse process that we used deep neural network to implement.
We are minimizing this KL-divergence. That is, we want these two distributions to be similar at each time step from t=[2,T]. In other words, we want to find a model p(xₜ₋₁|xₜ) that gives similar results as the reverse of the forward process q(xₜ₋₁|xₜ, x₀) at all steps. “similar results” means that at timestamp t-1, images sampled from p(xₜ₋₁|xₜ ) should be similar to image sampled from q(xₜ₋₁|xₜ, x₀).
Pay attention the timestamp range t=[2, T] here. This range means that the Lₜ₋₁ terms only covers the timestamps from 2 to T, leaving the first step t=1 unformulated. The timestamp t=1, being the step that finally generates the natural image, is of course important. Remember we left three teams from Lᵥ unanalyzed? Later we will see that the left terms covers the first timestamp.
So the reverse of the forward process serves as a regularization for the neural network and establishes the notion of “gradual-ness” among the images generated by the reverse process — the images generated by the reverse process at each time stamp should be similar to the images from the reverse of the forward process at the corresponding time stamps. Since the images from the reverse of the forward process changes gradually, the images from the reverse process must also change gradually.
This regularization restricts the neural network to behave according to an already known and much simpler process — the reverse of the forward process. The per-step KL-divergence prevents the learnt neural network to do weird things, such as first generates an image of a cat at an early step, and then morphs the cat into human face.
Now you should be convinced that the introduction of the forward process distribution q helps establish the gradual-ness of the generated images from the reverse process p.
Trajectory viewpoint
Let’s use the illustration below to reveal what KL(q(xₜ₋₁|xₜ, x₀) || p(xₜ₋₁|xₜ)) is trying to do from the trajectory point of view.
The left subplot shows two natural images X₀ and X₁. Starting from each natural image, if we apply the forward process multiple times, we get multiple trajectories. The black curves starting from X₀ or X₁ represent these trajectories. Timestamps go from left to right, so the images at the end of each trajectory are pure Gaussian noise already.
In this completely unconditioned setting, at timestamp t-1, the random variable xₜ₋₁ in our model can take values from any trajectory, no matter a trajectory starts from X₀ or X₁. In other words, at timestamp t-1, our model needs to be able to explain all possible images that can be generated by the forward process, starting from any natural image. Our model can do that by giving the random variable xₜ₋₁ a mean that is in the middle of all the trajectories and a large variance.
The middle subplot shows the situation when x₀ is given, which sets the random variable x₀ to the natural image X₀. This setting restricts the model to only explain the trajectories that start from the natural image X₀. They are the red trajectories in the middle subplot. In other words, our model now only need to explain the possible values from the red curves at timestamp t-1. The model can do that by offering a more precise mean and a smaller variance, since it does not need to cover the black trajectories starting from the natural images X₁ anymore.
The right subplot shows the situation when x₀ is still conditioned to X₀, and additionally, xₜ is conditioned on a particular image Sₜ, which is sampled from the distribution q(xₜ|x₀=X₀). This second conditioning further restricts the model to only need to explain trajectories that go through Sₜ at timestamp t. These are the blue trajectories, which are all start from X₁ and pass through Sₜ.
Under this condition, the possible values that the random variable xₜ₋₁ can take at timestamp t-1 is further restricted. This means that our model needs to predict a mean that is around middle of the blue trajectories, and predicts an even smaller covariance for xₜ₋₁.
But how “around the middle of the blue trajectories” should the predicted mean be, and how “even smaller” should the predicted covariance be for the random variable xₜ₋₁? These two target quantities are defined by the reverse of the forward process q(xₜ₋₁|xₜ, x₀), with its definition shown here again:
with
By conditioning the model on xₜ and x₀, we are giving the model an easier task to learn at each training step because at each step, the model only needs to explain a single time step at a relatively small amount of trajectories.
Optimization forces p to change by fixing q
Since the the reverse of the forward process q(xₜ₋₁|xₜ, x₀) is fixed, that is, there is no trainable parameters in q(xₜ₋₁|xₜ, x₀), the only way the optimization can do to make q(xₜ₋₁|xₜ, x₀) and the reverse process p(xₜ₋₁|xₜ) similar to each is to change the model parameters’ values to move p closer to q.
One thing to note is that many other papers introduce a learnable q and move q closer to p. Not in this paper. In this paper, the q distribution introduced in importance sampling is fixed, and minimizing the KL-divergence between q and p moves p.
Since the KL-divergence Lₜ₋₁=KL(q(xₜ₋₁|xₜ, x₀) || p(xₜ₋₁|xₜ)) is analytical, let’s write it down. Recap the probability density functions for the two mentioned distributions in the KL-divergence are both multivariate Gausisans:
The analytical formula for the KL-divergence between two multivariate Gaussians is:
The above formula has 4 terms.
The first term at line (1) computes the log ratio between two covariance matrix determinant, denoted by the name “det”. This team mentions model parameters.
The second term at line (2) reference d, the dimension of the random variable xₜ₋₁, which is the number of pixels in the images that we are working with. This term does not mention any model parameter.
The third term at line (3) computes the trace, denoted by the name “tr”, of two matrix product. This term mentions model parameters.
The fourth term at line (4) is the square of the vector μₚ(xₜ, t)-μₜ(xₜ, x₀), scaled by the covariance matrix Σₚ(xₜ, t)⁻¹.
I know, this formula is terrible. And please remind ourselves that we need to minimize this term with respect to the model parameters, which appears in:
- μₚ(xₜ, t), the neural network that is responsible to predict the mean of the mean vector for the p(xₜ₋₁|xₜ) multivariate Gaussian distribution.
- Σₚ(xₜ, t), a second neural network that is responsible to predict the covariance matrix for the p(xₜ₋₁|xₜ) multivariate Gaussian distribution.
Simplifying the model by setting the reverse process covariance matrix to constant
let’s simplify the model by removing the second neural network that predicts the covariance matrix. Mathematically, we set Σₚ(xₜ, t)=σₜ²I, where one of the obvious choice for σₜ² is:
The above makes the covariance matrix from the reverse process p(xₜ₋₁|xₜ) the same as the covariance matrix of the reverse of the forward process.
With this simplification, the first three terms become constants, let’s name their sum C. C does not mention model parameters anymore. They can be ignored during optimization. This left us with only the fourth term, let’s call it LMₜ₋₁. So we have:
with LMₜ₋₁ being:
Line (1) is the fourth term. Line(2) plugs in the simplified covariance matrix. The ||…||² in line (3) is the vector square operation, that is, vector dot product with itself. Line(4) swaps the two components in the square, which does not make a difference in result, just to be more consistent with the order of terms in the paper.
Note that I dropped the expectation with respect to x₀ and xₜ in LMₜ₋₁ to make the formula concise. But the computation is the same as before, we need to sample x₀ and xₜ, plug the samples in LSₜ₋₁ to approximate the integration analytically.
Interpreting the meaning of LMₜ₋₁
LMₜ₋₁ quantifies the distance between the two vector μₜ(xₜ, x₀) and μₚ(xₜ, t). This makes a lot of sense now:
- Originally we want to minimize the distance between q(xₜ₋₁|xₜ, x₀) the reverse of the forward process and p(xₜ₋₁|xₜ), which is our neural network implementation of the reverse process, at every time step t from 2 to T. In other words, we want to find a configuration (model parameter values) for the p(xₜ₋₁|xₜ) distribution such that these two distributions are similar to each other.
- These two distributions for the random variable xₜ₋₁ are both multivariate Gaussian. A multivariate distribution is fully specified by it mean vector and covariance matrix. If p(xₜ₋₁|xₜ) needs to be similar to q(xₜ₋₁|xₜ, x₀), their mean vector and covariance matrix must be similar to each other. This is called momentum matching, with the mean being the first momentum, and the covariance being the second. The letter “M” in LMₜ₋₁ stands for momentum matching.
- After we simplified the covariance matrix from the p(xₜ₋₁|xₜ) distribution to a quantity that is equal to the covariance matrix from the reverse of the forward process, the only thing that we can still change to make these two distributions similar or different is the mean vector. So we want to minimize the distance between the mean vectors from the p(xₜ₋₁|xₜ) and the q(xₜ₋₁|xₜ, x₀) distribution.
- Since the mean vector from the p(xₜ₋₁|xₜ) distribution is predicted by our neural network, we can use optimization to move the values of the neural network weights around by minimizing LMₜ₋₁.
Simplifying LMₜ₋₁
It is possible to simplify LMₜ₋₁, a lot. In LMₜ₋₁’s formula, the μₚ(xₜ, t) part is from the neural network, it’s like a black box, there is little we can simplify. So let’s try to simplify the other term μₜ(xₜ, x₀), which is the mean vector of the reverse of the forward process q(xₜ₋₁|xₜ, x₀), whose analytical probability density function is already derived:
with the covariance matrix:
and the mean vector:
We only need to look at the mean vector μₜ(xₜ, x₀) because previous derivation of LMₜ₋₁ reveals that we only need to use our neural network to predict a mean vector that is close to, or alternatively, match, μₜ(xₜ, x₀).
We also have the analytical probability density function for q(xₜ|x₀):
Using the reparameterization trick, we can rewrite the above into:
Re-organize the terms in the above equation to get the expression for x₀:
Now plug in this expression of x₀ into the formula for μₜ(xₜ, x₀):
Line (1) is a horrible formula, and line (2) introduces name A to represents the coefficient in front of xₜ, and the name B for ϵₜ. We will simplify A and B separately.
Simplifying A
Simplifying B
Wow, what an amazing simplification! It gives us:
Don’t panic, our goal has not changed — we still want our neural network to predict the mean vector of the p(xₜ₋₁|xₜ) distribution and the predicted mean vector should be as close to μₜ(xₜ, x₀) as possible. But upon seeing the simplified formula for μₜ(xₜ, x₀), we realize:
- xₜ is known via sampling, there is no need to predict it.
- Given timestamp t, βₜ is constant, and so all the other quantities derived from βₜ, namely αₜ and αₜ bar.
- The only part that needs predicting is the noise ϵₜ.
We can drop the original neural network, and design a new one ϵₚ(xₜ, t) that predicts the noise ϵₜ. Then we can construct the desirable mean vector μₚ(xₜ, t) by:
Plug this formulation into the definition of LMₜ₋₁ give us:
Line (7) is the simplified objective function to minimize.
Note that this objective function mentions the noise ϵₜ twice. They are the same random variable, not two different noises. This is because they both come from the same source:
The first time we use the above to get x₀ as an expression of xₜ and ϵₜ. The second time we use get xₜ as an expression of x₀ and ϵₜ.
Is this objective function still analytical?
Remember previously we drop the expectation with respect to xₜ and x₀ for LMₜ₋₁ to shorten our derivations? To answer the question if LMₜ₋₁ is still analytical, we have to add them back, because only with those expectations, we are computing the correct LMₜ₋₁.
Note:
- In the final formula for LMₜ₋₁, there is no mention of xₜ anymore, xₜ is expressed via x₀ and the noise ϵₜ. So we don’t need to add the expectation with respect to xₜ. Instead, we need to add the expectation with respect to ϵₜ, which is a standard multivariate Gaussian, that is ϵₜ~N(0, 1).
- There is the mention of timestamp t, which represents an integer between 2 and T. We need to add an expectation with respect to t, which comes from a uniform distribution.
- There is the mention of x₀, which comes from the unknown data distribution q(x₀).
So, the complete formula for LMₜ₋₁ is:
where Uni(2,T) denotes the uniform distribution between 2 and T.
This formula is analytical with sample-averaging. When we plug in the samples for x₀, ϵₜ and t into the above formula, we have an analytical expression, from which we can take gradient to perform stochastic gradient descent.
The authors found by ignoring the constants in front of the vector distance erm, the results is better:
The following Algorithm 0 minimizes the above loss:
Algorithm 0 evaluates the expectation with respect to x₀, xₜ and t by sample-averaging. Note at line (3), the timestamp t is sampled from the uniform distribution Uni(2, T).
One notational difference between the paper and this article is that in the paper, the authors use ϵ_θ to denote the neural network, and I use ϵₚ. The authors used ϵ_θ to highlight that the neural network has parameter set θ. This is also explicitly shown at line (5) of the above algorithm when the gradient (notice the ▽ symbol that denotes derivative over vector) is computed on the loss function with respect to θ. I use ϵₚ, because there is no subscript θ in Unicode, and I don’t want to write two many ϵ_θ as they don’t look good.
Another notational difference is the paper uses ϵ to denote standard Gaussian noise, and I used ϵₜ. I use ϵₜ because I derived my formulas this way. But I think ϵ is better because the standard Gaussian noise does not depend on the timestamp t.
The derivation for Lᵥ shows that it is an expectation with respect to q(x_0:T) and inside the expectation there are multiple terms, shown below:
Previously we only focused on the Lₜ₋₁ terms for t=[2, T]. Now let’s talk about the remaining terms, which I extracted into the first expectation at line (2) using the linearity of expectation property: E[a + b] = E[a] + E[b].
Line (2) replaces the names F_T and F₀ with their actual formula.
Line (3) and (4) re-writes the terms using the properties of log.
Line (5) simplifies the second log.
Line (6) splits the expectation into 2 using the linearity of expectation property.
Line (7) gives the first expectation the name L_T, same as the paper.
Line (8) gives the negative of the second expectation the name L₀, same as the paper.
The L_T term can be ignored in optimization, while the L₀ needs special treatment. We will see why.
Ignoring the L_T team
Here is the formula for the L_T term again:
It mentions q(X_T|x₀), which is the marginal probability density for the random variable X_T. The forward process doesn’t include any model parameters.
It also mentions p(X_T) which is the reverse process at timestamp T. We defined p(X_T) = N(0, 1). So p(X_T) doesn’t mention model parameters either.
This means the whole L_T term doesn’t mention model parameters, thus it can be ignored during parameter learning.
Approximating the L₀ term
The L₀ term is:
This term is for the timestamp t=1. Let’s understand what this term is saying. We want to minimize this team, which translates to finding model parameters that maximize the log likelihood log(p(x₀|x₁)). In other words, we want p(x₀|x₁) to evaluate to a high probability number when a natural image is plugged into x₀.
Alternatively, we can understand it by using the formula from Lₜ₋₁:
Line (1) is the definition of Lₜ₋₁ that we derived previously. Note that when we derived it, t starts from 2 because when t≥2, all Lₜ₋₁ terms are KL-divergences between two proper Gaussian distributions. This is not true for t=1 as you will see at line (4).
Line (2) sets t=1 to derive L₀. And line (3) expands the KL notation to its mathematical definition.
Line (4) uses the property that q(x₀|xₜ, x₀) = 1. This line also reveals that when t=1, there is no KL anymore. The formula degrades to an integration of a log. That’s why we cannot handle t=1 in Lₜ₋₁.
Line (5) uses the property of log to simplify the formula.
Line (6) replaces the integration using the expectation notation.
Line (7) simplifies the two expectations over x₀ into one expectation over x₀ since one expectation already removes the random variable x₀. The second expectation over x₀ doesn’t change the result anymore. This line also reveals that the resulting quantity is indeed the L₀ term.
L₀ needs to be minimized differently, it can’t fit into Algorithm 0
Now we should understand that it is not that we cannot derive L₀ from the Lₜ₋₁ point of view. We can, but the derivation of L₀ is not a KL-divergence between two proper multivariate Gaussian distributions, which means the analytical formula of L₀ is different from the analytical formula of Lₜ₋₁ for t≥2. This means we need a different way to minimize L₀. In other words, the minimization of L₀ doesn’t fit into Algorithm 0. Well, it doesn’t fit yet, later we will introduce a proximation to make it fit.
L₀ is optimizable
Since we want to minimize L₀, it is important that either:
- L₀ does not mention any model parameters so it can be ignored during the optimization. Or
- L₀ mentions model parameters and is analytical so its gradient can be taken for gradient descent.
Since the previous loss function LMₜ₋₁ only handles the case when t≥2, we hope that L₀ falls into the second category above so some part of our loss function covers the case t=1. Indeed that’s the case:
Line (1) is the definition of L₀.
Line (2) plugs in the definition of p(x₀|x₁), which is a multivariate Gaussian distribution with the neural network µₚ(x₁, 1) predicting its mean vector, and with its covariance matrix set to the constant 𝛼₁² I. I ignored the normalization term in front of the exponential, and used the proportional symbol “∝”.
Line (3) and line (4) simplifies the formula.
Line (4) reveals that L₀ mentions all the model parameters in µₚ(x₁, 1) and it is analytical after we sample x₀ and xₜ. So L₀ is optimizable.
Minimizing an approximation of L₀ inside Algorithm 0
Line (4) from above also shows that to minimize L₀, the neural network µₚ(x₁, 1) needs to predict a mean vector that is close to a natural image, say X₀, sampled for x₀.
Previously when we derive the analytical formula of Lₜ₋₁ for t≥2, we arrived at the realization that we want our neural network µₚ(xₜ, t) to predict mean vectors that are close to the mean of the reverse of the forward process µₜ(xₜ, x₀).
If we can:
- write down µₜ(xₜ, x₀) for t=1, that is µ₁(x₁, x₀) and,
- if µ₁(x₁, x₀) is close to the natural image sample X₀
then we can turn the original task of “minimizing the distance between between µₚ(x₁, 1) and X₀” to an approximation task of “minimizing the distance between between µₚ(x₁, 1) and µ₁(x₁, x₀)”. The benefit of the latter is that we can handle the case of t=1 using Algorithm 0, the same way as for the cases of t≥2.
We can write down µ₁(x₁, x₀)
Note that we cannot set t=1 into the first line above. This is because when t=1, quantifies such as 𝛼ₜ₋₁ bar is not defined. But we can set t=1 into the second line. This is because the second line replaces x₀ in the first line with an expression that only mentions x₁. And all quantities involving 𝛼₁ and β₁ are defined.
Set t=1 to derive:
After plugging sample for x₁ and ϵ₁, the above is a constant.
We know µ₁(x₁, x₀) must be close to the natural image X₀
This is because µ₁(x₁, x₀) is the mean vector for the ending random variable x₀ from the reverse of the forward process. So if we draw a sample for x₀ from the reverse of the forward process, we should get an image that is close to the natural image X₀. That’s by the definition of the reverse of the forward process. In fact, if we draw many many images for x₀ from the reverse of the forward process and averages all those sampled images, the average should be exactly equal to X₀. In other words, the reverse of the forward process can generate the exact starting image in expectation. But if we only sample a single image for x₀ from the reverse of the forward process, that sample is not equal to X₀. That’s why we are approximating the L₀ term.
Now we can use Algorithm 0 to handle all timestamps starting from t=1. Mathematically, we expand LMₜ₋₁ which only covers the cases for t≥2, see the t~Uni(2,T) part under the expectation:
to cover the case for t=1 as well, see the t~Uni(1,T) part under the expectation:
Lₛᵢₘₚₗₑ is the final loss function, and it covers all timestamps from 1 to T. Algorithm 1 from the paper, copied below, minimizes Lₛᵢₘₚₗₑ:
We happily notice that at line (3), the timestamp t is sampled from the uniform distribution Uni(1, T) covering all cases t≥1 because of the approximation for the L₀ term.
No concern on high variance in sample-averaging Lₛᵢₘₚₗₑ?
Previously I said that we can use sample-averaging to compute the analytical formula for expectation of the negative log likelihood L with respect to all the latent random variable x₁ to x_T. But this results in high variance in the computed expectation if we can only afford to draw one sample per random variable for practical computation reason.
Why we have no problem to use sample-averaging to compute the analytical formula Lₛᵢₘₚₗₑ and drawing a single sample per random variable?
The main reason is that in the final loss function Lₛᵢₘₚₗₑ, there are only 3 random variables to sample, compared to the T+1=1000+1 random variables to sample in the case of expectation of the negative log likelihood. So the variance in the final loss function’s case should be much smaller than the case of expected negative log likelihood.
To make things even better, now the samples are not drawn through uncalibrated neural networks any more, they all come from standard distributions whose behaviours do not depend on how much we’ve trained our neural networks. This results in a more predictable parameter learning experience.
But just for fun, let’s consider the alternative to sample-averaging. That is, to compute the expectation in the final loss function Lₛᵢₘₚₗₑ analytically:
- For the random variable x₀, there is no way to compute the expectation with respect to it analytically because the data distribution q(x₀) is unknown. So sample-averaging is the only option.
- For the random variable t that comes from an uniform distribution. It’s expectation is just take all possible values of t, compute the formula inside the expectation and average them. This is equivalent to sample-averaging in our context of stochastic gradient descent. Even though in stochastic gradient descent, Algorithm 1 only works with a single term, instead of adding all those terms together and dividing the sum by T, the algorithm does it repeated until converging. This is equivalent to computing the expectation over t asymptotically. For more details, please see the proof in Can We Use Stochastic Gradient Descent (SGD) on a Linear Regression Model?
- For the standard multivariate Gaussian random variable ϵₜ, we can use Gaussian quadrature to approximate the expectation analytically. For more details about Gaussian quadrature, please see Variational Gaussian Process (VGP) — What To Do When Things Are Not Gaussian. But Gaussian quadrature works better in low dimensional settings. In our case, the ϵₜ is a d dimensional random variable with d being the number of pixels in the images that we want to generation, so d is a large integer. And applying Gaussian quadrature is not practical. For more details about why it is not practical, please see the Appendix of the above link.
Given the above, using sample-averaging to approximate the expectation in Lₛᵢₘₚₗₑ is a sensible choice.
This article established clear motivation why the denoising diffusion probabilistic model is designed in that way by reasoning about the relationships among the forward process q(xₜ|xₜ₋₁), the reverse of the forward process q(xₜ₋₁|xₜ, x₀) and the reverse process p(xₜ₋₁|xₜ). It also provides detailed derivation of the loss function used for model parameter learning.
Why we won’t loss model parameters when applying sample-averaging to derive the analytical formula for the loss function L
A typical problem applying sample-averaging to approximate integrations in a loss function is that the resulting formula does not mention model parameters anymore. The reparameterization trick (see here) is the go-to recipe to prevent this from happening.
Our case of using sample-averaging to derive the analytical approximation for the loss function L does not have the losing model parameter problem, let’s use an example with a short reverse process (T=1) to see why.
Let’s show the loss function L together with some manipulations to demonstrate sample-averaging:
Line (1) is the loss L, and line(2) replaces the expectation notation with its mathematical definition.
Line (3) set T=1 to demonstrate following derivations on a short reverse trajectory.
Line (4) factorizes the joint probability inside the inner integration using the definition of the reverse process.
Line (5) replaces all probability density function notation with the actual probability density distribution names. It also reveals that the random variable x₁ is sample-able from the standard multivariate Gaussian distribution N(0, 1). Let’s denote S₁ as the sample for x₁.
Line (6) plugs in the sample S₁, removing the inner integration by doing sample-averaging using only one sample, for demonstration purpose. Sample-averaging is an approximation, which is reflected by the approximation sign “≈” in front of the line.
Line (7) draws the sample S₀ for the random variable x₀ from the unknown distribution q(x₀); practically just randomly pick an natural image from the training set. It then uses sample-averaging again to remove the integration over x₀.
Line (8) plugs in the formula for the multivariate Gaussian probability density function. The proportional symbol “∝” allows me to drop the normalization terms in front of the exponential function.
Line (9) simplifies the formula. It reveals after sample-averaging, the analytical loss is still a function that mentions all model parameters. So no need for the reparameterization trick.