Gradient Descent implementation in PyTorch

 

One of the most well-liked methods for training deep neural networks is the gradient descent algorithm. It has numerous uses in areas including speech recognition, computer vision, and natural language processing. Although the concept of gradient descent has been around for a long time, it has only lately been used in deep learning applications.

Gradient descent is an iterative optimization technique that updates values repeatedly at each step to determine the minimum of an objective function. It moves incrementally in the desired direction with each iteration until convergence or a stop criteria is attained.

This article will teach you how to train a simple linear regression model with two trainable parameters, as well as how gradient descent works and how to use it with PyTorch. You will learn specifically about:

  • Gradient Descent algorithm and its implementation in PyTorch
  • Batch Gradient Descent and its implementation in PyTorch
  • Stochastic Gradient Descent and its implementation in PyTorch
  • What distinguishes Stochastic Gradient Descent from Batch Gradient Descent
  • How loss diminishes during training in batch gradient descent versus stochastic gradient descent

Overview

This tutorial is in four parts; they are

  • Preparing Data
  • Batch Gradient Descent
  • Stochastic Gradient Descent
  • Plotting Graphs for Comparison

Preparing Data

We will take the linear regression example from the previous session to illustrate how to keep the model basic. The information is fabricated and produced as follows:

import torch
import numpy as np
import matplotlib.pyplot as plt

# Creating a function f(X) with a slope of -5
X = torch.arange(-5, 5, 0.1).view(-1, 1)
func = -5 * X

# Adding Gaussian noise to the function f(X) and saving it in Y
Y = func + 0.4 * torch.randn(X.size())

Similar to the previous tutorial, we initialized a variable X with values ranging from −5 to 5, and created a linear function with a slope of −5. Then, Gaussian noise is added to create the variable Y.

Using Matplotlib, we can plot the data to show the pattern:

...
# Plot and visualizing the data points in blue
plt.plot(X.numpy(), Y.numpy(), 'b+', label='Y')
plt.plot(X.numpy(), func.numpy(), 'r', label='func')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.grid('True', color='y')
plt.show()

 

Gradient Descent implementation in PyTorch 2
Data points for regression model

Batch Gradient Descent

We’ll design a forward function based on a simple linear regression equation now that we’ve created the data for our model. We will train the model for two parameters (w and b). A loss criterion function is also required. MSE loss is suitable because this is a regression problem on continuous values.

...
# defining the function for forward pass for prediction
def forward(x):
return w * x + b

# evaluating data points with Mean Square Error (MSE)
def criterion(y_pred, y):
return torch.mean((y_pred - y) ** 2)

Let’s learn about batch gradient descent before we train our model. All of the samples in the training data are considered in a single step in batch gradient descent. Taking the mean gradient of all the training instances, the parameters are updated. In other words, each epoch has only one gradient descent step.

While Batch Gradient Descent is the best solution for smooth error manifolds, it is very sluggish and computationally complex, especially when training on a bigger dataset.

Training with Batch Gradient Descent

Let us randomly initialise the trainable parameters w and b, as well as specify some training parameters like learning rate or step size, an empty list to record the loss, and the number of training epochs.

w = torch.tensor(-10.0, requires_grad=True)
b = torch.tensor(-20.0, requires_grad=True)

step_size = 0.1
loss_BGD = []
n_iter = 20

We’ll train our model for 20 epochs using the code below. The forward() method generates the prediction, while the criterian() function measures the loss and stores it in the loss variable. The gradient computations are performed via the backward() method, and the modified parameters are saved in w.data and b.data.

for i in range (n_iter):
# making predictions with forward pass
Y_pred = forward(X)
# calculating the loss between original and predicted data points
loss = criterion(Y_pred, Y)
# storing the calculated loss in a list
loss_BGD.append(loss.item())
# backward pass for computing the gradients of the loss w.r.t to learnable parameters
loss.backward()
# updateing the parameters after each iteration
w.data = w.data - step_size * w.grad.data
b.data = b.data - step_size * b.grad.data
# zeroing gradients after each iteration
w.grad.data.zero_()
b.grad.data.zero_()
# priting the values for understanding
print('{}, \t{}, \t{}, \t{}'.format(i, loss.item(), w.item(), b.item()))

When we use batch gradient descent, the result looks like this, and the parameters are updated after each epoch.

0, 596.7191162109375, -1.8527469635009766, -16.062074661254883
1, 343.426513671875, -7.247585773468018, -12.83026123046875
2, 202.7098388671875, -3.616910219192505, -10.298759460449219
3, 122.16651153564453, -6.0132551193237305, -8.237251281738281
4, 74.85094451904297, -4.394278526306152, -6.6120076179504395
5, 46.450958251953125, -5.457883358001709, -5.295622825622559
6, 29.111614227294922, -4.735295295715332, -4.2531514167785645
7, 18.386211395263672, -5.206836700439453, -3.4119482040405273
8, 11.687058448791504, -4.883906364440918, -2.7437009811401367
9, 7.4728569984436035, -5.092618465423584, -2.205873966217041
10, 4.808231830596924, -4.948029518127441, -1.777699589729309
11, 3.1172332763671875, -5.040188312530518, -1.4337140321731567
12, 2.0413269996643066, -4.975278854370117, -1.159447193145752
13, 1.355530858039856, -5.0158305168151855, -0.9393846988677979
14, 0.9178376793861389, -4.986582279205322, -0.7637402415275574
15, 0.6382412910461426, -5.004333972930908, -0.6229321360588074
16, 0.45952412486076355, -4.991086006164551, -0.5104631781578064
17, 0.34523946046829224, -4.998797416687012, -0.42035552859306335
18, 0.27213525772094727, -4.992753028869629, -0.3483465909957886
19, 0.22536347806453705, -4.996064186096191, -0.2906789183616638

Putting all together, the following is the complete code

Source link