top of page

Classification Algorithm Part 1: Logistic Regression Using R Language

Writer's picture: Margi PatelMargi Patel


What is Logistic Regression?


Logistic regression is used when the dependent variable is categorical. For example,

  • To predict whether the tumor is malignant(1) or benign(0).

  • To predict whether an email is a spam(1) or not spam(0).

Consider a scenario where we need to classify whether the tumor is malignant or benign. If we use linear regression for this problem, there is a need for setting up a threshold based on which classification can be done. Say if the actual class is malignant, predicted continuous value 0.4 and the threshold value is 0.5, the data point will be classified as benign which can lead to serious consequence in real-time. From this example, it can be inferred that linear regression is not suitable for the classification problem. Linear regression is unbounded, and this brings logistic regression into the picture. Their value strictly from 0 to 1.


Let's implement logistic regression using the Social Network Ads data set which is available on Kaggle. This data set contains information on users of a social network. This information includes the user id, gender, age, estimated salary, and the number of purchases. So this data set contains the data about profiles of the users on the social network who on interacting with the advertisement either purchased the product or not.


Required R package


First, you need to install the caTools and ggplot2 package and load the caTools and ggplot2 library then after you can able to perform the following operations. So let’s start to implement Random Forest Regression model.


  • Import libraries

install.packages('caTools')   
install.packages('ggplot2')
library(caTools)
library(ggplot2)

Note: If you use R studio then packages need to be installed only once.


  • Importing the dataset

dataset <- read.csv('../input/social-network-ads/Social_Network_Ads.csv')
dataset <- dataset[3:5]
dim(dataset)

The read.csv() function is used to read the csv file and dim() function is used to check the csv file contains how many rows and columns. Here we selected three columns age, estimated salary, and purchased in the dataset variable.


  • Splitting the data set into the Training set and Test set


set.seed(123)
split <- sample.split(dataset$Purchased, SplitRatio <- 0.75)
training_set <- subset(dataset, split == TRUE)
test_set <- subset(dataset, split == FALSE)

Now split the data set into training and test set. So here the training_set contains 75% of the data and test_set contains 25% of the data. But you can also change the training and testing split ratio like 80% - 20%, 70% - 30%, etc.


  • Feature Scaling

training_set[-3] <- scale(training_set[-3])
test_set[-3] <- scale(test_set[-3])

Feature scaling is a method used to normalize the range of independent variables. It increases the speed of computation. It is also known as data normalization.


  • Fitting Logistic Regression to the Training set

classifier <- glm(formula = Purchased ~ ., family <- binomial, data <- training_set)

Now, create the classifier for the logistic regression. Here, the gml (generalized linear models) is used because the logistic regression is a linear classifier. The first argument is a formula that takes the dependent variable. The dot specifies that we want to take all the independent variables which are the age and the estimated salary. So based on all the independent variables we predict the result. The next argument is the family. For logistic regression, you have to specify the binomial family. The last argument is data which is the training data on which you want to train your logistic regression model.

  • Predicting the Test set results

prob_pred <- predict(classifier, type = 'response', newdata <- test_set[-3])
y_pred <- ifelse(prob_pred > 0.5, 1, 0)

The predict function is used to predict the probabilities of the test set observation by using the logistic regression classifier. If your prob_pred value is greater than 0.5 then it predicts the value 1 otherwise it predicts the value 0.


  • Making the Confusion Matrix

cm <- table(test_set[, 3], y_pred > 0.5)

To evaluate the predictions by making the confusion matrix which will count the number of correct predictions and the number of incorrect predictions.



  • Visualizing the Training set results

set <- training_set
x1 <- seq(min(set[, 1]) - 1, max(set[, 1]) + 1, by = 0.01)
x2 <- seq(min(set[, 2]) - 1, max(set[, 2]) + 1, by = 0.01)
grid_set <- expand.grid(x1, x2)
colnames(grid_set) <- c('Age', 'EstimatedSalary')
prob_set <- predict(classifier, type = 'response', newdata = grid_set)
y_grid <- ifelse(prob_set > 0.5, 1, 0)
plot(set[, -3], main = 'LR Plot for Training set', xlab = 'Age', ylab = 'Estimated Salary', xlim = range(x1), ylim = range(x2))
contour(x1, x2, matrix(as.numeric(y_grid), length(x1), length(x2)), add = TRUE)
points(grid_set, pch = '.', col = ifelse(y_grid == 1, 'springgreen3', 'tomato'))
points(set, pch = 21, bg = ifelse(set[, 3] == 1, 'green4', 'red3'))

The logistic regression model represents some red points and green points. The red points are the training set observations for which the dependent variable purchased is equal to zero and the green points are the training set observations for which the dependent variable purchase is equal to 1. And each of these users here is characterized by its age here on the x-axis and it's estimated salary here on the y-axis.


  • Visualizing the Test set results

set <- test_set
x1 <- seq(min(set[, 1]) - 1, max(set[, 1]) + 1, by = 0.01)
x2 <- seq(min(set[, 2]) - 1, max(set[, 2]) + 1, by = 0.01)
grid_set <- expand.grid(x1, x2)
colnames(grid_set) <- c('Age', 'EstimatedSalary')
prob_set <- predict(classifier, type = 'response', newdata = grid_set)
y_grid <- ifelse(prob_set > 0.5, 1, 0)
plot(set[, -3], main = 'LR plot for Test set', xlab = 'Age', ylab = 'Estimated Salary', xlim = range(x1), ylim = range(x2))
contour(x1, x2, matrix(as.numeric(y_grid), length(x1), length(x2)), add = TRUE)
points(grid_set, pch = '.', col = ifelse(y_grid == 1, 'springgreen3', 'tomato'))
points(set, pch = 21, bg = ifelse(set[, 3] == 1, 'green4', 'red3'))

The above code represents the visualization graph for test set observations.


The code is available on my GitHub account.


If you like the blog or found it helpful please leave a clap!


Thank you.

961 views
bottom of page