Bayesian Multinomial Models

I spent some time on these models to better understand them in the traditional and Bayesian context, as well as profile potential speed gains in the Stan code. If you were doing what many would call ‘multinomial regression’ without qualification, I can recommend brms with the ‘categorical’ distribution. However, I’m not aware of it being able to accommodate choice-specific variables easily, i.e. ones that vary across choices (though it does accommodate choice specific effects). I show the standard model here with the usual demonstration, and show some code for the most complex setting of choice-specific, individual-specific, and choice-constant variables.

See the [multinomial chapter][Multinomial] for the non-Bayesian approach.

Data Setup

Depending on the complexity of the data, you may need to create a data set specific to the problem.

library(haven)
library(tidyverse)

program = read_dta("https://stats.idre.ucla.edu/stat/data/hsbdemo.dta") %>% 
  as_factor() %>% 
  mutate(prog = relevel(prog, ref = "academic"))


head(program[,1:5])
# A tibble: 6 x 5
     id female ses    schtyp prog    
  <dbl> <fct>  <fct>  <fct>  <fct>   
1    45 female low    public vocation
2   108 male   middle public general 
3    15 male   high   public vocation
4    67 male   low    public vocation
5   153 male   middle public vocation
6    51 female high   public general 
library(mlogit)

programLong = program %>% 
  select(id, prog, ses, write) %>% 
  mlogit.data(
    shape  = 'wide',
    choice = 'prog',
    id.var = 'id'
  )

head(programLong)
~~~~~~~
 first 10 observations out of 600 
~~~~~~~
   id  prog    ses write chid      alt      idx
1   1 FALSE    low    44   11 academic  11:emic
2   1 FALSE    low    44   11  general  11:eral
3   1  TRUE    low    44   11 vocation  11:tion
4   2 FALSE middle    41    9 academic   9:emic
5   2 FALSE middle    41    9  general   9:eral
6   2  TRUE middle    41    9 vocation   9:tion
7   3  TRUE    low    65  159 academic 159:emic
8   3 FALSE    low    65  159  general 159:eral
9   3 FALSE    low    65  159 vocation 159:tion
10  4  TRUE    low    50   30 academic  30:emic

~~~ indexes ~~~~
   chid id      alt
1    11  1 academic
2    11  1  general
3    11  1 vocation
4     9  2 academic
5     9  2  general
6     9  2 vocation
7   159  3 academic
8   159  3  general
9   159  3 vocation
10   30  4 academic
indexes:  1, 1, 2 
X = model.matrix(prog ~ ses + write, data = program)
y = program$prog

X = X[order(y),]
y = y[order(y)]

Model Code

data {
  int K;
  int N;
  int D;
  int y[N];
  matrix[N,D] X;
}

transformed data {
  vector[D] zeros;
  
  zeros = rep_vector(0, D);
}

parameters {
  matrix[D, K-1] beta_raw;
}

transformed parameters {
  matrix[D, K] beta;
  
  beta = append_col(zeros, beta_raw);
}

model {
  matrix[N, K] L;                   # Linear predictor
  
  L = X * beta;
  
  // prior for coefficients
  to_vector(beta_raw) ~ normal(0, 10);
  
  // likelihood
  for (n in 1:N)
    y[n] ~ categorical_logit(to_vector(L[n]));
}

Estimation

We’ll get the data prepped for Stan, and the model code is assumed to be in an object bayes_multinom.

# N = sample size, x is the model matrix, y integer version of class outcome, k=
# number of classes, D is dimension of model matrix
stan_data = list(
  N = nrow(X),
  X = X,
  y = as.integer(y),
  K = n_distinct(y),
  D = ncol(X)
)


library(rstan)

fit = sampling(
  bayes_multinom,
  data  = stan_data,
  thin  = 4,
  cores = 4
)

Comparison

We’ll need to do a bit of reordering, but otherwise we can see that the models come to similar conclusions.

print(
  fit,
  digits = 3,
  par    = c('beta'),
  probs  = c(.025, .5, .975)
)
Inference for Stan model: 6f6b51695c6eeda4b74c0665f4447c97.
4 chains, each with iter=2000; warmup=1000; thin=4; 
post-warmup draws per chain=250, total post-warmup draws=1000.

            mean se_mean    sd   2.5%    50%  97.5% n_eff  Rhat
beta[1,1]  0.000     NaN 0.000  0.000  0.000  0.000   NaN   NaN
beta[1,2]  2.834   0.037 1.214  0.453  2.823  5.212  1100 0.999
beta[1,3]  5.271   0.039 1.187  2.906  5.227  7.640   910 0.999
beta[2,1]  0.000     NaN 0.000  0.000  0.000  0.000   NaN   NaN
beta[2,2] -0.537   0.014 0.437 -1.388 -0.533  0.332  1030 1.002
beta[2,3]  0.340   0.016 0.488 -0.616  0.346  1.282   956 1.001
beta[3,1]  0.000     NaN 0.000  0.000  0.000  0.000   NaN   NaN
beta[3,2] -1.207   0.016 0.508 -2.193 -1.199 -0.232  1073 1.000
beta[3,3] -1.001   0.021 0.617 -2.187 -0.979  0.145   864 0.999
beta[4,1]  0.000     NaN 0.000  0.000  0.000  0.000   NaN   NaN
beta[4,2] -0.058   0.001 0.022 -0.103 -0.058 -0.015  1068 1.000
beta[4,3] -0.116   0.001 0.023 -0.160 -0.115 -0.072   912 1.000

Samples were drawn using NUTS(diag_e) at Wed Nov 25 18:00:51 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).
fit_coefs    = get_posterior_mean(fit, par = 'beta_raw')[, 5]

mlogit_mod   = mlogit(prog ~ 1 | ses + write, data = programLong)
mlogit_coefs = coef(mlogit_mod)[c(1, 3, 5, 7, 2, 4, 6, 8)]
m_logit fit
(Intercept):general 2.852 2.834
(Intercept):vocation 5.218 5.271
sesmiddle:general -0.533 -0.537
sesmiddle:vocation 0.291 0.340
seshigh:general -1.163 -1.207
seshigh:vocation -0.983 -1.001
write:general -0.058 -0.058
write:vocation -0.114 -0.116

Adding Complexity

The following adds choice-specific (a.k.a. alternative-specific) variables, e.g. among product choices, this may include price. Along with this we may have, along with choice constant, and the typical individual varying covariates.

This code worked at the time, but I wasn’t interested enough to try it again recently. You can use the classic ‘travel’ data as an example (available as TravelMode in AER), or fishing from mlogit. Essentially you’ll have three separate data components- a matrix for individual-specific covariates, one for alternative specific, and one for alternative constant covariates.

data {
  int K;                               // number of choices
  int N;                               // number of individuals
  int D;                               // number of indiv specific variables
  int G;                               // number of alt specific variables
  int T;                               // number of alt constant variables
  
  int y[N*K];                          // choices
  vector[N*K] choice;                  // choice made (logical)
  
  matrix[N, D]       X;                // data for indiv specific effects
  matrix[N*K, G]     Y;                // data for alt specific effects
  matrix[N*(K-1), T] Z;                // data for alt constant effects
}

parameters {   
  matrix[D, K-1] beta;                 // individual specific coefs
  matrix[G, K]  gamma;                 // choice specific coefs for alt-specific variables
  vector[T]     theta;                 // choice constant coefs for alt-specific variables
  
}

model {
  matrix[N, K-1] Vx;                   // Utility for individual vars
  
  vector[N*K]   Vy0;
  matrix[N, K-1] Vy;                   // Utility for alt-specific/alt-varying vars
  
  vector[N*(K-1)] Vz0;
  matrix[N, (K-1)] Vz;                 // Utility for alt-specific/alt-constant vars

  matrix[N, K-1] V;                    // combined utilities
  
  vector[N] baseProbVec;               // reference group probabilities
  real ll0;                            // intermediate log likelihood
  real loglik;                         // final log likelihood


  // priors  
  to_vector(beta)  ~ normal(0, 10);    // diffuse priors on coefficients
  to_vector(gamma) ~ normal(0, 10);
  to_vector(theta) ~ normal(0, 10);
  

  // likelihood
  
  // 'Utilities'
  Vx = X * beta;
  
  for(alt in 1:K){
    vector[G] par;
    int start;
    int end;

    par   = gamma[,alt];
    start = N*alt - N+1;
    end   = N*alt;
    
    Vy0[start:end] = Y[start:end,] * par;
    if(alt > 1) Vy[,alt-1] = Vy0[start:end] - Vy0[1:N];
  }
  
  Vz0 = Z * theta;
  
  for(alt in 1:(K-1)){
    int start;
    int end;

    start = N*alt - N+1;
    end   = N*alt;
    Vz[,alt] = Vz0[start:end];
  }

  V = Vx + Vy + Vz;

  for(n in 1:N)  baseProbVec[n] = 1/(1 + sum(exp(V[n])));
  
  ll0 = dot_product(to_vector(V), choice[(N+1):(N*K)]); // just going to assume no neg index
  loglik  = sum(log(baseProbVec)) + ll0;
  target += loglik;
  
}


generated quantities {
  matrix[N, K-1] fitted_nonref;
  vector[N] fitted_ref;
  matrix[N, K] fitted;
  
  matrix[N, K-1] Vx;                   // Utility for individual variables
  
  vector[N*K] Vy0;
  matrix[N, K-1] Vy;                   // Utility for alt-specific/alt-varying variables
  
  vector[N*(K-1)] Vz0;
  matrix[N, (K-1)] Vz;                 // Utility for alt-specific/alt-constant variables

  matrix[N, K-1] V;                    // combined utilities
  
  vector[N] baseProbVec;               // reference group probabilities

  Vx = X * beta;
  
  for(alt in 1:K) {
    vector[G] par;
    int start;
    int end;

    par   = gamma[,alt];
    start = N*alt - N+1;
    end   = N*alt;
    
    Vy0[start:end] = Y[start:end, ] * par;
    
    if (alt > 1) Vy[,alt-1] = Vy0[start:end] - Vy0[1:N];
  }
  
  Vz0 = Z * theta;
  
  for(alt in 1:(K-1)){
    int start;
    int end;

    start = N*alt-N+1;
    end = N*alt;
    
    Vz[,alt] = Vz0[start:end];
  }

  V = Vx + Vy + Vz;
  
  for(n in 1:N)  baseProbVec[n] = 1 / (1 + sum(exp(V[n])));
  fitted_nonref = exp(V) .* rep_matrix(baseProbVec, K-1);
  
  for(n in 1:N) fitted_ref[n] = 1 - sum(fitted_nonref[n]);
  fitted = append_col(fitted_ref, fitted_nonref);
}