Gradient Descent

The following demonstration regards Gradient descent for a standard linear regression model.

Data Setup

Create some basic data for standard regression.



n  = 1000
x1 = rnorm(n)
x2 = rnorm(n)
y  = 1 + .5*x1 + .2*x2 + rnorm(n)
X  = cbind(Intercept = 1, x1, x2)  # model matrix


(Batch) Gradient Descent Algorithm. The function takes arguments starting points for the parameters to be estimated, a tolerance or maximum iteration value to provide a stopping point, stepsize (or starting stepsize for adaptive approach), whether to print out iterations, and whether to plot the loss over each iteration.

gd <- function(
  tolerance = 1e-3,
  maxit     = 1000,
  stepsize  = 1e-3,
  adapt     = FALSE,
  verbose   = TRUE,
  plotLoss  = TRUE
  ) {
  # initialize
  beta = par; names(beta) = colnames(X)
  loss = crossprod(X %*% beta - y)
  tol  = 1
  iter = 1
  while(tol > tolerance && iter < maxit){
    LP   = X %*% beta
    grad = t(X) %*% (LP - y)
    betaCurrent = beta - stepsize * grad
    tol  = max(abs(betaCurrent - beta))
    beta = betaCurrent
    loss = append(loss, crossprod(LP - y))
    iter = iter + 1
    if (adapt)
      stepsize = ifelse(
        loss[iter] < loss[iter - 1],  
        stepsize * 1.2, 
        stepsize * .8
    if (verbose && iter %% 10 == 0)
      message(paste('Iteration:', iter))
  if (plotLoss)
    plot(loss, type = 'l', bty = 'n')
    par    = beta,
    loss   = loss,
    RSE    = sqrt(crossprod(LP - y) / (nrow(X) - ncol(X))), 
    iter   = iter,
    fitted = LP


Set starting values.

init = rep(0, 3)

For any particular data you’d have to fiddle with the stepsize, which could be assessed via cross-validation, or alternatively one can use an adaptive approach, a simple one of which is implemented in this function.

fit_gd = gd(
  X = X,
  y = y,
  tolerance = 1e-8,
  stepsize  = 1e-4,
  adapt     = TRUE

List of 5
 $ par   : num [1:3, 1] 0.985 0.487 0.218
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr [1:3] "Intercept" "x1" "x2"
  .. ..$ : NULL
 $ loss  : num [1:70] 2315 2315 2075 1918 1760 ...
 $ RSE   : num [1, 1] 1.03
 $ iter  : num 70
 $ fitted: num [1:1000, 1] 0.441 1.061 0.43 2.125 1.858 ...


We can compare to standard linear regression.

Intercept x1 x2
gd 0.985 0.487 0.218
lm 0.985 0.487 0.218