Classification Example with RPART Tree model in R

   Classification and Regression Trees (CART) models can be implemented by using the rpart package in R. In this post, we'll briefly learn how to classify data by using the 'rpart' function in R with two types of implementation. The tutorial covers,
  1. Classification with the rpart() function
  2. Applying the 'caret' package's the train() method.
  3. Source code listing.
First, we'll start by loading the required libraries.

> library(rpart)
> library(caret)

   Next, we'll generate the sample classification dataset and split it into the train and test parts. You may use any other classification dataset too.

> ds <- data.frame(a = sample(1:20, 200, replace = T),
                   b = sample(1:20, 200, replace = T)) 
> ds <- cbind(ds, level = ifelse((ds$a+ds$b) > 25, "high",
                         ifelse((ds$a+ds$b) > 15, "normal", "low")))
> head(ds)
   a  b  level
1  7  5    low
2  4  2    low
3 16 16   high
4 13  1    low
5  2 18 normal
6  2  6    low

> indexes <- createDataPartition(ds$level, p = .9, list = F)
> train <- ds[indexes, ]
> test <- ds[-indexes, ]



Classificaiton with the 'rpart' function


We will use the rpart function to fit the model.

> fit = rpart(level~., data = train, control = rpart.control(cp = 0.0001))
> printcp(fit)

Classification tree:
rpart(formula = level ~ ., data = train, control = rpart.control(cp = 1e-04))

Variables actually used in tree construction:
[1] a b

Root node error: 93/182 = 0.51099

n= 182 

        CP nsplit rel error  xerror     xstd
1 0.240143      0   1.00000 1.00000 0.072513
2 0.010753      3   0.27957 0.34409 0.055221
3 0.000100      4   0.26882 0.40860 0.058960


Next, we can prune data with the CP value that contains the lowest error.

> fit.pruned = prune(fit, cp = 0.0107)

We can plot the model as a following.

> plot(fit.pruned)
> text(fit.pruned, cex = 0.9, xpd = TRUE)

We'll predict the test data.

> pred <- predict(fit.pruned, test, type = "class")
> data.frame(test,pred)
     a  b  level   pred
6    2  6    low    low
7    7  1    low    low
9   11 12 normal normal
17   4 18 normal normal
25   4 20 normal normal
26  17 13   high   high
31  16 17   high   high
33   5 15 normal normal
57  15 18   high   high
60   5  2    low    low
65  12 18   high   high
82  19  2 normal normal
84  10  5    low normal
131  1  2    low    low
138  5 11 normal normal
173  2 16 normal normal
175 16  8 normal normal
192  7 16 normal normal

Checking the result with the confusion matrix.

> confusionMatrix(test$level,pred)
Confusion Matrix and Statistics

          Reference
Prediction high low normal
    high      4   0      0
    low       0   4      1
    normal    0   0      9

Overall Statistics
                                          
               Accuracy : 0.9444          
                 95% CI : (0.7271, 0.9986)
    No Information Rate : 0.5556          
    P-Value [Acc > NIR] : 0.0003914       
                                          
                  Kappa : 0.9091          
 Mcnemar's Test P-Value : NA 
.... 


Applying the 'caret' package's train() method

In this method, we need a 'trainControl' parameter and we can define it as below.

> trainCtrl <- trainControl(method = "cv", number=10)
> 
> fit <- caret::train(level~., data = train, 
                     trControl = trainCtrl, method = "rpart")

> print(fit)
CART 

182 samples
  2 predictor
  3 classes: 'high', 'low', 'normal' 

No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 163, 164, 164, 163, 164, 164, ... 
Resampling results across tuning parameters:

  cp          Accuracy   Kappa    
  0.00000000  0.8289474  0.7221735
  0.01075269  0.8178363  0.7015998
  0.24014337  0.5614035  0.1982245

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was cp = 0.

Predicting data and checking results.

> pred <- predict(fit, test)
> data.frame(test,pred)
     a  b  level   pred
6    2  6    low    low
7    7  1    low    low
9   11 12 normal normal
17   4 18 normal normal
25   4 20 normal normal
26  17 13   high   high
31  16 17   high   high
33   5 15 normal normal
57  15 18   high   high
60   5  2    low    low
65  12 18   high   high
82  19  2 normal normal
84  10  5    low normal
131  1  2    low    low
138  5 11 normal    low
173  2 16 normal normal
175 16  8 normal normal
192  7 16 normal normal

> confusionMatrix(test$level,pred)
 
Confusion Matrix and Statistics

          Reference
Prediction high low normal
    high      4   0      0
    low       0   4      1
    normal    0   1      8

Overall Statistics
                                          
               Accuracy : 0.8889          
                 95% CI : (0.6529, 0.9862)
    No Information Rate : 0.5             
    P-Value [Acc > NIR] : 0.0006561       
                                          
                  Kappa : 0.8218          
 Mcnemar's Test P-Value : NA              
 ........


   In this tutorial, we've briefly learned how to classify data with the CART model using the rpart and caret training methods. The full source code is listed below.


Source code listing

library(rpart)
library(caret)
 
ds = data.frame(a=sample(1:20, 200, replace = T),
                 b=sample(1:20, 200, replace = T)) 
ds = cbind(ds, level = ifelse((ds$a+ds$b) > 25, "high",
    ifelse((ds$a+ds$b) > 15, "normal", "low")))
head(ds)

indexes = createDataPartition(ds$level, p = .9, list = F)
train = ds[indexes, ]
test = ds[-indexes, ]
 
fit = rpart(level~., data = train, control = rpart.control(cp = 0.0001))
printcp(fit)
 
fit.pruned = prune(fit, cp = 0.0107)
 
plot(fit.pruned)
text(fit.pruned, cex = 0.9, xpd = TRUE)
 
pred = predict(fit.pruned, test, type = "class")
print(data.frame(test, pred))
 
confusionMatrix(test$level, pred)


# caret train method
trainCtrl = trainControl(method = "cv", number=10)
fit = caret::train(level~., data = train, 
                  trControl = trainCtrl, method = "rpart")
print(fit)
 
pred = predict(fit, test)
print(data.frame(test, pred))
confusionMatrix(test$level, pred) 


1 comment:

  1. This post has been my salvation. God bless you. Let God turn what you hold into gold.

    ReplyDelete