Variational Bayes Regression

The following provides a function for estimating the parameters of a linear regression via variational inference. See Drugowitsch (2014) for an overview of the method outlined in Bishop (2006).

For the primary function I will use the notation used in the Drugowitsch article in most cases. Here w, represents the coefficients, and τ the precision (inverse variance). The likelihood for target y is N(Xw, τ-1). Priors for w and tau are normal inverse gamma N(0, (τα)-1) Gamma(a0, b0).

References:

Data Setup

We can simulate some data as a starting point, in this case, basic tabular data used in the standard regression problem. Here, I explicitly note the intercept, as it is added to the model matrix within the vb_reg function.

library(tidyverse)

set.seed(1234)

n = 100
d = 3
coefs = c(1, 2, 3, 5)
sigma = 2

X = replicate(d, rnorm(n))                         # predictors
colnames(X) = paste0('X', 1:d)

y = cbind(1, X) %*% coefs + rnorm(n, sd = sigma)   # target

df = data.frame(X, y)

We can also look at the higher dimension case as done in Drugowitsch section 2.6.2.

n     = 150
ntest = 50
d     = 100
coefs = rnorm(d + 1)
sigma = 1

X_train = cbind(1, replicate(d, rnorm(n)))
y_train = X_train %*% coefs + rnorm(n, sd = sigma)

X_test = cbind(1, replicate(d, rnorm(ntest)))
y_test = X_test %*% coefs + rnorm(ntest, sd = sigma)

Function

First, the main function. For this demo, automatic relevance determination is an argument rather than a separate function.

vb_reg <- function(
  X,
  y,
  a0 = 10e-2,
  b0 = 10e-4,
  c0 = 10e-2,
  d0 = 10e-4,
  tol = 1e-8,
  maxiter = 1000,
  ard = F
  ) {
  # X: model matrix
  # y: the response
  # a0, b0 prior parameters for tau
  # c0, d0 hyperprior parameters for alpha
  # tol: tolerance value to end iterations
  # maxiter: alternative way to end iterations
  
  
  # initializations
  X  = cbind(1, X)
  D  = ncol(X)
  N  = nrow(X)
  w  = rep(0, D) 
  XX = crossprod(X)
  Xy = crossprod(X,y)

  a_N = a0 + N/2
  
  if (!ard) {
    c_N = c0 + D/2
    E_alpha = c0/d0  
  } else {
    c_N = c0 + 1/2
    E_alpha = rep(c0/d0, D)
  }
  

  tolCurrent = 1
  iter = 0
  LQ   = 0
  
  while(iter < maxiter && tolCurrent > tol ){
    iter = iter + 1
    # wold = w
    
    if(!ard){
      b_N  = b0 + 1/2 * (crossprod(y - X%*%w) + E_alpha * crossprod(w))
      VInv = diag(E_alpha, D) + XX
      V = solve(VInv)
      w = V %*% Xy
      E_wtau  = a_N/b_N * crossprod(w) + sum(diag(V))
      d_N     = d0 + 1/2*E_wtau
      E_alpha = c(c_N/d_N)
    } else {
      b_N  = b0 + 1/2 * (crossprod(y - X%*%w) + t(w) %*% diag(E_alpha) %*% w)
      VInv = diag(E_alpha) + XX
      V = solve(VInv)
      w = V %*% Xy
      E_wtau  = a_N/b_N*crossprod(w) + sum(diag(V))
      d_N     = d0 + 1/2*(c(w)^2 * c(a_N/b_N) + diag(V))
      E_alpha = c(c_N/d_N)
    }

    
    LQ_old = LQ
    
    suppressWarnings({
    LQ = -N/2*log(2*pi) - 1/2 * (a_N/b_N * crossprod(y- crossprod(t(X), w)) + sum(XX * V)) + 
      1/2 * determinant(V, log = TRUE)$modulus + D/2 - lgamma(a0) + a0 * log(b0) - b0 * a_N / b_N + 
      lgamma(a_N) - a_N * log(b_N) + a_N - lgamma(c0) + c0*log(d0) + 
      lgamma(c_N) - sum(c_N*log(d_N))
    })
    
    tolCurrent = abs(LQ - LQ_old)
    # alternate tolerance, comment out LQ_old up to this line if using
    # tolCurrent = sum(abs(w - wold))  
  }
  
  res = list(
    coef  = w,
    sigma = sqrt(1 / (E_wtau / crossprod(w))),
    LQ    = LQ,
    iterations = iter,
    tol   = tolCurrent
  )
  
  if (iter >= maxiter)
    append(res, warning('Maximum iterations reached.'))
  else
    res
}

Estimation

First we can estimate the model using the smaller data.

fit_small = vb_reg(X, y, tol = 1e-8, ard = FALSE)

glimpse(fit_small)
List of 5
 $ coef      : num [1:4, 1] 1.01 2.29 3.29 5.02
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr [1:4] "" "X1" "X2" "X3"
  .. ..$ : NULL
 $ sigma     : num [1, 1] 2.08
 $ LQ        : num [1, 1] -233
  ..- attr(*, "logarithm")= logi TRUE
 $ iterations: num 8
 $ tol       : num [1, 1] 1.11e-10
  ..- attr(*, "logarithm")= logi TRUE
# With automatic relevance determination
fit_small_ard = vb_reg(X, y, tol = 1e-8, ard = TRUE)

glimpse(fit_small_ard)
List of 5
 $ coef      : num [1:4, 1] 0.955 2.269 3.283 5.047
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr [1:4] "" "X1" "X2" "X3"
  .. ..$ : NULL
 $ sigma     : num [1, 1] 2.09
 $ LQ        : num [1, 1] -229
  ..- attr(*, "logarithm")= logi TRUE
 $ iterations: num 9
 $ tol       : num [1, 1] 7.46e-09
  ..- attr(*, "logarithm")= logi TRUE
lm_mod = lm(y ~ ., data = df)

Now with the higher dimensional data. We fit using the training data and will estimate the error on training and test using the yardstick package.

fit_vb  = vb_reg(X_train[,-1], y_train)
fit_glm = glm.fit(X_train, y_train)

# predictions
vb_pred_train = X_train %*% fit_vb[['coef']]
vb_pred_test  = X_test %*% fit_vb[['coef']]

glm_pred_train = fitted(fit_glm)
glm_pred_test  = X_test %*% coef(fit_glm)

# error
vb_train_error = yardstick::rmse_vec(y_train[,1], vb_pred_train[,1])
vb_test_error  = yardstick::rmse_vec(y_test[,1], vb_pred_test[,1])

glm_train_error = yardstick::rmse_vec(y_train[,1], glm_pred_train)
glm_test_error  = yardstick::rmse_vec(y_test[,1], glm_pred_test[,1])

Comparison

For the smaller data, we will compare the coefficients.

no_ard ard lm
1.010 0.955 1.012
2.291 2.269 2.300
3.286 3.283 3.297
5.024 5.047 5.045

For the higher dimensional data, we will compare root mean square error.

vb glm
train 0.574 0.566
test 1.876 1.982

Visualization

In general the results are as expected where the standard approach overfits relative to VB regression. The following visualizes them, similar to Drugowitsch figure 1.

Supplemental Example

And now for a notably higher dimension case with irrelevant predictors as in Drugowitsch section 2.6.3. This is problematic for the GLM with having more covariates than data points (rank deficient), and as such it will throw a warning, as will the predict function. It’s really not even worth looking at but I have the code for consistency.

This will take a while to estimate, and without ARD, even bumping up the iterations to 2000 it will still likely hit the max before reaching the default tolerance level. However, the results appear very similar to that of Drugowitsch Figure 2.

set.seed(1234)

n     = 500
ntest = 50
d     = 1000
deff  = 100
coefs = rnorm(deff + 1)
sigma = 1

X_train = cbind(1, replicate(d, rnorm(n)))
y_train = X_train %*% c(coefs, rep(0, d - deff)) + rnorm(n, sd = sigma)

X_test = cbind(1, replicate(d, rnorm(ntest)))
y_test = X_test %*% c(coefs, rep(0, d - deff)) + rnorm(ntest, sd = sigma)
fit_vb     = vb_reg(X_train[,-1], y_train)

fit_vb_ard = vb_reg(X_train[,-1], y_train, ard = TRUE)

# fit_glm = glm(y_train ~ ., data = data.frame(X_train[,-1]))
# predictions
vb_pred_train = X_train %*% fit_vb[['coef']]
vb_pred_test  = X_test %*% fit_vb[['coef']]

# 
vb_ard_pred_train = X_train %*% fit_vb_ard[['coef']]
vb_ard_pred_test  = X_test %*% fit_vb_ard[['coef']]

# glm_pred_train = fitted(fit_glm)
# glm_pred_test  = X_test %*% coef(fit_glm)

# error
vb_train_error = yardstick::rmse_vec(y_train[,1], vb_pred_train[,1])
vb_test_error  = yardstick::rmse_vec(y_test[,1], vb_pred_test[,1])

# error
vb_ard_train_error = yardstick::rmse_vec(y_train[,1], vb_ard_pred_train[,1])
vb_ard_test_error  = yardstick::rmse_vec(y_test[,1], vb_ard_pred_test[,1])

# glm_train_error = yardstick::rmse_vec(y_train[,1], glm_pred_train)
# glm_test_error  = yardstick::rmse_vec(y_test[,1], glm_pred_test[,1])
mse_results = data.frame(
  vb = c(vb_train_error, vb_test_error),
  vbARD = c(vb_ard_train_error, vb_ard_test_error)#,
  # glm = c(glm_train_error, glm_test_error)
)
rownames(mse_results) = c('train', 'test')

kable_df(mse_results)
vb vbARD
train 0.641 0.002
test 8.378 2.323

Note how ARD correctly estimates (nearly) zero for irrelevant predictors.

N Mean SD Min Q1 Median Q3 Max % Missing
900 0 0 -0.2 0 0 0 0.4 0

Visualized, as before.