Bayesian Mixed Model

Explore the classic sleepstudy example of lme4. Part of this code was based on that seen on this old Stan thread, but you can look at the underlying code for rstanarm or brms for a fully optimized approach compared to this conceptual one.

Data Setup

The data comes from the lme4 package. It deals with reaction time to some task vs. sleep deprivation over 10 days.

library(tidyverse)
library(lme4)

data(sleepstudy)
# ?sleepstudy

dat = list(
  N = nrow(sleepstudy),
  I = n_distinct(sleepstudy$Subject),
  Subject = as.numeric(sleepstudy$Subject),
  Days    = sleepstudy$Days,
  RT      = sleepstudy$Reaction
)

Model Code

Create the Stan model code.

data {                                      // data setup
  int<lower = 1> N;                         // sample size
  int<lower = 1> I;                         // number of subjects
  vector<lower = 0>[N] RT;                  // Response: reaction time
  vector<lower = 0>[N] Days;                // Days in study
  int<lower = 1, upper = I> Subject[N];     // Subject
}

transformed data {
  real IntBase;
  real RTsd;
  
  IntBase = mean(RT);                       // Intercept starting point
  RTsd    = sd(RT);
}

parameters {
  real Intercept01;                         // fixed effects
  real beta01;
  vector<lower = 0>[2] sigma_u;             // sd for ints and slopes
  real<lower = 0> sigma_y;                  // residual sd
  vector[2] gamma[I];                     // individual effects
  cholesky_factor_corr[2] Omega_chol;       // correlation matrix for random intercepts and slopes (chol decomp)
}

transformed parameters {
  vector[I] gammaIntercept;                 // individual effects (named)
  vector[I] gammaDays;
  real Intercept;
  real beta;

  Intercept = IntBase + Intercept01 * RTsd;
  beta = beta01 * 10;

  for (i in 1:I){  
    gammaIntercept[i]  = gamma[i, 1];
    gammaDays[i] = gamma[i, 2];
  }

} 

model {
  matrix[2,2] D;
  matrix[2,2] DC;
  vector[N] mu;                             // Linear predictor
  vector[2] gamma_mu;                       // vector of Intercept and beta

  D = diag_matrix(sigma_u);
  gamma_mu[1] = Intercept;
  gamma_mu[2] = beta;

  // priors
  Intercept01 ~ normal(0, 1);               // example of weakly informative priors;
  beta01 ~ normal(0, 1);                    // remove to essentially duplicate lme4 via improper prior

  Omega_chol ~  lkj_corr_cholesky(2.0); 

  sigma_u ~ cauchy(0, 2.5);                 // prior for RE scale
  sigma_y ~ cauchy(0, 2.5);                 // prior for residual scale

  DC = D * Omega_chol;

  for (i in 1:I)                            // loop for Subject random effects
    gamma[i] ~ multi_normal_cholesky(gamma_mu, DC);

  // likelihood
  for (n in 1:N)                          
    mu[n] = gammaIntercept[Subject[n]] + gammaDays[Subject[n]] * Days[n];

  RT ~ normal(mu, sigma_y);
}

generated quantities {
  matrix[2, 2] Omega;                       // correlation of RE
  vector[N] y_hat;
  
  Omega = tcrossprod(Omega_chol);
  
  for (n in 1:N)                 
    y_hat[n] = gammaIntercept[Subject[n]] + gammaDays[Subject[n]] * Days[n];
}

Estimation

Run the model and examine results. The following assumes a character string or file (bayes_mixed) of the previous model code.

library(rstan)

fit = sampling(
  bayes_mixed,
  data = dat,
  thin = 4,
  verbose = FALSE
)

Comparison

Compare to lme4 result.

print(
  fit,
  digits_summary = 3,
  pars  = c('Intercept', 'beta', 'sigma_y', 'sigma_u', 'Omega[1,2]'),
  probs = c(.025, .5, .975)
)
Inference for Stan model: 2506143d3919a87ea11841b6a26e9ada.
4 chains, each with iter=2000; warmup=1000; thin=4; 
post-warmup draws per chain=250, total post-warmup draws=1000.

              mean se_mean    sd    2.5%     50%   97.5% n_eff  Rhat
Intercept  252.224   0.211 6.801 238.987 252.342 266.310  1042 0.999
beta        10.189   0.054 1.668   6.891  10.176  13.476   962 0.998
sigma_y     25.901   0.052 1.568  23.080  25.815  29.188   897 1.000
sigma_u[1]  23.900   0.230 6.125  13.200  23.469  37.385   707 1.003
sigma_u[2]   6.162   0.047 1.452   3.953   5.951   9.660   953 1.000
Omega[1,2]   0.102   0.010 0.266  -0.407   0.106   0.606   778 0.999

Samples were drawn using NUTS(diag_e) at Wed Nov 25 17:34:19 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).
mod_lme = lmer(Reaction ~ Days + (Days | Subject), sleepstudy)
mod_lme
Linear mixed model fit by REML ['lmerMod']
Formula: Reaction ~ Days + (Days | Subject)
   Data: sleepstudy
REML criterion at convergence: 1743.628
Random effects:
 Groups   Name        Std.Dev. Corr
 Subject  (Intercept) 24.741       
          Days         5.922   0.07
 Residual             25.592       
Number of obs: 180, groups:  Subject, 18
Fixed Effects:
(Intercept)         Days  
     251.41        10.47  
cbind(
  coef(mod_lme)$Subject,
  matrix(get_posterior_mean(fit, par = c('gammaIntercept', 'gammaDays'))[, 'mean-all chains'],
         ncol = 2)
)
    (Intercept)       Days        1          2
308    253.6637 19.6662617 255.2697 19.4436027
309    211.0064  1.8476053 212.5908  1.6576886
310    212.4447  5.0184295 215.1223  4.5968134
330    275.0957  5.6529356 273.1367  5.9194018
331    273.6654  7.3973743 272.1793  7.6053031
332    260.4447 10.1951090 260.3817 10.1720560
333    268.2456 10.2436499 267.4948 10.3554621
334    244.1725 11.5418676 245.0794 11.2856905
335    251.0714 -0.2848792 249.6814 -0.1358559
337    286.2956 19.0955511 284.6991 19.2866572
349    226.1949 11.6407181 229.1268 11.1449475
350    238.3351 17.0815038 240.3025 16.7056700
351    255.9830  7.4520239 255.2410  7.6661530
352    272.2688 14.0032871 271.0198 14.0739847
369    254.6806 11.3395008 255.0000 11.2681311
370    225.7921 15.2897709 228.4023 14.7781029
371    252.2122  9.4791297 252.3174  9.5119921
372    263.7197 11.7513080 263.0256 11.8830409

Visualize

Visualize the posterior predictive distribution.

# shinystan::launch_shinystan(fit)  # diagnostic plots

library(bayesplot)

pp_check(
  dat$RT, 
  rstan::extract(fit, par = 'y_hat')$y_hat[1:10, ], 
  fun = 'dens_overlay'
)