Gradient Descent
The following demonstration regards Gradient descent for a standard linear regression model.
Data Setup
Create some basic data for standard regression.
library(tidyverse)
set.seed(8675309)
= 1000
n = rnorm(n)
x1 = rnorm(n)
x2 = 1 + .5*x1 + .2*x2 + rnorm(n)
y = cbind(Intercept = 1, x1, x2) # model matrix X
Function
(Batch) Gradient Descent Algorithm. The function takes arguments starting points for the parameters to be estimated, a tolerance or maximum iteration value to provide a stopping point, stepsize (or starting stepsize for adaptive approach), whether to print out iterations, and whether to plot the loss over each iteration.
<- function(
gd
par,
X,
y,tolerance = 1e-3,
maxit = 1000,
stepsize = 1e-3,
adapt = FALSE,
verbose = TRUE,
plotLoss = TRUE
) {
# initialize
= par; names(beta) = colnames(X)
beta = crossprod(X %*% beta - y)
loss = 1
tol = 1
iter
while(tol > tolerance && iter < maxit){
= X %*% beta
LP = t(X) %*% (LP - y)
grad = beta - stepsize * grad
betaCurrent = max(abs(betaCurrent - beta))
tol = betaCurrent
beta = append(loss, crossprod(LP - y))
loss = iter + 1
iter
if (adapt)
= ifelse(
stepsize < loss[iter - 1],
loss[iter] * 1.2,
stepsize * .8
stepsize
)
if (verbose && iter %% 10 == 0)
message(paste('Iteration:', iter))
}
if (plotLoss)
plot(loss, type = 'l', bty = 'n')
list(
par = beta,
loss = loss,
RSE = sqrt(crossprod(LP - y) / (nrow(X) - ncol(X))),
iter = iter,
fitted = LP
) }
Estimation
Set starting values.
= rep(0, 3) init
For any particular data you’d have to fiddle with the stepsize
, which could
be assessed via cross-validation, or alternatively one can use an
adaptive approach, a simple one of which is implemented in this function.
= gd(
fit_gd
init,X = X,
y = y,
tolerance = 1e-8,
stepsize = 1e-4,
adapt = TRUE
)
str(fit_gd)
List of 5
$ par : num [1:3, 1] 0.985 0.487 0.218
..- attr(*, "dimnames")=List of 2
.. ..$ : chr [1:3] "Intercept" "x1" "x2"
.. ..$ : NULL
$ loss : num [1:70] 2315 2315 2075 1918 1760 ...
$ RSE : num [1, 1] 1.03
$ iter : num 70
$ fitted: num [1:1000, 1] 0.441 1.061 0.43 2.125 1.858 ...
Comparison
We can compare to standard linear regression.
Intercept | x1 | x2 | |
---|---|---|---|
gd | 0.985 | 0.487 | 0.218 |
lm | 0.985 | 0.487 | 0.218 |
Source
Original code available at https://github.com/m-clark/Miscellaneous-R-Code/blob/master/ModelFitting/gradient_descent.R