Stacking classifier #
Stacking can be applied in a similar way to classification problems. For demonstration, we consider the Spambase Data Set from the Machine Learning Repository. We load the data and shuffle the observations around since they are ordered by outcome.
. insheet using ///
https://archive.ics.uci.edu/ml/machine-learning-databases/spambase/spambase.data, ///
clear comma
. set seed 42
. gen uni=runiform()
. sort uni
Stacking classification works very similar to stacking regression. The example below is somewhat more complicated. Let’s go through it step-by-step:
- We use 6 base learners: logit, random forest, gradient boosting and 3x neural nets.
- We apply the
poly2
pipeline to the logistic regressor, which creates squares and interaction terms of the predictors. - For both gradient boosting and random forest, we increase the number of classification trees to 1000.
- We consider three types of neural nets: (1) with one layer of 100 neurons (the default), (2) three layers with 50 neurons each, (3) one layer of 200 neurons.
- We use
type(class)
to specify that we consider a classification task. - Finally,
njobs(-1)
switches parallelization on with all available CPUs.
Please note that this might take a while to run.
. pystacked v58 v1-v57 || ///
> m(logit) pipe(poly2) || ///
> m(rf) opt(n_estimators(1000)) || ///
> m(gradboost) opt(n_estimators(1000)) || ///
> m(nnet) || ///
> m(nnet) opt(hidden_layer_sizes(50 50 50)) || ///
> m(nnet) opt(hidden_layer_sizes(200)) || ///
> if _n<=3000 , type(class) njobs(-1)
Stacking weights:
---------------------------------------
Method | Weight
-----------------+---------------------
logit | 0.0016062
rf | 0.0762563
gradboost | 0.7524429
nnet | 0.0810773
nnet | 0.0574056
nnet | 0.0312117
Confusion matrix #
After estimation, we can obtain the predicted class using predict
. The predicted classes allow us to construct in-sample and out-of-sample confusion matrices:
. predict spam, class
. tab spam v58 if _n<=3000, cell
| v58
spam | 0 1 | Total
-----------+----------------------+----------
0 | 1,792 2 | 1,794
| 59.73 0.07 | 59.80
-----------+----------------------+----------
1 | 1 1,205 | 1,206
| 0.03 40.17 | 40.20
-----------+----------------------+----------
Total | 1,793 1,207 | 3,000
| 59.77 40.23 | 100.00
. tab spam v58 if _n>3000, cell
| v58
spam | 0 1 | Total
-----------+----------------------+----------
0 | 962 43 | 1,005
| 60.09 2.69 | 62.77
-----------+----------------------+----------
1 | 33 563 | 596
| 2.06 35.17 | 37.23
-----------+----------------------+----------
Total | 995 606 | 1,601
| 62.15 37.85 | 100.00
The table
option makes this even easier. The table below shows the in-sample and out-of-sample classification errors
for stacking and each base learner–all in one table.
. pystacked, table holdout
Number of holdout observations: 1601
Confusion matrix: In-Sample, CV, Holdout
-----------------------------------------------------------------------------
Method | Weight In-Sample CV Holdout
| 0 1 0 1 0 1
-----------------+-----------------------------------------------------------
STACKING 0 | . 1792 2 1728 76 962 43
STACKING 1 | . 1 1205 65 1131 33 563
logit 0 | 0.002 1077 68 1097 65 562 35
logit 1 | 0.002 716 1139 696 1142 433 571
rf 0 | 0.076 1792 0 1709 100 948 44
rf 1 | 0.076 1 1207 84 1107 47 562
gradboost 0 | 0.752 1792 2 1726 75 960 43
gradboost 1 | 0.752 1 1205 67 1132 35 563
nnet 0 | 0.081 1758 175 1671 113 957 100
nnet 1 | 0.081 35 1032 122 1094 38 506
nnet 0 | 0.057 1669 77 1650 125 910 55
nnet 1 | 0.057 124 1130 143 1082 85 551
nnet 0 | 0.031 1679 86 1654 116 911 54
nnet 1 | 0.031 114 1121 139 1091 84 552
ROC curve and AUC #
pystacked
supports ROC curves which allow to assess the classification performance for varying disrimination thresholds.
The y-axis in an ROC plot corresponds to sensitivity (true positive rate) and the x-axis corresponds to 1-specificity (false positive rate). The Area Under the Curve (AUC) displayed below each ROC plot is a common evaluation metric for classification problems.
. pystacked, graph(subtitle(Spam data)) ///
lgraph(plotopts(msymbol(i) ylabel(0 1, format(%3.1f)))) ///
holdout