Confusion Matrices – Evaluating your classification models
At two recent – successful and thoroughly engaging – ML training sessions. The cohorts really enjoyed the section on ML classification methods, explicitly focusing on supervised ML techniques for classification. If you are interested in attending one of these courses, and have limited exposure to ML techniques, then this course is for you.
One of the users asked if there was a default way in R (the language that we covered in the training) of visualising a correlation plot. There is a standard visualisation built into the caret package but this was hard to interpret. At D&D we act on feedback from our customers to create custom and standalone solutions.
What is a confusion matrix you may ask?
Put quite simply – a confusion matrix is essentially a table used to assess prediction accuracy.
Throughout the training we used an example widely published by Google of how to understand the components of the table – the main point was to understand how to interpret the information pertaining to the example case we covered in the training, in terms of using hospital information to understand if a patient would be readmitted.
The confusion matrix (in a binary classification problem i.e. readmitted vs not readmitted) forms a 2 x 2 matrix as shown hereunder:
This shows that training models and improving accuracy is essential and necessary in making the most accurate predictions possible.
Assessing a confusion matrix
To assess the confusion matrix there are a few very useful measures:
- Accuracy – this shows how accurate overall the model is and is used as a benchmark to compare against other classification algorithms / models. The calculation for this metric is TP+TN/Total predictions
- Error (Misclassification) Rate – overall, how inaccurate the readmission predictions were. The calculation for this metric is FP+FN/Total predictions
- Sensitivity (True Positive Rate) – when we actually predict readmission, how often are we right at that predictions. The calculation for this metric is TP/Actual Positive (i.e. readmitted).
- Specificity (False Positive Rate) – when we predict the patient won’t be readmitted, how often is this prediction right. The calculation for this metric is FP/Actual Negative (i.e. not readmitted)
There are a few other metrics to assess the confusion matrix, but the ones outlined are the most useful in our experience. The aim here is to try to reduce the complexity and the number of metrics you need to look at to assess a classification method.
Why the need for a visual?
We have created the visual as a function to complement the existing confusion matrix options, contained in the caret package. The output of this package is useful for quick assessment of the model accuracy and fit – however it is not the most aesthetically pleasing to the eye.
Using the visualisation function
The candidates on the recent ML training performed data import, data cleaning, feature engineering, training a model and creating the confusion matrix. The following code assumes you have already fit a confusion matrix using the caret package.
Responding to feedback from the training there was a task, for D&D’s data science team, to create an example confusion matrix plot – in R – to allow our candidates to use the outputs of their models and to put these outputs to best use.
Creating a confusion matrix object in R
The model was trained – in this case using a decision tree – with the caret package. The confusion matrix was then created by using the following command:
cm <- confusionMatrix(data= dt_pred ,test[,4], positive = "1")
print(cm)
To quickly explain this part – the confusion matrix is a native command in R’s caret package and can be accessed as in the code. Then, it uses the data parameter set to dt_pred (this is just the name of the predictions made from the training dataset). The test (validation) dataset is then used to assess the predicted labels, from the training set, against the actual labels in the test set. When we say labels – many supervised ML algorithms are trained by looking at retrospective labels (in our case readmission vs not readmitted).
Visualising the confusion matrix in R
Once the confusion matrix object had been created in memory the function contained below is used to produce the visual. I will explain how the function works in bullets after the R code snippet:
conf_matrix_cust_plot <- function(cm_input, class_label1="Class Negative",
class_label2="Class Positive", quadrant_col1='#3F97D0',
quadrant_col2='#F7AD50', custom_title="Confusion matrix",
text_col="black", round_dig=2){
library(caret)
layout(matrix(c(1,1,2)))
par(mar=c(2,2,2,2))
plot(c(100, 345), c(300, 450), type = "n", xlab="", ylab="", xaxt='n', yaxt='n')
#n is specified in plot to indicate no plotting
title(custom_title, cex.main=2)
# Create the matrix visualisation using custom rectangles and text items on the chart
rect(150, 430, 240, 370, col=quadrant_col1)
text(195, 435, class_label1, cex=1.2)
rect(250, 430, 340, 370, col=quadrant_col2)
text(295, 435, class_label2, cex=1.2)
text(125, 370, 'Predicted', cex=1.3, srt=90, font=2)
text(245, 450, 'Actual', cex=1.3, font=2)
rect(150, 305, 240, 365, col=quadrant_col1)
rect(250, 305, 340, 365, col=quadrant_col2)
text(140, 400, class_label1, cex=1.2, srt=90)
text(140, 335, class_label2, cex=1.2, srt=90)
#Add the results of the confusion matrix - as these will be saved to cm$table
result <- as.numeric(cm$table)
text(195, 400, result[1], cex=1.6, font=2, col=text_col)
text(195, 335, result[2], cex=1.6, font=2, col=text_col)
text(295, 400, result[3], cex=1.6, font=2, col=text_col)
text(295, 335, result[4], cex=1.6, font=2, col=text_col)
#Add in other confusion matrix statistics
plot(c(100, 0), c(100, 0), type = "n", xlab="", ylab="", main = "Confusion matrix statistics", xaxt='n', yaxt='n')
text(10, 85, names(cm$byClass[1]), cex=1.6, font=2)
text(10, 70, round(as.numeric(cm$byClass[1]), round_dig), cex=1.2)
text(30, 85, names(cm$byClass[2]), cex=1.6, font=2)
text(30, 70, round(as.numeric(cm$byClass[2]), round_dig), cex=1.2)
text(50, 85, names(cm$byClass[5]), cex=1.6, font=2)
text(50, 70, round(as.numeric(cm$byClass[5]), round_dig), cex=1.2)
text(65, 85, names(cm$byClass[6]), cex=1.6, font=2)
text(65, 70, round(as.numeric(cm$byClass[6]), round_dig), cex=1.2)
text(86, 85, names(cm$byClass['Balanced Accuracy']), cex=1.6, font=2)
text(86, 70, round(as.numeric(cm$byClass['Balanced Accuracy']), round_dig), cex=1.2)
# add in the accuracy information
text(30, 35, names(cm$overall[1]), cex=1.5, font=2)
text(30, 20, round(as.numeric(cm$overall[1]), 4), cex=1.4)
text(70, 35, names(cm$overall[2]), cex=1.5, font=2)
text(70, 20, round(as.numeric(cm$overall[2]), round_dig), cex=1.4)
}
There is a lot to say about this function – however it will become apparent after the guidance below:
- The function takes 8 parameters – however the most important parameter is the cm_input parameter (this is the confusion matrix variable created in the section regarding creating the confusion matrix in R). The remaining parameters are the classification labels (defaulting to Class Positive and Class Negative; colour of the rectangles (these set to the default colours of D&D’s BI solutions if there are no inputs passed to the parameter; a title for the plot (this defaults to “Confusion Matrix” if nothing is specified), the text colour can be changed and the summary digits can be rounded up or down (using round_dig)
- The first couple of line set out how the matrix will be displayed in the graphical window – using R’s native plotting engine
- Following this – custom rectangles and text elements are added – on an x and y axis – to customise the top part of the confusion matrix. This is where the inputs such as the class labels change the display – for this example they are labelled Admitted and Not Admitted
- The next step is to store the results of the confusion matrix table into a variable called result and this accesses an element of the confusion matrix to bring back the 2×2 table needed. The font is then specified and the results added.
- Moving down the code further – a separate rectangle is drawn below the confusion matrix visualisation. This then contains the relevant accuracy assessment metrics discussed in this article.
- Finally, at the bottom of the other measures the confusion matrix also stores an element cm$overall this is the overall accuracy of the model. Stored alongside this is a metric called Kappa – this shows how well the model will perform with future observations in a live setting – see: https://en.wikipedia.org/wiki/Cohen%27s_kappa.
The line of code underneath shows how to use the visualisation:
conf_matrix_cust_plot(cm,
class_label1 = "Not Admitted",
class_label2 = "Admitted",
quadrant_col1 = "#e5e5e5",
quadrant_col2 = "#ec008e",
custom_title="Confusion matrix - Readmissions",
round_dig = 3)
Once this code is typed the Plot window in R will come to life and you will have a custom plot to use in reports, or alongside the summary metrics to validate your ML model:
Summary
We had a great time undertaking the ML training days. We look forward to seeing anyone who is interested at our next series of sessions. As always – we would love your feedback and if you would like to ask a question – please go to the Questions and Answers section of this website and we will respond as quickly as we can.
Signing off, and thanks!
Gary Hutson – Head of Solutions and AI