Bayesian Multilevel Mediation

The following demonstrates an indirect effect in a multilevel situation. It is based on Yuan & MacKinnon 2009, which provides some Bugs code. In what follows we essentially have two models, one where the ‘mediator’ is the response; the other regards the primary response of interest (noted y). They will be referred to with Med or Main respectively.

Data Setup

The two main models are expressed conceptually as follows:

\[\textrm{Mediator} \sim \alpha_{Med} + \beta_{Med}\cdot X\]

\[y \sim \alpha_{Main} + \beta_{1\_{Main}}\cdot X + \beta_{2\_{Main}}\cdot \textrm{Mediator}\]

However, there will be random effects for a grouping variable for each coefficient, i.e. random intercepts and slopes, for both the mediator model and the outcome model.

Let’s create data to this effect. In the following we will ultimately have 1000 total observations, with 50 groups (20 observations each).

library(tidyverse)

set.seed(8675309)

N  = 1000
n_groups = 50
n_per_group = N/n_groups

# random effects for mediator model
# create cov matrix of RE etc. with no covariance between model random effects
# covmat_RE = matrix(c(1,-.15,0,0,0,
#                       -.15,.4,0,0,0,
#                       0,0,1,-.1,.15,
#                       0,0,-.1,.3,0,
#                       0,0,.15,0,.2), nrow=5, byrow = T)

# or with slight cov added to indirect coefficient RE; both matrices are pos def
covmat_RE = matrix(c( 1.00, -0.15,  0.00,  0.00,  0.00,
                     -0.15,  0.64,  0.00,  0.00, -0.10,
                      0.00,  0.00,  1.00, -0.10,  0.15,
                      0.00,  0.00, -0.10,  0.49,  0.00,
                      0.00, -0.10,  0.15,  0.00,  0.25), nrow = 5, byrow = TRUE)

# inspect
covmat_RE
      [,1]  [,2]  [,3]  [,4]  [,5]
[1,]  1.00 -0.15  0.00  0.00  0.00
[2,] -0.15  0.64  0.00  0.00 -0.10
[3,]  0.00  0.00  1.00 -0.10  0.15
[4,]  0.00  0.00 -0.10  0.49  0.00
[5,]  0.00 -0.10  0.15  0.00  0.25
# inspect as correlation
cov2cor(covmat_RE)
        [,1]    [,2]       [,3]       [,4]  [,5]
[1,]  1.0000 -0.1875  0.0000000  0.0000000  0.00
[2,] -0.1875  1.0000  0.0000000  0.0000000 -0.25
[3,]  0.0000  0.0000  1.0000000 -0.1428571  0.30
[4,]  0.0000  0.0000 -0.1428571  1.0000000  0.00
[5,]  0.0000 -0.2500  0.3000000  0.0000000  1.00
colnames(covmat_RE) = rownames(covmat_RE) = 
  c('alpha_Med', 'beta_Med', 'alpha_Main', 'beta1_Main', 'beta2_Main')

# simulate random effects
re = MASS::mvrnorm(
  n_groups,
  mu    = rep(0, 5),
  Sigma = covmat_RE,
  empirical = TRUE
)

# random effects for mediator model
ranef_alpha_Med = rep(re[, 'alpha_Med'], e = n_per_group)
ranef_beta_Med  = rep(re[, 'beta_Med'],  e = n_per_group)

# random effects for main model                                                 
ranef_alpha_Main = rep(re[, 'alpha_Main'], e = n_per_group)
ranef_beta1_Main = rep(re[, 'beta1_Main'], e = n_per_group)
ranef_beta2_Main = rep(re[, 'beta2_Main'], e = n_per_group)

## fixed effects
alpha_Med = 2
beta_Med  = .2

alpha_Main = 1
beta1_Main = .3
beta2_Main = -.2

# residual variance
resid_Med  = MASS::mvrnorm(N, 0, .75^2, empirical = TRUE)
resid_Main = MASS::mvrnorm(N, 0, .50^2, empirical = TRUE)


# Collect parameters for later comparison
params = c(
  alpha_Med  = alpha_Med,
  beta_Med   = beta_Med,
  sigma_Med  = sd(resid_Med),
  alpha_Main = alpha_Main,
  beta1_Main = beta1_Main,
  beta2_Main = beta2_Main,
  sigma_y    = sd(resid_Main),
  alpha_Med_sd = sqrt(diag(covmat_RE)[1]),
  beta_Med_sd  = sqrt(diag(covmat_RE)[2]),
  alpha_sd = sqrt(diag(covmat_RE)[3]),
  beta1_sd = sqrt(diag(covmat_RE)[4]),
  beta2_sd = sqrt(diag(covmat_RE)[5])
)

ranefs =  cbind(
  gamma_alpha_Med = unique(ranef_alpha_Med),
  gamma_beta_Med  = unique(ranef_beta_Med),
  gamma_alpha = unique(ranef_alpha_Main),
  gamma_beta1 = unique(ranef_beta1_Main),
  gamma_beta2 = unique(ranef_beta2_Main)
)

Finally, we can create the data for analysis.

X = rnorm(N, sd = 2)

Med = (alpha_Med + ranef_alpha_Med) + (beta_Med + ranef_beta_Med) * X + resid_Med[, 1]

y = (alpha_Main + ranef_alpha_Main) + (beta1_Main + ranef_beta1_Main) * X + 
  (beta2_Main + ranef_beta2_Main) *  Med + resid_Main[, 1]

group = rep(1:n_groups, e = n_per_group)

standat = list(
  X   = X,
  Med = Med,
  y   = y,
  Group = group,
  J = length(unique(group)),
  N = length(y)
)

Model Code

In the following, the cholesky decomposition of the RE covariance matrix is used for efficiency. As a rough guide, the default data with rN` observations took about a minute or so to run.

data {
  int<lower = 1> N;                              // Sample size
  vector[N] X;                                   // Explanatory variable
  vector[N] Med;                                 // Mediator
  vector[N] y;                                   // Response
  int<lower = 1> J;                              // Number of groups
  int<lower = 1,upper = J> Group[N];             // Groups
}

parameters{
  real alpha_Med;                                // mediator model reg parameters and related
  real beta_Med;
  real<lower = 0> sigma_alpha_Med;
  real<lower = 0> sigma_beta_Med;
  real<lower = 0> sigmaMed;

  real alpha_Main;                               // main model reg parameters and related
  real beta1_Main;
  real beta2_Main;
  real<lower = 0> sigma_alpha;
  real<lower = 0> sigma_beta1;
  real<lower = 0> sigma_beta2;
  real<lower = 0> sigma_y;

  cholesky_factor_corr[5] Omega_chol;            // chol decomp of corr matrix for random effects

  vector<lower = 0>[5] sigma_ranef;              // sd for random effects

  matrix[J,5] gamma;                             // random effects
}

transformed parameters{
  vector[J] gamma_alpha_Med;
  vector[J] gamma_beta_Med;

  vector[J] gamma_alpha;
  vector[J] gamma_beta1;
  vector[J] gamma_beta2;

  for (j in 1:J){
    gamma_alpha_Med[j] = gamma[j,1];
    gamma_beta_Med[j]  = gamma[j,2];
    gamma_alpha[j] = gamma[j,3];
    gamma_beta1[j] = gamma[j,4];
    gamma_beta2[j] = gamma[j,5];
  }
}

model {
  vector[N] mu_y;                                // linear predictors for response and mediator
  vector[N] mu_Med;
  matrix[5,5] D;
  matrix[5,5] DC;

  // priors
  // mediator model
  // fixef
  // for scale params the cauchy is a little more informative here due 
  // to the nature of the data
  sigma_alpha_Med ~ cauchy(0, 1);                
  sigma_beta_Med  ~ cauchy(0, 1);
  alpha_Med ~ normal(0, sigma_alpha_Med);   
  beta_Med  ~ normal(0, sigma_beta_Med);

  // residual scale
  sigmaMed ~ cauchy(0, 1);

  // main model
  // fixef
  sigma_alpha ~ cauchy(0, 1);
  sigma_beta1 ~ cauchy(0, 1);
  sigma_beta2 ~ cauchy(0, 1);
  alpha_Main  ~ normal(0, sigma_alpha);      
  beta1_Main  ~ normal(0, sigma_beta1);
  beta2_Main  ~ normal(0, sigma_beta2);

  // residual scale
  sigma_y ~ cauchy(0, 1);

  // ranef sampling via cholesky decomposition
  sigma_ranef ~ cauchy(0, 1);
  Omega_chol  ~  lkj_corr_cholesky(2.0);

  D  = diag_matrix(sigma_ranef);
  DC = D * Omega_chol;
  
  for (j in 1:J)                                 // loop for Group random effects
    gamma[j] ~ multi_normal_cholesky(rep_vector(0, 5), DC);

  // Linear predictors
  for (n in 1:N){
    mu_Med[n] = alpha_Med + gamma_alpha_Med[Group[n]] + 
    (beta_Med + gamma_beta_Med[Group[n]]) * X[n];
    
    mu_y[n]   = alpha_Main + gamma_alpha[Group[n]] + 
    (beta1_Main + gamma_beta1[Group[n]]) * X[n] + 
    (beta2_Main + gamma_beta2[Group[n]]) * Med[n] ;
  }
  
  
  // sampling for primary models
  Med ~ normal(mu_Med, sigmaMed);
  y   ~ normal(mu_y, sigma_y);
}

generated quantities{
  real naive_ind_effect;
  real avg_ind_effect;
  real total_effect;
  matrix[5,5] cov_RE;
  vector[N] y_hat;    
  
  cov_RE = diag_matrix(sigma_ranef) * tcrossprod(Omega_chol) * diag_matrix(sigma_ranef);

  naive_ind_effect = beta_Med*beta2_Main;
  avg_ind_effect   = beta_Med*beta2_Main + cov_RE[2,5];    // add cov of random slopes for mediator effects
  total_effect     = avg_ind_effect + beta1_Main; 
  
  for (n in 1:N){
    y_hat[n]   = alpha_Main + gamma_alpha[Group[n]] + 
    (beta1_Main + gamma_beta1[Group[n]]) * X[n] + 
    (beta2_Main + gamma_beta2[Group[n]]) * Med[n] ;
  }
}

Estimation

Run the model and examine results. The following assumes a character string or file (bayes_med_model) of the previous model code.

library(rstan)

fit = sampling(
  bayes_med_model,
  data    = standat,
  iter    = 3000,
  warmup  = 2000,
  thin    = 4,
  cores   = 4,
  control = list(adapt_delta = .99, max_treedepth = 15)
)

Comparison

Main parameters include fixed and random effect standard deviation, plus those related to indirect effect.

mainpars = c(
  'alpha_Med',
  'beta_Med',
  'sigmaMed',
  'alpha_Main',
  'beta1_Main',
  'beta2_Main',
  'sigma_y',
  'sigma_ranef',
  'naive_ind_effect',
  'avg_ind_effect',
  'total_effect'
)

print(
  fit,
  digits = 3,
  probs  = c(.025, .5, 0.975),
  pars   = mainpars
)
Inference for Stan model: e771ba356dcfd779b3cf0ac752042346.
4 chains, each with iter=3000; warmup=2000; thin=4; 
post-warmup draws per chain=250, total post-warmup draws=1000.

                   mean se_mean    sd   2.5%    50%  97.5% n_eff  Rhat
alpha_Med         1.996   0.010 0.150  1.693  1.999  2.280   244 1.015
beta_Med          0.142   0.008 0.113 -0.047  0.139  0.369   220 1.021
sigmaMed          0.744   0.001 0.018  0.711  0.744  0.780   940 1.001
alpha_Main        0.964   0.008 0.158  0.646  0.966  1.281   425 1.008
beta1_Main        0.273   0.006 0.108  0.050  0.275  0.483   284 1.002
beta2_Main       -0.153   0.004 0.081 -0.307 -0.151 -0.002   410 1.004
sigma_y           0.495   0.000 0.012  0.472  0.494  0.520   946 1.001
sigma_ranef[1]    1.031   0.004 0.108  0.845  1.026  1.269   841 1.005
sigma_ranef[2]    0.839   0.003 0.089  0.685  0.831  1.031  1058 0.999
sigma_ranef[3]    1.046   0.004 0.123  0.839  1.031  1.313   997 1.002
sigma_ranef[4]    0.776   0.003 0.085  0.625  0.773  0.969  1033 0.998
sigma_ranef[5]    0.514   0.002 0.061  0.401  0.511  0.647   952 1.000
naive_ind_effect -0.024   0.002 0.025 -0.082 -0.018  0.006   231 1.025
avg_ind_effect   -0.115   0.002 0.067 -0.259 -0.113  0.005   863 1.001
total_effect      0.158   0.007 0.125 -0.087  0.159  0.391   356 1.000

Samples were drawn using NUTS(diag_e) at Sat Dec 12 12:18:03 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

We can use a piecemeal mixed model via lme4 for initial comparison. However, it can’t directly estimate mediated effect, and it won’t pick up on correlation of random effects between models.

library(lme4)
mod_Med = lmer(Med ~ X + (1 + X | group))
summary(mod_Med)
Linear mixed model fit by REML ['lmerMod']
Formula: Med ~ X + (1 + X | group)

REML criterion at convergence: 2647.9

Scaled residuals: 
     Min       1Q   Median       3Q      Max 
-3.14663 -0.67172  0.03569  0.64922  2.87581 

Random effects:
 Groups   Name        Variance Std.Dev. Corr 
 group    (Intercept) 0.9739   0.9869        
          X           0.6414   0.8009   -0.22
 Residual             0.5522   0.7431        
Number of obs: 1000, groups:  group, 50

Fixed effects:
            Estimate Std. Error t value
(Intercept)   2.0046     0.1416  14.155
X             0.1901     0.1139   1.668

Correlation of Fixed Effects:
  (Intr)
X -0.217
mod_Main = lmer(y ~ X + Med + (1 + X + Med | group))
summary(mod_Main)
Linear mixed model fit by REML ['lmerMod']
Formula: y ~ X + Med + (1 + X + Med | group)

REML criterion at convergence: 2030.7

Scaled residuals: 
    Min      1Q  Median      3Q     Max 
-3.6523 -0.6368 -0.0054  0.6344  2.8578 

Random effects:
 Groups   Name        Variance Std.Dev. Corr       
 group    (Intercept) 0.9816   0.9908              
          X           0.5352   0.7316   -0.11      
          Med         0.2367   0.4865    0.35 -0.03
 Residual             0.2445   0.4945              
Number of obs: 1000, groups:  group, 50

Fixed effects:
            Estimate Std. Error t value
(Intercept)  0.98977    0.14891   6.647
X            0.28636    0.10507   2.725
Med         -0.17958    0.07226  -2.485

Correlation of Fixed Effects:
    (Intr) X     
X   -0.097       
Med  0.225 -0.041
# should equal the naive estimate in the following code
lme_indirect_effect = fixef(mod_Med)['X'] * fixef(mod_Main)['Med']

Using the mediation package will provide a better estimate, and can handle this simple mixed model setting.

# library(mediation)

mediation_mixed = mediation::mediate(
  model.m  = mod_Med,
  model.y  = mod_Main,
  treat    = 'X',
  mediator = 'Med'
)

summary(mediation_mixed)

Causal Mediation Analysis 

Quasi-Bayesian Confidence Intervals

Mediator Groups: group 

Outcome Groups: group 

Output Based on Overall Averages Across Groups 

               Estimate 95% CI Lower 95% CI Upper p-value    
ACME            -0.1144      -0.1805        -0.06  <2e-16 ***
ADE              0.2877       0.0777         0.50    0.01 ** 
Total Effect     0.1733      -0.0385         0.40    0.11    
Prop. Mediated  -0.5990      -6.6281         4.56    0.11    
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Sample Size Used: 1000 


Simulations: 1000 

Extract parameters for comparison.

pars_primary = get_posterior_mean(fit, pars = mainpars)[, 5]
pars_re_cov  = get_posterior_mean(fit, pars = 'Omega_chol')[, 5] # or take 'cov_RE' from monte carlo sim
pars_re      = get_posterior_mean(fit, pars = c('sigma_ranef'))[, 5]

Fixed effects and random effect variances.

param true bayes lme4
alpha_Med 2.00 1.996 2.005
beta_Med 0.20 0.142 0.190
sigma_Med 0.75 0.744 0.743
alpha_Main 1.00 0.964 0.990
beta1_Main 0.30 0.273 0.286
beta2_Main -0.20 -0.153 -0.180
sigma_y 0.50 0.495 0.494
alpha_Med_sd.alpha_Med 1.00 1.031 0.987
beta_Med_sd.beta_Med 0.80 0.839 0.801
alpha_sd.alpha_Main 1.00 1.046 0.991
beta1_sd.beta1_Main 0.70 0.776 0.732
beta2_sd.beta2_Main 0.50 0.514 0.486

Compare the covariances of the random effects. The first shows the full covariance matrix for mediator and outcome, then broken out separately.

$true
           alpha_Med beta_Med alpha_Main beta1_Main beta2_Main
alpha_Med       1.00    -0.15       0.00       0.00       0.00
beta_Med       -0.15     0.64       0.00       0.00      -0.10
alpha_Main      0.00     0.00       1.00      -0.10       0.15
beta1_Main      0.00     0.00      -0.10       0.49       0.00
beta2_Main      0.00    -0.10       0.15       0.00       0.25

$estimates
      [,1]  [,2]  [,3]  [,4]  [,5]
[1,]  1.06 -0.16  0.02  0.02 -0.04
[2,] -0.16  0.69 -0.01 -0.03 -0.09
[3,]  0.02 -0.01  1.05 -0.08  0.15
[4,]  0.02 -0.03 -0.08  0.57 -0.01
[5,] -0.04 -0.09  0.15 -0.01  0.24
$vcov_Med
          alpha_Med beta_Med
alpha_Med      1.00    -0.15
beta_Med      -0.15     0.64

$vcov_Med_bayes
      [,1]  [,2]
[1,]  1.06 -0.16
[2,] -0.16  0.69

$vcov_Med_lme4
            (Intercept)     X
(Intercept)        0.97 -0.18
X                 -0.18  0.64
$vcov_Main
           alpha_Main beta1_Main beta2_Main
alpha_Main       1.00      -0.10       0.15
beta1_Main      -0.10       0.49       0.00
beta2_Main       0.15       0.00       0.25

$vcov_Main_bayes
      [,1]  [,2]  [,3]
[1,]  1.05 -0.08  0.15
[2,] -0.08  0.57 -0.01
[3,]  0.15 -0.01  0.24

$vcov_Main_lme4
            (Intercept)     X   Med
(Intercept)        0.98 -0.08  0.17
X                 -0.08  0.54 -0.01
Med                0.17 -0.01  0.24

Compare indirect effects

true est_bayes naive_bayes naive_lmer mediation_pack
-0.14 -0.115 -0.024 -0.034 -0.114

Note that you can use brms to estimate this model as follows. The i allows the random effects to correlate across Mediator and outcome models. We have to convert the correlation estimate back to the covariance estimate to get the indirect value to compare to our base Stan result.

library(brms)

f = 
  bf(Med ~ X + (1 + X |i| group)) +
  bf(y   ~ X + Med + (1 + X + Med |i| group)) +
  set_rescor(FALSE)
       

fit_brm = brm(
  f,
  data  = data.frame(X, Med, y, group),
  cores = 4,
  thin  = 4,
  seed  = 1234,
  control = list(adapt_delta = .99, max_treedepth = 15)
)
summary(fit_brm)
 Family: MV(gaussian, gaussian) 
  Links: mu = identity; sigma = identity
         mu = identity; sigma = identity 
Formula: Med ~ X + (1 + X | i | group) 
         y ~ X + Med + (1 + X + Med | i | group) 
   Data: data.frame(X, Med, y, group) (Number of observations: 1000) 
Samples: 4 chains, each with iter = 2000; warmup = 1000; thin = 4;
         total post-warmup samples = 1000

Group-Level Effects: 
~group (Number of levels: 50) 
                               Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Med_Intercept)                  1.05      0.11     0.86     1.30 1.00      859      781
sd(Med_X)                          0.85      0.09     0.68     1.05 1.00      744      796
sd(y_Intercept)                    1.06      0.12     0.84     1.33 1.00      938      988
sd(y_X)                            0.78      0.09     0.64     0.98 1.00      849     1029
sd(y_Med)                          0.51      0.06     0.41     0.64 1.00      816      876
cor(Med_Intercept,Med_X)          -0.19      0.14    -0.44     0.09 1.00      847      992
cor(Med_Intercept,y_Intercept)     0.02      0.14    -0.26     0.28 1.00      925      993
cor(Med_X,y_Intercept)            -0.02      0.15    -0.31     0.26 1.00      879      994
cor(Med_Intercept,y_X)             0.03      0.14    -0.23     0.31 1.00      945      866
cor(Med_X,y_X)                    -0.05      0.14    -0.33     0.23 1.00      918      813
cor(y_Intercept,y_X)              -0.11      0.15    -0.39     0.18 1.00      886      946
cor(Med_Intercept,y_Med)          -0.07      0.15    -0.36     0.21 1.00      849      990
cor(Med_X,y_Med)                  -0.22      0.14    -0.48     0.06 1.00      871      882
cor(y_Intercept,y_Med)             0.31      0.15    -0.00     0.58 1.00      762      882
cor(y_X,y_Med)                    -0.01      0.15    -0.30     0.28 1.00      903      987

Population-Level Effects: 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Med_Intercept     2.01      0.15     1.70     2.29 1.00      684      850
y_Intercept       0.99      0.16     0.67     1.32 1.00      856      961
Med_X             0.19      0.12    -0.04     0.42 1.00      817      794
y_X               0.29      0.11     0.07     0.51 1.01      841      914
y_Med            -0.18      0.08    -0.33    -0.03 1.00      919      884

Family Specific Parameters: 
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma_Med     0.74      0.02     0.71     0.78 1.00     1007      732
sigma_y       0.50      0.01     0.47     0.52 1.00      936      782

Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
hypothesis(
  fit_brm,
  'b_y_Med*b_Med_X + cor_group__Med_X__y_Med*sd_group__Med_X*sd_group__y_Med = 0',
  class = NULL,
  seed  =  1234
)
Hypothesis Tests for class :
                Hypothesis Estimate Est.Error CI.Lower CI.Upper Evid.Ratio Post.Prob Star
1 (b_y_Med*b_Med_X+... = 0    -0.13      0.08     -0.3        0         NA        NA     
---
'CI': 90%-CI for one-sided and 95%-CI for two-sided hypotheses.
'*': For one-sided hypotheses, the posterior probability exceeds 95%;
for two-sided hypotheses, the value tested against lies outside the 95%-CI.
Posterior probabilities of point hypotheses assume equal prior probabilities.

Visualization

library(bayesplot)

pp_check(
  standat$y, 
  rstan::extract(fit, par = 'y_hat')$y_hat[1:10, ], 
  fun = 'dens_overlay'
)