flowMC - normalizing flow enhanced MCMC in JAX
Markov Chain Monte Carlos (MCMC) is a simple-to-code algorithm that is used in many domains of science, often in finding uncertainty in the model's parameters given some data.
The procedure of MCMC can be described as the following:
Start somewhere in your parameter space.
Randomly propose to jump somewhere in the parameter space according to a proposal.
Decide to move to the proposed location or not. This depends on the probability density ratio between your current position and the proposal location. The larger the ratio, the more likely the proposal will be accepted.
Rinse and repeat until some convergence criteria are met.
Compute whatever quantity you like using the chain you just constructed, such as estimating the mean of the distribution.
The simplest kind of MCMC uses a normal distribution as its proposal distribution. Here below shows a single proposed step from an illustrative example:
From this example, we can see why MCMC may have a stroke when we go to a high-dimensional problem. Whenever we propose a jump into the region where the target probability is low, the proposed jump is more likely to be rejected, and the chain will just stay where it is currently. In another word, the computational resource is wasted because of this rejection, since you pay the price to compute the likelihood value at the new location, yet no new information is gained because the jump is rejected. The degree of inefficiency basically depends on the overlap between the proposal distribution and the target distribution, and the ratio gets worse and worse as we go to higher and higher dimensions. If your proposal distribution is exactly the target distribution, basically every proposal cycle will give you an effective sample, which is the best-case scenario. Obviously one cannot craft such a proposal for all the problems that exist in the world apriori, otherwise, why do you need to use an MCMC in the first place?
While we cannot write down a proposal that is close to the target apriori, we can attempt to construct it alongside an MCMC run. This is the core idea behind flowMC, in which the algorithm goes something like this:
Run an MCMC (typically with as many independent chains as possible) to gain some insight into the local landscape.
Train a normalizing flow (Explaing video here: https://www.youtube.com/watch?v=s27I7b3-FMY) with the samples you aggregate across multiple chains
Use the distribution approximated by the normalizing flow as your proposal distribution in the MCMC.
There are a couple of catches and detail one may need to pay attention to make sure the MCMC is converging and faithful, but I will leave you to read the documentation of flowMC. I spent time on that and I am proud of it. Below shows a figure fetched from an example of flowMC (https://flowmc.readthedocs.io/en/latest/tutorials/dualmoon.html), the target distribution is four separated moon distributions. The lines show how two chains transverse through the posterior space. You can see all these lines going from one mode to another, which is uniquely enabled by the normalizing flow. If you try to run this problem with a normal distribution proposal, one needs to take many more steps to go from one mode to another, instead of teleporting to any mode like in the picture below.
So what's more than the normalizing flow part in flowMC? Now I have told you the secret sauce that makes our algorithm works, why should you use my package and give me kudos instead of taking all the glory yourself? Well, here are some features we have implemented in our code:
Since our code is based on Jax (https://github.com/google/jax), we support automatic differentiation, hence one can use gradient-based samplers such as MALA and HMC as the local sampler.
The code runs natively with GPU because of Jax (again), which allows many parallel chains at once. Also, this is really important to train the normalizing flow model fast.
Jax offers Just-In-Time (JIT) compilation for specific devices you have, so it gives you an additional computational advantage compared to other python-based packages.
A black-box interface. Instead of asking you to refactor your code in a way that fits into our environment, we simply ask you to write your posterior/likelihood function in Jax, i.e. p(x), where x is the vector with dimensions you want to sample over. Most of our user wants to focus on their model, instead of the sampling detail. Also, the minimizes the development overhead and sunk cost of failure.
flowMC is published in the Journal of Open Source Software, and freely available on GitHub. If you have a problem in MCMC, give it a shot!