Andy Pickering

Adventures in Data Science

Introduction to Machine Learning in R

An example of how to build a simple machine-learning model in R. Click the title for more info!

Machine Learning Classification using a Random Forest Model in R

Introduction

This is an example of building a machine learning model for a classification task. We’ll use the well-known Iris dataset, which contains measurements of several plant features for 3 species of Iris. The goal will be to classify the species of flower based on these 4 measurements.

Exploratory Data Analysis

First let’s load the data and take quick look at it.

data("iris")
head(iris)
##   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## 1          5.1         3.5          1.4         0.2  setosa
## 2          4.9         3.0          1.4         0.2  setosa
## 3          4.7         3.2          1.3         0.2  setosa
## 4          4.6         3.1          1.5         0.2  setosa
## 5          5.0         3.6          1.4         0.2  setosa
## 6          5.4         3.9          1.7         0.4  setosa

The ‘str’ (for ‘structure’) function also provides a concise summary of the data:

str(iris)
## 'data.frame':    150 obs. of  5 variables:
##  $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
##  $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
##  $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
##  $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
##  $ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...

The dataset contains 4 ‘features’ (also known as predictor variables or independent variables) and Species is the ‘target’ (or independent variable) we are trying to predict.

Next we’ll visually examine the data. For datasets such as Iris with only a few variables, a ‘pairs’ plot is a nice compact way of visualizing the data and relationships between the variables. We can see that the species (represented by color here) are well-separated by some variable combinations (for example Petal width vs Petal Length in the last row, 3rd column). This means we should be able to build a model that classifies them well.

library(GGally)
ggpairs(data=iris,columns=1:4,ggplot2::aes(colour=Species))

Building a Machine Learning Model

Now it’s time to actually build a model to predict the species. I’m going to use the ‘caret’ package to fit the model, because it makes it easy to apply standard model-fitting procedures to any model and dataset, with a consistent , organized framework. Note that the actual models are in their own packages (e.g. ‘randomForest’); Caret works with many models and provides general tools to do many common data preparation, model fitting, and evaluation tasks. When you want to try many different models and parameters, this is extremely useful.

Before continuing, you might be wondering what a ‘random forest’ is. It is a collection of many decision trees. A decision tree is like a flow-chart, or a game of 20-questions. The data is succesively split based on a criteria for one the variables, until each node contains a sufficiently small number of data points. A random forest is just a collection of many random versions of this tree. Individual trees may have large errors at different points, but by averaging many random trees together, the accuracy is improved.

Our goal for this model is to predict the species based on measurements of sepal and petal width/length. Now, we can always make a model that is very accurate on the data we use to build it. But this is not useful; what we want is a model that will make accurate predictions on new data that was not used to build it. Therefore before we build our model (or any machine learning model), we need to split the data into a training and test set. The model will be fit on the training data only, and the final model wil be evaluated after on the test data to see how well it does on new data (the ‘test’ or ‘out of sample’ error or accuracy).

The data is split into random training (75%) and testing (25%) sets using the ‘createDataPartition’ function from caret.

suppressPackageStartupMessages(library(caret))
set.seed(153)
inTrain <- createDataPartition(y=iris$Species,p=0.75,list=FALSE)
training_set <- iris[inTrain,]
testing_set <- iris[-inTrain,]
dim(training_set)
## [1] 114   5

Whatever model you use will have some adjustable parameters. We want to fit the model with a variety of parameters and choose the model that will give us the best test accuracy. However, we can only evaluate the model on the test set once, so we can’t use test-set accuracy to choose the best model. One solution is to further split the training data into a training and ‘validation’ set, so that we train the models on the training set, and evaluate on the validation set to choose the best model/parameters. But it turns out the accuracy on the validation set isn’t always a great estimate of how the model will do on un-seen data. The results can be sensitive to the training/test split; we might get lucky and choose data that the model fits really well, or the opposite. To get a better estimate of what the test-set performance will be, a method called cross-validation (CV) is commonly used. In CV, training data is split up again into smaller pieces (usually 5 or 10), and one of those pieces can be held out as a ‘test’ set during training. We repeat that so that each piece (‘fold’) of the training data is held out once, and average the error from each iteration to get a more robust estimate of what the true test error will be.

The ‘train’ function in caret will perform this cross-validation for us and find the best model parameters. The parameter we will vary in this model is ‘mtry’, the number of variables randomly sampled as candidates at each split.

# Create CV folds, using 5 partitions
myFolds <- createFolds(training_set$Species, k = 5)

# create a control object to pass to caret::train . If we fit multiple different models, this will allow us to fit them all in the same way so we can compare them easily.
myControl=trainControl(classProbs = TRUE, # IMPORTANT!
                       verboseIter = FALSE,
                       savePredictions = TRUE,
                       index=myFolds)

# fit the model using the caret::train function
mod_rf <- train(Species~.,
              method="rf",
              data=training_set,
              trControl=myControl)
## Loading required package: randomForest
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
## 
##     margin
# printing the model gives us a nice summary
print(mod_rf)
## Random Forest 
## 
## 114 samples
##   4 predictor
##   3 classes: 'setosa', 'versicolor', 'virginica' 
## 
## No pre-processing
## Resampling: Bootstrapped (5 reps) 
## Summary of sample sizes: 22, 23, 23, 23, 23 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##   2     0.9253703  0.8879378
##   3     0.9231725  0.8846407
##   4     0.9209747  0.8813436
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was mtry = 2.

So the ‘train’ function fit the model for 3 different values of ‘mtry’; for each one it performed the cross-validation, and then it found that ‘mtry’ of 2 gave the best CV-accuracy. The main draw-back of CV is that it requires more computation time; in this case we had to fit the model 15 times (5 CV folds, 3 values of ‘mtry’).

After choosing our final model, we then evaluate it’s performace on the test-set. The confusion matrix provides a nice summary of the performance of a classification model. It shows the predicted classes vs the true classes. We see that 1 sample was mis-classified in this case.

preds_final <- predict(mod_rf,newdata = testing_set)
confusionMatrix(preds_final,testing_set$Species)
## Confusion Matrix and Statistics
## 
##             Reference
## Prediction   setosa versicolor virginica
##   setosa         12          0         0
##   versicolor      0         11         0
##   virginica       0          1        12
## 
## Overall Statistics
##                                           
##                Accuracy : 0.9722          
##                  95% CI : (0.8547, 0.9993)
##     No Information Rate : 0.3333          
##     P-Value [Acc > NIR] : 4.864e-16       
##                                           
##                   Kappa : 0.9583          
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: setosa Class: versicolor Class: virginica
## Sensitivity                 1.0000            0.9167           1.0000
## Specificity                 1.0000            1.0000           0.9583
## Pos Pred Value              1.0000            1.0000           0.9231
## Neg Pred Value              1.0000            0.9600           1.0000
## Prevalence                  0.3333            0.3333           0.3333
## Detection Rate              0.3333            0.3056           0.3333
## Detection Prevalence        0.3333            0.3056           0.3611
## Balanced Accuracy           1.0000            0.9583           0.9792

Summary

I hope this example gave you an idea of how to start building a simple machine-learning model for classification in R, and an understanding of some important concepts. This was simple example with a nice clean data set, but the same basic process should apply to more complicated data sets.

  • We loaded the Iris dataset and did some simple exploratory data analysis.
  • We trained a machine learning model (in this case Random Forest) to classify the species of Iris based on it’s measurements.
  • We learned about the importance of training/test errors and how to choose the best model using cross-validation and the Caret package.

Resources

Session Info

sessionInfo()
## R version 3.3.2 (2016-10-31)
## Platform: x86_64-apple-darwin13.4.0 (64-bit)
## Running under: OS X Yosemite 10.10.5
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] randomForest_4.6-12 caret_6.0-73        ggplot2_2.2.1      
## [4] lattice_0.20-34     GGally_1.3.0       
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_0.12.9        compiler_3.3.2     nloptr_1.0.4      
##  [4] RColorBrewer_1.1-2 plyr_1.8.4         class_7.3-14      
##  [7] prettydoc_0.2.0    iterators_1.0.8    tools_3.3.2       
## [10] digest_0.6.12      lme4_1.1-12        evaluate_0.10     
## [13] tibble_1.2         gtable_0.2.0       nlme_3.1-131      
## [16] mgcv_1.8-17        Matrix_1.2-8       foreach_1.4.3     
## [19] parallel_3.3.2     yaml_2.1.14        SparseM_1.74      
## [22] e1071_1.6-8        stringr_1.2.0      knitr_1.15.1      
## [25] MatrixModels_0.4-1 stats4_3.3.2       rprojroot_1.2     
## [28] grid_3.3.2         nnet_7.3-12        reshape_0.8.6     
## [31] rmarkdown_1.3      minqa_1.2.4        reshape2_1.4.2    
## [34] car_2.1-4          magrittr_1.5       backports_1.0.5   
## [37] scales_0.4.1       codetools_0.2-15   ModelMetrics_1.1.0
## [40] htmltools_0.3.5    MASS_7.3-45        splines_3.3.2     
## [43] assertthat_0.1     pbkrtest_0.4-6     colorspace_1.3-2  
## [46] quantreg_5.29      labeling_0.3       stringi_1.1.2     
## [49] lazyeval_0.2.0     munsell_0.4.3
Written on March 17, 2017