Variational Bayes Regression
The following provides a function for estimating the parameters of a linear regression via variational inference. See Drugowitsch (2014) for an overview of the method outlined in Bishop (2006).
For the primary function I will use the notation used in the Drugowitsch article in most cases. Here w, represents the coefficients, and τ the precision (inverse variance). The likelihood for target y is N(Xw, τ-1). Priors for w and tau are normal inverse gamma N(0, (τα)-1) Gamma(a0, b0).
References:
- Drugowitsch: http://arxiv.org/abs/1310.5438
- See here and here for his Matlab implementations.
- Bishop: Pattern Recognition and Machine Learning
Data Setup
We can simulate some data as a starting point, in this case, basic tabular data used in the standard regression problem. Here, I explicitly note the intercept, as it is added to the model matrix within the vb_reg function.
library(tidyverse)
set.seed(1234)
= 100
n = 3
d = c(1, 2, 3, 5)
coefs = 2
sigma
= replicate(d, rnorm(n)) # predictors
X colnames(X) = paste0('X', 1:d)
= cbind(1, X) %*% coefs + rnorm(n, sd = sigma) # target
y
= data.frame(X, y) df
We can also look at the higher dimension case as done in Drugowitsch section 2.6.2.
= 150
n = 50
ntest = 100
d = rnorm(d + 1)
coefs = 1
sigma
= cbind(1, replicate(d, rnorm(n)))
X_train = X_train %*% coefs + rnorm(n, sd = sigma)
y_train
= cbind(1, replicate(d, rnorm(ntest)))
X_test = X_test %*% coefs + rnorm(ntest, sd = sigma) y_test
Function
First, the main function. For this demo, automatic relevance determination is an argument rather than a separate function.
<- function(
vb_reg
X,
y,a0 = 10e-2,
b0 = 10e-4,
c0 = 10e-2,
d0 = 10e-4,
tol = 1e-8,
maxiter = 1000,
ard = F
) {# X: model matrix
# y: the response
# a0, b0 prior parameters for tau
# c0, d0 hyperprior parameters for alpha
# tol: tolerance value to end iterations
# maxiter: alternative way to end iterations
# initializations
= cbind(1, X)
X = ncol(X)
D = nrow(X)
N = rep(0, D)
w = crossprod(X)
XX = crossprod(X,y)
Xy
= a0 + N/2
a_N
if (!ard) {
= c0 + D/2
c_N = c0/d0
E_alpha else {
} = c0 + 1/2
c_N = rep(c0/d0, D)
E_alpha
}
= 1
tolCurrent = 0
iter = 0
LQ
while(iter < maxiter && tolCurrent > tol ){
= iter + 1
iter # wold = w
if(!ard){
= b0 + 1/2 * (crossprod(y - X%*%w) + E_alpha * crossprod(w))
b_N = diag(E_alpha, D) + XX
VInv = solve(VInv)
V = V %*% Xy
w = a_N/b_N * crossprod(w) + sum(diag(V))
E_wtau = d0 + 1/2*E_wtau
d_N = c(c_N/d_N)
E_alpha else {
} = b0 + 1/2 * (crossprod(y - X%*%w) + t(w) %*% diag(E_alpha) %*% w)
b_N = diag(E_alpha) + XX
VInv = solve(VInv)
V = V %*% Xy
w = a_N/b_N*crossprod(w) + sum(diag(V))
E_wtau = d0 + 1/2*(c(w)^2 * c(a_N/b_N) + diag(V))
d_N = c(c_N/d_N)
E_alpha
}
= LQ
LQ_old
suppressWarnings({
= -N/2*log(2*pi) - 1/2 * (a_N/b_N * crossprod(y- crossprod(t(X), w)) + sum(XX * V)) +
LQ 1/2 * determinant(V, log = TRUE)$modulus + D/2 - lgamma(a0) + a0 * log(b0) - b0 * a_N / b_N +
lgamma(a_N) - a_N * log(b_N) + a_N - lgamma(c0) + c0*log(d0) +
lgamma(c_N) - sum(c_N*log(d_N))
})
= abs(LQ - LQ_old)
tolCurrent # alternate tolerance, comment out LQ_old up to this line if using
# tolCurrent = sum(abs(w - wold))
}
= list(
res coef = w,
sigma = sqrt(1 / (E_wtau / crossprod(w))),
LQ = LQ,
iterations = iter,
tol = tolCurrent
)
if (iter >= maxiter)
append(res, warning('Maximum iterations reached.'))
else
res }
Estimation
First we can estimate the model using the smaller data.
= vb_reg(X, y, tol = 1e-8, ard = FALSE)
fit_small
glimpse(fit_small)
List of 5
$ coef : num [1:4, 1] 1.01 2.29 3.29 5.02
..- attr(*, "dimnames")=List of 2
.. ..$ : chr [1:4] "" "X1" "X2" "X3"
.. ..$ : NULL
$ sigma : num [1, 1] 2.08
$ LQ : num [1, 1] -233
..- attr(*, "logarithm")= logi TRUE
$ iterations: num 8
$ tol : num [1, 1] 1.11e-10
..- attr(*, "logarithm")= logi TRUE
# With automatic relevance determination
= vb_reg(X, y, tol = 1e-8, ard = TRUE)
fit_small_ard
glimpse(fit_small_ard)
List of 5
$ coef : num [1:4, 1] 0.955 2.269 3.283 5.047
..- attr(*, "dimnames")=List of 2
.. ..$ : chr [1:4] "" "X1" "X2" "X3"
.. ..$ : NULL
$ sigma : num [1, 1] 2.09
$ LQ : num [1, 1] -229
..- attr(*, "logarithm")= logi TRUE
$ iterations: num 9
$ tol : num [1, 1] 7.46e-09
..- attr(*, "logarithm")= logi TRUE
= lm(y ~ ., data = df) lm_mod
Now with the higher dimensional data. We fit using the training data and will estimate the error on training and test using the yardstick package.
= vb_reg(X_train[,-1], y_train)
fit_vb = glm.fit(X_train, y_train)
fit_glm
# predictions
= X_train %*% fit_vb[['coef']]
vb_pred_train = X_test %*% fit_vb[['coef']]
vb_pred_test
= fitted(fit_glm)
glm_pred_train = X_test %*% coef(fit_glm)
glm_pred_test
# error
= yardstick::rmse_vec(y_train[,1], vb_pred_train[,1])
vb_train_error = yardstick::rmse_vec(y_test[,1], vb_pred_test[,1])
vb_test_error
= yardstick::rmse_vec(y_train[,1], glm_pred_train)
glm_train_error = yardstick::rmse_vec(y_test[,1], glm_pred_test[,1]) glm_test_error
Comparison
For the smaller data, we will compare the coefficients.
no_ard | ard | lm |
---|---|---|
1.010 | 0.955 | 1.012 |
2.291 | 2.269 | 2.300 |
3.286 | 3.283 | 3.297 |
5.024 | 5.047 | 5.045 |
For the higher dimensional data, we will compare root mean square error.
vb | glm | |
---|---|---|
train | 0.574 | 0.566 |
test | 1.876 | 1.982 |
Visualization
In general the results are as expected where the standard approach overfits relative to VB regression. The following visualizes them, similar to Drugowitsch figure 1.
Supplemental Example
And now for a notably higher dimension case with irrelevant predictors as in Drugowitsch section 2.6.3. This is problematic for the GLM with having more covariates than data points (rank deficient), and as such it will throw a warning, as will the predict function. It’s really not even worth looking at but I have the code for consistency.
This will take a while to estimate, and without ARD, even bumping up the iterations to 2000 it will still likely hit the max before reaching the default tolerance level. However, the results appear very similar to that of Drugowitsch Figure 2.
set.seed(1234)
= 500
n = 50
ntest = 1000
d = 100
deff = rnorm(deff + 1)
coefs = 1
sigma
= cbind(1, replicate(d, rnorm(n)))
X_train = X_train %*% c(coefs, rep(0, d - deff)) + rnorm(n, sd = sigma)
y_train
= cbind(1, replicate(d, rnorm(ntest)))
X_test = X_test %*% c(coefs, rep(0, d - deff)) + rnorm(ntest, sd = sigma) y_test
= vb_reg(X_train[,-1], y_train)
fit_vb
= vb_reg(X_train[,-1], y_train, ard = TRUE)
fit_vb_ard
# fit_glm = glm(y_train ~ ., data = data.frame(X_train[,-1]))
# predictions
= X_train %*% fit_vb[['coef']]
vb_pred_train = X_test %*% fit_vb[['coef']]
vb_pred_test
#
= X_train %*% fit_vb_ard[['coef']]
vb_ard_pred_train = X_test %*% fit_vb_ard[['coef']]
vb_ard_pred_test
# glm_pred_train = fitted(fit_glm)
# glm_pred_test = X_test %*% coef(fit_glm)
# error
= yardstick::rmse_vec(y_train[,1], vb_pred_train[,1])
vb_train_error = yardstick::rmse_vec(y_test[,1], vb_pred_test[,1])
vb_test_error
# error
= yardstick::rmse_vec(y_train[,1], vb_ard_pred_train[,1])
vb_ard_train_error = yardstick::rmse_vec(y_test[,1], vb_ard_pred_test[,1])
vb_ard_test_error
# glm_train_error = yardstick::rmse_vec(y_train[,1], glm_pred_train)
# glm_test_error = yardstick::rmse_vec(y_test[,1], glm_pred_test[,1])
= data.frame(
mse_results vb = c(vb_train_error, vb_test_error),
vbARD = c(vb_ard_train_error, vb_ard_test_error)#,
# glm = c(glm_train_error, glm_test_error)
)rownames(mse_results) = c('train', 'test')
kable_df(mse_results)
vb | vbARD | |
---|---|---|
train | 0.641 | 0.002 |
test | 8.378 | 2.323 |
Note how ARD correctly estimates (nearly) zero for irrelevant predictors.
N | Mean | SD | Min | Q1 | Median | Q3 | Max | % Missing |
---|---|---|---|---|---|---|---|---|
900 | 0 | 0 | -0.2 | 0 | 0 | 0 | 0.4 | 0 |
Visualized, as before.
Source
Original code available at: https://github.com/m-clark/Miscellaneous-R-Code/tree/master/ModelFitting/Bayesian/multinomial