Schrödinger Bridges and CLIP
Written by Shingo Kodama and Ines Kevine
Simple Diffusion
Last time, we learned about a model that can generate new images: diffusion! Our motivation was to generate new data (for example, images) from \(P_{\textnormal{data}}\), the probability distribution that we sample \(x_0\) from. We derived a closed form for generating \(x_t\) using only \(x_0\): \[ x_t = \sqrt{\alpha^{t}} \cdot x_0 + \sqrt{1 - \alpha^{t}} \cdot z , \quad\quad \text{where } z \sim \mathcal{N}(0,I) \] We used this formula to directly compute \(x_0\) using \(x_T\) and a predicted \(z\) value, allowing us to generate new images in an efficient manner.
However, we could also do diffusion in a “simpler” way!
(Note that the jagged lines on f denote that f is being trained.)
Here, we compute each \(x_t\) by adding a scaling of a normal distribution to \(x_{t-1}\). We don’t express \(x_t\) in terms of \(x_0\), instead creating each new \(x_t\) by using the previously created \(x_{t-1}\). This is certainly less efficient than our model from yesterday, but it’s simpler and easier to understand (and it will lead to Schrödinger Bridges, which we will talk about next)!
So what would our training and evaluation steps look like in this scenario? For the training phase, we would do: \[ \begin{array}{l} x_0 \sim P_{\textnormal{data}} \\ \text{for } t \text{ in } \{0,...,T-1\}:\\ \quad\quad z \sim \mathcal{N}(0,I)\\ \quad\quad x_{t+1} = x_t + \gamma z\\ \quad\quad \mathcal{L}_f = \|\mathbf{f(x_{t+1}, t+1) - x_t}\|_2^2 \end{array} \]
And for the evaluation phase, we have: \[ \begin{array}{l} x_T \sim \mathcal{N}(0,I) \\ \text{for } t \text{ in } \{T-1,...,0\}:\\ \quad\quad x_t = f(x_{t+1}, t+1) \end{array} \] The for loop in the evaluation phase is just one transition from \(x_T\) to \(x_0\). Just like in the original diffusion we learned about yesterday, we can repeat this process \(T\) times to create a more accurate image.
Schrödinger Bridges
Now let’s move on to a more complex model. Before (in the two diffusion models that we’ve seen so far), we wanted to transition between \(\mathcal{N}(0,I)\) and \(P_{\textnormal{data}}\), so that we could generate new images starting from a normal distribution. Now, we want to move from one probability distribution \(P_{\textnormal{data}}\) to another probability distribution \(P_{\textnormal{data}}^{\prime}\)!
Why would this be helpful? Let’s imagine a situation where we had an image of day-light cityscape. We can use a Schrödinger bridge diffusion model to transition this image into a night-time scene, because we can move from one probability distribution to another. Here are some examples of ways that this model would be helpful:
\[\begin{array}{c@{\hspace{8pt}\leftrightarrow\hspace{8pt}}c@{\hspace{8pt}}c} \mathbf{P_{\textnormal{data}}^{\prime}} & \mathbf{P_{\textnormal{data}}} & \text{Explanation}\\ \hline \text{day} & \text{night} & \text{turning a day scene into a night scene}\\ \text{blurry} & \text{clear} & \text{turning a blurry image into a clear one}\\ \text{realistic} & \text{comic} & \text{turning a realistic image into a comic-style image}\\ \text{cat} & \text{dog} & \text{morphing a cat into a dog with similar features}\\ \end{array}\]So now our motivation is: Given \(x_0 \sim P_{\textnormal{data}}\), find a “close” \(x_T \sim P_{\textnormal{data}}^{\prime}\)!
We want to move all the points (more exactly, the probability distribution) from \(P_{\textnormal{data}}\) to \(P_{\textnormal{data}}^{\prime}\) in an \(\textit{optimal}\) way. What does it mean to be optimal? We want to transform one distribution to another with the least amount of total work. We call this “optimal transport”.
The Schrödinger Bridge model attempts to find this optimal transport between \(P_{\textnormal{data}}\) and \(P_{\textnormal{data}}^{\prime}\). How do we do this?
First, let’s take a look at the backward phase (going from \(x_0\) to \(x_T\)). Instead of just adding noise with a normal distribution, we now want to have a function \(b\) that steps from \(x_{t-1}\) to \(x_t\) for each \(t\) between \(1\) and \(T\). We now want to train the function \(f\) that takes us from \(x_t\) to \(x_{t-1}\), just as in the simple diffusion model we looked at earlier.
Now, let’s suppose that \(f\) has been trained. Then we can start from \(y_t \sim P_{\textnormal{data}}^{\prime}\) and use \(f\) to step from \(y_{t+1}\) to \(y_t\) for all \(t\) between \(0\) and \(T-1\). We can now train \(b\)!
So we want to train both \(f\) and \(b\) at the same time. Where have we seen something similar recently? It’s just like GANs, where we trained the generator and the discriminator at the same time! Let’s see how we can represent the transitions and loss functions mathematically.
\[\begin{array}{l} x_{t+1} = x_t + \gamma z + b(x_t, t), \quad\quad \text{where } z \sim \mathcal{N}(0,I) \\ \mathcal{L}_f = \|\mathbf{f(x_{t+1}, t+1) - x_t}\|_2^2 \\\\ y_{t-1} = y_t + \gamma z + f(y_t, t), \quad\quad \text{where } z \sim \mathcal{N}(0,I) \\ \mathcal{L}_b = \|\mathbf{b(y_{t-1}, t-1) - y_t}\|_2^2 \end{array}\]Can we train both \(f\) and \(b\) at the same time, when they are both randomly initialized? It’s certainly possible to do so, but there is an alternative solution that can make the generated image better! We can first pretrain \(f\) and \(b\) using regular diffusion, then use that to start the training phase. This makes the results much better.
CLIP
So now we know how to (1) embed words and images as vectors in latent spaces and (2) generate images using diffusion. Could we possibly combine these techniques to generate new images from text prompts?
Yes, we can! First, how can we find data to train the model on? We can find many images online that have captions attached to them. We can assume that the captions are describing their images, so each (image, caption) pair is a positive pair! That means we can use contrastive learning to embed images and words in the same latent space.
So how can we represent images and text as vectors of the same dimension? For images, we can use a network of convolutional layers to produce a vector in \(\mathbb{R}^d\). For text, we can use attention layers to produce a vector that represents the meaning of an entire sentence or even paragraph. We can then use a linear network to transform that vector into a vector in \(\mathbb{R}^d\).
Now we have vector representations of both images and captions in the same dimension! For any pair of (image, caption), we can calculate:
\[ | f(\text{image})^\top g(\text{caption}) | \]
If the pair is a positive pair, we want this value to be large. If it is a negative pair, we want it to be small.
We’ve now been able to represent images and words in the \(\textit{same}\) space in a meaningful way. Captions and images that represent similar concepts are now close to each other in this latent space!
Finally, let’s think about how we can create images from text prompts. We still use diffusion to create images, but with a slight change in the model.
Now, instead of sampling \(x_0 \sim P_{\textnormal{data}}\), we sample \((x_0, text) \sim P_{\textnormal{data}}\). The loss function becomes \[ \begin{array}{l} \mathcal{L}_f = \|\mathbf{f(x_t, t, text) - x_{t-1}}\|_2^2 \end{array} \] where we pass in the text as an argument of \(f\).
After training the model in this way, we can create new images using whatever text prompt we want! We simply represent the text prompt as an embedded vector and use it to transition from \(X_T \sim \mathcal{N}(0,I)\) to \(X_0\). We can now generate images from text prompts that we have no prior data about.