Comprehensive Guide To Multiclass Classification With Sklearn - Towards Data Science
Comprehensive Guide To Multiclass Classification With Sklearn - Towards Data Science
This is your last free member-only story this month. Upgrade for unlimited access.
Bex T. Follow
Save
Learn how to tackle any multiclass classification problem with Sklearn. The tutorial covers how to choose a model selection strategy,
several multiclass evaluation metrics and how to use them finishing off with hyperparameter tuning to optimize for user-defined metrics.
Introduction
Even though multi-class classification is not as common, it certainly poses a much bigger challenge than binary classification problems.
You can literally take my word for it because this article has been the most challenging post I have ever written (have written close to 70).
I found that the topic of multiclass classification is deep and full of nuances. I have read so many articles, read multiple StackOverflow
threads, created a few of my own, and spent several hours exploring the Sklearn user guide and doing experiments. The core topics of
multiclass classification such as
Get unlimited access Open in app
filtering out a single metric that solves your business problem and customizing it
and finally putting all the theory into practice with Sklearn
have all been scattered in the dark, sordid corners of the Internet. This was enough to conclude that no single resource shows an end-to-
end workflow of dealing with multiclass classification problems on the Internet (maybe, I missed it).
For this reason, this article will be a comprehensive tutorial on how to solve any multiclass supervised classification problem using
Sklearn. You will learn both the theory and the implementation of the above core concepts. It is going to be a long and technical read, so
get a coffee!
The first and the biggest group of estimators are the ones that support multi-class classification natively:
naive_bayes.BernoulliNB
tree.DecisionTreeClassifier
tree.ExtraTreeClassifier
ensemble.ExtraTreesClassifier
naive_bayes.GaussianNB
neighbors.KNeighborsClassifier
For an N-class problem, they produce N by N confusion matrix, and most of the evaluation metrics are derived from it:
6601.py
hosted with ❤ by GitHub view raw
Get unlimited access Open in app
OVO splits a multi-class problem into a single binary classification task for each pair of classes. In other words, for each pair, a single
binary classifier will be built. For example, a target with 4 classes — brain, lung, breast, and kidney cancer, uses 6 individual classifiers to
binarize the problem:
svm.NuSVC
svm.SVC
Sklearn also provides a wrapper estimator for the above models under sklearn.multiclass.OneVsOneClassifier :
6602.py
hosted with ❤ by GitHub view raw
A major downside of this strategy is its computation workload. As each pair of classes require a separate binary classifier, targets with high
cardinality may take too long to train. To compute the number of classifiers that will be built for an N-class problem, the following formula
is used:
6603.py
hosted with ❤ by GitHub view raw
In practice, the One-vs-Rest strategy is much preferred because of this disadvantage.
Get unlimited access Open in app
Classifier 1: lung vs. [breast, kidney, brain] — (lung cancer, not lung cancer)
Classifier 2: breast vs. [lung, kidney, brain] — (breast cancer, not breast cancer)
Classifier 3: kidney vs. [lung, breast, brain] — (kidney cancer, not kidney cancer)
Classifier 4: brain vs. [lung, breast kidney] — (brain cancer, not brain cancer)
Sklearn suggests these classifiers to work best with the OVR approach:
ensemble.GradientBoostingClassifier
linear_model.SGDClassifier
linear_model.Perceptron
Alternatively, you can use the above models with the default OneVsRestClassifier :
6604.py
hosted with ❤ by GitHub view raw
Even though this strategy significantly lowers the computational cost, the fact that only one class is considered positive and the rest as
negative makes each binary problem an imbalanced classification. This problem is even more pronounced for classes with low proportions
in the target.
In both approaches, depending on the passed estimator, the results of all binary classifiers can be summarized in two ways:
majority of the vote: each binary classifier predicts one class, and the class that got the most votes from all classifiers is chosen
Get unlimited access Open in app
depending on the argmax of class membership probability scores: classifiers such as LogisticRegression computes probability scores
for each class ( .predict_proba() ). Then, the argmax of the sum of the scores is chosen.
We will talk more about how to score each of these strategies later in the tutorial.
1 import pandas as pd
2
3 diamonds = pd.read_csv("data/diamonds.csv").drop("Unnamed: 0", axis=1)
4 diamonds.head()
6605.py
hosted with ❤ by GitHub view raw
1 >>> diamonds.shape
2 (53940, 10)
3
4
5 >>> diamonds.describe().T.round(3)
6606.py
hosted with ❤ by GitHub view raw
Get unlimited access Open in app
The above output shows the features are on different scales, suggesting we use some type of normalization. This step is essential for many
linear-based models to perform well.
1 >>> diamonds.cut.value_counts()
2
3 Ideal 21551
4 Premium 13791
5 Very Good 12082
6 Good 4906
7 Fair 1610
8 Name: cut, dtype: int64
6607.py
hosted with ❤ by GitHub view raw
The dataset contains a mixture of numeric and categorical features. I covered preprocessing steps for binary classification in my last article
in detail. You can easily apply the ideas to the multi-class case, so I will keep the explanations here nice and short.
The target is ‘cut’, which has 5 classes: Ideal, Premium, Very Good, Good, and Fair (descending quality). We will encode the textual
features with OneHotEncoder.
Let’s take a quick look at the distributions of each numeric feature to decide what type of normalization to use:
Price and carat show skewed distributions. We will use a logarithmic transformer to make them as normally distributed as possible. For
the rest, simple standardization is enough. If you are not familiar with numeric transformations, check out my article on the topic. Also,
the below code contains an example of Sklearn pipelines, and you can learn all about them from here.
6608.py
hosted with ❤ by GitHub view raw
6609.py
hosted with ❤ by GitHub view raw
The first version of our pipeline uses RandomForestClassifier . Let's look at its confusion matrix by generating predictions:
6610.py
hosted with ❤ by GitHub view raw
In lines 8 and 9, we are creating the matrix and using a special Sklearn function to plot it. ConfusionMatrixDisplay also has display_labels
argument, to which we are passing the class names accessed by pipeline.classes_ attribute.
Interpreting N by N confusion matrix
Get unlimited access Open in app
If you read my other article on binary classification, you know that confusion matrices are the holy grail of supervised classification
problems. In a 2 by 2 matrix, the matrix terms are easy to interpret and locate.
Even though it gets more difficult to interpret the matrix as the number of classes increases, there are sure-fire ways to find your way
around any matrix of any shape.
The first step is always identifying your positive and negative classes. This depends on the problem you are trying to solve. As a jewelry
store owner, I may want my classifier to differentiate Ideal and Premium diamonds better than other types, making these types of
diamonds my positive class. Other classes will be considered negative.
Establishing positive and negative classes early on is very important in evaluating model performance and in hyperparameter tuning. After
doing this, you should define your true positives, true negatives, false positives, and false negatives. In our case:
False Positives: actual value belongs to any of the 3 negative classes but predicted either Ideal or Premium
False Negatives: actual value is either Ideal or Premium but predicted by any of the 3 negative classes.
Always list out the terms of your matrix in this manner, and the rest of your workflow will be much easier, as you will see in the next
section.
The first metric we will discuss is the ROC AUC score or area under the receiver operating characteristic curve. It is mostly used when we
want to measure a classifier’s performance to differentiate between each class. This means that ROC AUC is better suited for balanced
classification tasks.
In essence, the ROC AUC score is used for binary classification and with models that can generate class membership probabilities based on
some threshold. Here is a brief overview of the steps to calculate ROC AUC for binary classification:
1. A binary classifier that can generate class membership probabilities such as LogisticRegression with its predict_proba method.
2. An initial, close to 0 decision threshold is chosen. For example, if the probability is higher than 0.1, the class is predicted negative else
positive.
4. True positive rate (TPR) and false positive rate (FPR) are found.
6. Repeat steps 2–5 for various thresholds between 0 and 1 to create a set of TPRs and FPRs.
7. Plot all TPRs vs. FPRs to generate the receiver operating characteristic curve.
Get unlimited access Open in app
For multiclass classification, you can calculate the ROC AUC for all classes using either OVO or OVR strategies. Since we agreed that OVR
is a better option, here is how ROC AUC is calculated for OVR classification:
1. Each binary classifier created using OVR finds the ROC AUC score for its own class using the above steps.
2. ROC AUC scores of all classifiers are then averaged using either of these 2 methods:
“weighted”: this takes class imbalance into account by finding a weighted average. Each ROC AUC is multiplied by their class weight
and summed, then divided by the total number of samples.
As an example, let’s say there are 100 samples in the target — class 1 (45), class 2 (30), class 3 (25). OVR creates 3 binary classifiers, 1 for
each class, and their ROC AUC scores are 0.75, 0.68, 0.84, respectively. The weighted ROC AUC score across all classes will be:
ROC AUC (weighted): ((45 * 0.75) + (30 * 0.68) + (25 * 0.84)) / 100 = 0.7515
6611.py
hosted with ❤ by GitHub view raw
Above, we calculated ROC AUC for our diamond classification problem and got an excellent score. Don’t forget to set the multi_class and
average parameters properly when using roc_auc_score . If you want to generate the score for a particular class, here is how you do it:
6612.py
hosted with ❤ by GitHub view raw
Get unlimited access Open in app
ROC AUC score is only a good metric to see how the classifier differentiates between classes. A higher ROC AUC score does not necessarily
mean a better model. On top of that, we care more about our model’s ability to classify Ideal and Premium diamonds, so a metric like ROC
AUC is not a good option for our case.
Images by author
In a multiclass case, these 3 metrics are calculated per-class basis. For example, let’s look at the confusion matrix again:
6613.py
hosted with ❤ by GitHub view raw
Get unlimited access Open in app
Precision tells us what proportion of predicted positives is truly positive. If we want to calculate precision for Ideal diamonds, true
positives would be the number of Ideal diamonds predicted correctly (the center of the matrix, 6626). False positives would be any cells
that count the number of times our classifier predicted other types of diamonds as Ideal. These would be the cells above and below the
center of the matrix (1013 + 521 + 31 + 8 = 1573). Using the formula of precision, we calculate it to be:
Recall is calculated similarly. We know the number of true positives — 6626. False negatives would be any cells that count the number of
times the classifier predicted the Ideal type of diamonds belonging to any other negative class. These would be the cells right and left to
the center of the matrix (3 + 9 + 363 + 111 = 486). Using the formula of recall, we calculate it to be:
So, how do we choose between recall and precision for the Ideal class? It depends on the type of problem you are trying to solve. If you
want to minimize the instances where other, cheaper types of diamonds are predicted as Ideal, you should optimize precision. As a jewelry
store owner, you might be sued for fraud for selling cheaper diamonds as expensive Ideal diamonds.
On the other hand, if you want to minimize the instances where you accidentally sell Ideal diamonds for a lower price, you should
optimize for recall of the Ideal class. Indeed, you won’t get sued, but you might lose money.
The third option is to have a model that is equally good at the above 2 scenarios. In other words, a model with high precision and recall.
Fortunately, there is a metric that measures just that: the F1 score. F1 score takes the harmonic mean of precision and recall and produces
a value between 0 and 1:
Get unlimited access Open in app
Up to this point, we calculated the 3 metrics only for the Ideal class. But in multiclass classification, Sklearn computes them for all classes.
You can use classification_report to see this:
6614.py
hosted with ❤ by GitHub view raw
You can check that our calculations for the Ideal class were correct. The last column of the table — support shows how many samples are
there for each class. Also, the last 2 rows show averaged scores for the 3 metrics. We already covered what macro and weighted averages
are in the example of ROC AUC.
For imbalanced classification tasks such as these, you rarely choose averaged precision, recall of F1 scores. Again, choosing one metric to
optimize for a particular class depends on your business problem. For our case, we will choose to optimize the F1 score of Ideal and
Premium classes (yes, you can choose multiple classes simultaneously). First, let’s see how to calculate weighted F1 across all class:
6615.py
hosted with ❤ by GitHub view raw
Get unlimited access Open in app
The above is consistent with the output of classification_report . To choose the F1 scores for Ideal and Premium classes, specify the
labels parameter:
6616.py
hosted with ❤ by GitHub view raw
Finally, let’s see how to optimize these metrics with hyperparameter tuning.
Up until now, we were using the RandomForestClassifier pipeline, so we will create a hyperparameter grid for this estimator:
6617.py
hosted with ❤ by GitHub view raw
244 3
Don’t forget to prepend each hyperparameter name with the step name you chose in the pipeline for your estimator. When we created our
pipeline, we specified RandomForests as ‘base’. See this discussion for more info.
We will use the HalvingGridSeachCV (HGS), which was much faster than a regular GridSearch. You can read this article to see my
experiments:
Before we feed the above grid to HGS, let’s create a custom scoring function. In the binary case, we could pass string values as the names
of the metrics we wanted to use, such as ‘precision’ or ‘recall.’ But in multiclass case, those functions accept additional parameters, and we
cannot do that if we pass the function names as strings. To solve this, Sklearn provides make_scorer function:
6618.py
hosted with ❤ by GitHub view raw
As we did in the last section, we pasted custom values for average and labels parameters.
Finally, let’s initialize the HGS and fit it to the full data with 3-fold cross-validation:
6619.py
hosted with ❤ by GitHub view raw
After the search is done, you can get the best score and estimator with .best_score_ and .best_estimator_ attributes, respectively.
Get unlimited access Open in app
Your model is only as good as the metric you choose to evaluate it with. Hyperparameter tuning will be time-consuming but assuming you
did everything right until this point and gave a good enough parameter grid, everything will turn out as expected. If not, it is an iterative
process, so take your time by tweaking the preprocessing steps, take a second look at your chosen metrics, and maybe widen your search
grid. Thank you for reading!
Related Articles
Multi-Class Metrics Made Simple, Part I: Precision and Recall
Discussions
How to choose between ROC AUC and the F1 score?
Every Thursday, the Variable delivers the very best of Towards Data Science: from hands-on tutorials and cutting-edge research to original features you don't want to
miss. Take a look.