VI Drives me NUTS

Feel free to be a bit weary.

Vincent Warmerdam

Variational Inference is “hip” and I can’t say that I am a huge fan. I decided to give it a try and immediately it hit my head. In this document I hope to quickly demonstrate a potential failure scenario.

The Model

Here is the code for the model. It is a model about increased weights of chickens who are given different diets.

df = pd.read_csv("", 
                 names=["r", "weight", "time", "chick", "diet"])
time_input = 10

with pm.Model() as mod: 
    intercept = pm.Normal("intercept", 0, 2)
    time_effect = pm.Normal("time_weight_effect", 0, 2, shape=(4,))
    diet = pm.Categorical("diet", p=[0.25, 0.25, 0.25, 0.25], shape=(4,),
    sigma = pm.HalfNormal("sigma", 2)
    sigma_time_effect = pm.HalfNormal("time_sigma_effect", 2, shape=(4,))
    weight = pm.Normal("weight", 
                       mu=intercept +*df.time, 
                       sd=sigma +*df.time, 
    trace = pm.sample(5000, chains=1)

Next I’ll show how the traceplots are different if we compare different inference methods.

NUTS sampling results

I took 5500 samples with NUTS. It took about 7 seconds and this is the output:

Metropolis sampling results

I took 20000 samples with Metropolis. It took about 14 seconds and this is the output:

VI results

I used the fullrank_advi setting. Here’s a traceplot from the samples I took from the approximated posteriour.


The interesting thing is that if I change the model slightly, VI suddenly has no issues (this was pointed out to me by a collegue, Mathijs).

n_diets =

with pm.Model() as model:
    mu_intercept = pm.Normal('mu_intercept', mu=40, sd=5)
    mu_slope = pm.HalfNormal('mu_slope', 10, shape=(n_diets,))
    mu = mu_intercept + mu_slope[] * df.time
    sigma_intercept = pm.HalfNormal('sigma_intercept', sd=2)
    sigma_slope = pm.HalfNormal('sigma_slope', sd=2, shape=n_diets)
    sigma = sigma_intercept + sigma_slope[] * df.time
    weight = pm.Normal('weight', mu=mu, sd=sigma, observed=df.weight)
    approx =, random_seed=42, method="fullrank_advi")

The main difference is that I am no longer using pm.Categorical.

With that out of the way suddenly the estimates look a whole lot better.


Be careful when using variational inference. It might be faster but it is only faster because it approximates. I’m not the only person why is a bit skeptical of variational inference.

The alternative, NUTS sampling still amazes me, even though it isn’t perfect.


For attribution, please cite this work as

Warmerdam (2018, Nov. 9). VI Drives me NUTS. Retrieved from

BibTeX citation

  author = {Warmerdam, Vincent},
  title = { VI Drives me NUTS},
  url = {},
  year = {2018}