top of page

Regression Algorithm Part 5: Decision Tree Regression Using R Language

Writer's picture: Margi PatelMargi Patel

What is Decision Tree Regression?

Decision trees are a non-parametric supervised learning method used for both classification and regression tasks. Decision trees are constructed via an algorithmic approach that identifies ways to split a data set based on different conditions. It is one of the most widely used and practical methods for supervised learning.


Let’s understand Decision Tree Regression using the Position_Salaries data set which is available on Kaggle. This data set consists of a list of positions in a company along with the band levels and their associated salary. The data set includes columns for Position with values ranging from Business Analyst, Junior Consultant to CEO, Level ranging from 1–10, and finally the Salary associated with each position ranging from $45000 to $1000000.



Required R package


First, you need to install the rpart and ggplot2 package and load the rpart and ggplot2 library then after you can able to perform the following operations. So let’s start to implement our non-linear regression model.


  • Import libraries

install.packages('rpart')   
install.packages('ggplot2')
library(rpart)
library(ggplot2)

Note: If you use R studio then packages need to be installed only once.


  • Importing the dataset

dataset <- read.csv('../input/position-salaries/Position_Salaries.csv')
dataset <- dataset[2:3]
dim(dataset)

Please refer to Regression Algorithm Part 4 for more information. Here, we will make use of Decision Tree Regression to predict the accurate salary of the employee.


  • Apply Decision Tree Regression to the data set

regressor <- rpart(formula <- Salary ~ ., data <- dataset, control <- rpart.control(minsplit = 1))

The rpart() function used to create a Decision Tree Regression model. If you look at the data set, we have one dependent variable salary and one independent variable Level. Therefore, the notation formula <- Salary ~ . means that the salary is proportional to Level. The dot represents all the independent variables. Now, the second argument takes the data set on which you want to train your regression model. The final argument is the control argument. We had no conditions on the independent variables and no splits therefore, set minsplits to 1 to solve the problem.


  • Predicting a new result with Decision Tree Regression

y_pred <- predict(regressor, data.frame(Level = 6.5))

This code predicts the salary associated with 6.5 level according to a Decision Tree Regression Model and it predicts 250 k so it’s a pretty good prediction.


  • Visualize the Decision Tree Regression results

x_grid <- seq(min(dataset$Level), max(dataset$Level), 0.01)
ggplot() +
geom_point(aes(x <- dataset$Level, y <- dataset$Salary), colour = 'red') +
geom_line(aes(x <- x_grid, y <- predict(regressor, data.frame(Level = x_grid))), colour = 'blue') +
ggtitle('Decision Tree Regression') +
xlab('Level') +
ylab('Salary')

The Decision Regression model looks like a non-continuous model. Based on the entropy in the information gain, it splits the whole range of your independent variable into different intervals. So you can see the first interval is from 1 to 6.5. The second interval is from 6.5 to 8.5. Then the third interval is from 8.5 to 9.5. And finally, the last interval is from 9.5 to 10. The decision tree regression model is considering the average of the dependent variable values in each of the intervals. From this we concluded, the decision tree regression model is not an interesting model in one d but it can be a very interesting and very powerful model in more dimensions.


The previous part of the series part1, part2, part3 and part4 covered the Linear Regression, Multiple Linear Regression, Polynomial Linear Regression and Support Vector Regression.


If you like the blog or found it helpful please leave a clap!


Thank you.

370 views
bottom of page