1 Introduction

DALEX is designed to work with various black-box models like tree ensembles, linear models, neural networks etc. Unfortunately R packages that create such models are very inconsistent. Different tools use different interfaces to train, validate and use models.

In this vignette we will show explanations for models from mlr (Bischl et al. 2016).

2 Regression use case - apartments data

library(DALEX)
library(DALEXtra)
library(mlr)
library(breakDown)

To illustrate applications of DALEX to regression problems we will use an artificial dataset apartments available in the DALEX package. Our goal is to predict the price per square meter of an apartment based on selected features such as construction year, surface, floor, number of rooms, district. It should be noted that four of these variables are continuous while the fifth one is a categorical one. Prices are given in Euro.

data(apartments)
head(apartments)
##   m2.price construction.year surface floor no.rooms    district
## 1     5897              1953      25     3        1 Srodmiescie
## 2     1818              1992     143     9        5     Bielany
## 3     3643              1937      56     1        2       Praga
## 4     3517              1995      93     7        3      Ochota
## 5     3013              1992     144     6        5     Mokotow
## 6     5795              1926      61     6        2 Srodmiescie

2.1 The explain() function

The first step of using the DALEX package is to wrap-up the black-box model with meta-data that unifies model interfacing. To work with mlr models we use the DALEXtra package, ectension to DALEX package.

In this vignette we will use 3 models: random forest, gradient boosting machine model, and neutral network for regression.

According to the semantics of the mlr package at the beginning we have to make our regression task using function makeRegrTask() and build learners for our models using the makeLearner() function.

set.seed(123)
regr_task <- makeRegrTask(id = "ap", data = apartments, target = "m2.price")
regr_lrn_rf <- makeLearner("regr.randomForest")
regr_lrn_nn <- makeLearner("regr.nnet")
regr_lrn_gbm <- makeLearner("regr.gbm", par.vals = list(n.trees = 500))

Additionally, for the neural network model we set additional parameters and do the data preprocessing.

regr_lrn_nn <- setHyperPars(regr_lrn_nn, par.vals = list(maxit=500, size=2))
regr_lrn_nn <- makePreprocWrapperCaret(regr_lrn_nn, ppc.scale=TRUE, ppc.center=TRUE)

Below, we use the mlr function train() to fit our models.

regr_rf <- train(regr_lrn_rf, regr_task)
regr_nn <- train(regr_lrn_nn, regr_task)
regr_gbm <- train(regr_lrn_gbm, regr_task)

To create an explainer for these models it is enough to use explain_mlr() function with the model, data and y parameters. Validation dataset for the models is apartmentsTest data from the DALEX package. For the models created by mlr package we have to provide custom predict function which takes two arguments: model and newdata and returns a numeric vector with predictions because function predict() from mlr returns not only predictions but an object with more information.

data(apartmentsTest)
explainer_regr_rf <- DALEXtra::explain_mlr(regr_rf, data=apartmentsTest, y=apartmentsTest$m2.price, 
                                           label="rf", verbose = FALSE)
explainer_regr_nn <- DALEXtra::explain_mlr(regr_nn, data=apartmentsTest, y=apartmentsTest$m2.price,
                                           label="nn", verobose = FALSE)
## Preparation of a new explainer is initiated
##   -> model label       :  nn 
##   -> data              :  9000  rows  6  cols 
##   -> target variable   :  9000  values 
##   -> predict function  :  yhat.WrappedModel  will be used (  default  )
##   -> predicted values  :  numerical, min =  3042.321 , mean =  3490.524 , max =  3754.98  
##   -> model_info        :  package mlr , ver. 2.17.0 , task regression (  default  ) 
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  -2036.98 , mean =  20.99976 , max =  2924.02  
##   A new explainer has been created! 
explainer_regr_gbm <- DALEXtra::explain_mlr(regr_gbm, data=apartmentsTest, y=apartmentsTest$m2.price,
                                            label="gbm", verbose = FALSE)

2.2 Model performance

Function model_performance() calculates predictions and residuals for validation dataset.

mp_regr_rf <- model_performance(explainer_regr_rf)
mp_regr_gbm <- model_performance(explainer_regr_gbm)
mp_regr_nn <- model_performance(explainer_regr_nn)

Generic function print() returns quantiles for residuals.

mp_regr_rf
## Measures for:  regression
## mse        : 79113.32 
## rmse       : 281.2709 
## r2         : 0.9024261 
## mad        : 170.6715
## 
## Residuals:
##          0%         10%         20%         30%         40%         50% 
## -720.367763 -292.280918 -216.723678 -158.327152 -105.918427  -53.924498 
##         60%         70%         80%         90%        100% 
##    7.442933   88.250267  198.604201  398.788140 1244.948784

Generic function plot() shows reversed empirical cumulative distribution function for absolute values from residuals. Plots can be generated for one or more models.

plot(mp_regr_rf, mp_regr_nn, mp_regr_gbm)

The figure above shows that majority of residuals for random forest are smaller than residuals for the neural network and gbm.

We are also able to use the plot() function to get an alternative comparison of residuals. Setting the geom = "boxplot" parameter we can compare the distribution of residuals for selected models.

plot(mp_regr_rf, mp_regr_nn, mp_regr_gbm, geom = "boxplot")

2.3 Variable importance

Using he DALEX package we are able to better understand which variables are important.

Model agnostic variable importance is calculated by means of permutations. We simply substract the loss function calculated for validation dataset with permuted values for a single variable from the loss function calculated for validation dataset.

This method is implemented in the model_parts() function.

vi_regr_rf <- model_parts(explainer_regr_rf, loss_function = loss_root_mean_square)
vi_regr_gbm <- model_parts(explainer_regr_gbm, loss_function = loss_root_mean_square)
vi_regr_nn <- model_parts(explainer_regr_nn, loss_function = loss_root_mean_square)

We can compare all models using the generic plot() function.

plot(vi_regr_rf, vi_regr_gbm, vi_regr_nn)

Length of the interval coresponds to a variable importance. Longer interval means larger loss, so the variable is more important.

For better comparison of the models we can hook the variabe importance at 0 using the type=difference.

vi_regr_rf <- model_parts(explainer_regr_rf, loss_function = loss_root_mean_square, type = "difference")
vi_regr_gbm <- model_parts(explainer_regr_gbm, loss_function = loss_root_mean_square, type = "difference")
vi_regr_nn <- model_parts(explainer_regr_nn, loss_function = loss_root_mean_square, type = "difference")

plot(vi_regr_rf, vi_regr_gbm, vi_regr_nn)

We see that in random forest and gbm model the most important variable is district.

2.4 Model profile

Explainers presented in this section are designed to better understand the relation between a variable and model output.

For more details of methods desribed in this section see: - Partial-dependence Profiles - Local-dependence and Accumulated-local Profiles

2.4.1 Partial Dependence Plot

Partial Dependence Plots (PDP) are one of the most popular methods for exploration of the relation between a continuous variable and the model outcome.

Function model_profile() with the parameter type = "partial" to calculate PDP response.

pdp_regr_rf  <- model_profile(explainer_regr_rf, variable =  "construction.year", type = "partial")
pdp_regr_gbm  <- model_profile(explainer_regr_gbm, variable =  "construction.year", type = "partial")
pdp_regr_nn  <- model_profile(explainer_regr_nn, variable =  "construction.year", type = "partial")

plot(pdp_regr_rf, pdp_regr_gbm, pdp_regr_nn)

We use PDP plots to compare our 3 models. As we can see above performance of random forest may tell us that we have non-linear relation in the data. It looks like the neural network and gbm don’t captured that relation.

2.4.2 Acumulated Local Effects plot

Acumulated Local Effects (ALE) plot is the extension of PDP, that is more suited for highly correlated variables.

Function model_profile() with the parameter type = "accumulated" to calculate the ALE curve for the variable construction.year.

ale_regr_rf  <- model_profile(explainer_regr_rf, variable =  "construction.year", type = "accumulated")
ale_regr_gbm  <- model_profile(explainer_regr_gbm, variable =  "construction.year", type = "accumulated")
ale_regr_nn  <- model_profile(explainer_regr_nn, variable =  "construction.year", type = "accumulated")

plot(ale_regr_rf, ale_regr_gbm, ale_regr_nn)

2.4.3 Partial Dependence Profile for categorical variable

Function model_profile() with the parameter type = "partial" for categorical variable.

mpp_regr_rf  <- model_profile(explainer_regr_rf, variable =  "district", type = "partial")
mpp_regr_gbm  <- model_profile(explainer_regr_gbm, variable =  "district", type = "partial")
mpp_regr_nn  <- model_profile(explainer_regr_nn, variable =  "district", type = "partial")

plot(mpp_regr_rf, mpp_regr_gbm, mpp_regr_nn)

We can note some kind of three clusters: the city center (Srodmiescie), districts well communicated with city center (Ochota, Mokotow, Zoliborz - for the random forest and gbm) and other districts closer to city boundaries.

3 Classification use case - wine data

To illustrate applications of DALEX to classification problems we will use a wine dataset available in the breakDown package. We want to classify the quality of wine. Originally this variable has 7 levels but in our example, it will be reduced to the binary classification. Our classification will be based on eleven features from this data set.

White wine quality data is related to variants of the Portuguese “Vinho Verde” wine. For more details, consult: http://www.vinhoverde.pt/en/.

data(wine)
wine$quality <- ifelse(wine$quality>5, 1, 0)

First, we create a train and test indexes which are needed to train the mlr models when we don’t have an additional test set - wineTest.

wine$quality <- factor(wine$quality)
train_index <- sample(1:nrow(wine), 0.6 * nrow(wine))
test_index <- setdiff(1:nrow(wine), train_index)

wineTest <- wine[test_index,]

In this vignette we will use 3 models: random forest, logistic regression and support vector machines for classification.

According to the semantics of the mlr package at the beginning we have to make our classification task using function makeClassifTask() and build learners for our models using the makeLearner() function with the parameter predict.type=prob.

classif_task <- makeClassifTask(id = "ap", data = wine, target = "quality")
classif_lrn_rf <- makeLearner("classif.randomForest", predict.type = "prob")
classif_lrn_glm <- makeLearner("classif.binomial", predict.type = "prob")
classif_lrn_svm <- makeLearner("classif.ksvm", predict.type = "prob")

Next, we use train() to fit 3 our models.

classif_rf <- train(classif_lrn_rf, classif_task, subset=train_index)
classif_glm <- train(classif_lrn_glm, classif_task, subset=train_index)
classif_svm <- train(classif_lrn_svm, classif_task, subset=train_index)

As previously, to create an explainer for these models we use explain() function. Validation dataset for the models is wineTest.

In this case we consider the differences between observed class and predicted probabilities to be residuals. So, we have to provide custom predict function which takes two arguments: model and newdata and returns a numeric vector with probabilities.

y_test <- as.numeric(as.character(wineTest$quality))

explainer_classif_rf <- DALEXtra::explain_mlr(classif_rf, data = wineTest, y = y_test, 
                                              label= "rf", verbose = FALSE)
explainer_classif_glm <- DALEXtra::explain_mlr(classif_glm, data = wineTest, y = y_test,
                                               label = "glm", verbose = FALSE)
explainer_classif_svm <- DALEXtra::explain_mlr(classif_svm, data = wineTest, y = y_test,
                                               label = "svm", verbose = FALSE)

3.1 Model performance

Function model_performance() calculates predictions and residuals for validation dataset wineTest.

We use the generic plot() function to get a comparison of models.

mp_classif_rf <- model_performance(explainer_classif_rf)
mp_classif_glm <- model_performance(explainer_classif_glm)
mp_classif_svm <- model_performance(explainer_classif_svm)

plot(mp_classif_rf, mp_classif_glm, mp_classif_svm)

Setting the geom = "boxplot" parameter we can compare the distribution of residuals for selected models.

plot(mp_classif_rf, mp_classif_glm, mp_classif_svm, geom = "boxplot")

3.2 Variable importance

Function variable_importance() computes variable importances which may be plotted.

vi_classif_rf <- model_parts(explainer_classif_rf, loss_function = loss_root_mean_square)
vi_classif_glm <- model_parts(explainer_classif_glm, loss_function = loss_root_mean_square)
vi_classif_svm <- model_parts(explainer_classif_svm, loss_function = loss_root_mean_square)

plot(vi_classif_rf, vi_classif_glm, vi_classif_svm)

Left edges of intervals start in full model. Length of the interval coresponds to a variable importance. Longer interval means larger loss, so the variable is more important.

3.3 Model profile

As previously, we create explainers which are designed to better understand the relation between a variable and model output: PDP plots and ALE plots.

For more details of methods desribed in this section see: - Partial-dependence Profiles - Local-dependence and Accumulated-local Profiles

3.3.1 Partial Depedence Plot

pdp_classif_rf  <- model_profile(explainer_classif_rf, variable = "pH", type = "partial")
pdp_classif_glm  <- model_profile(explainer_classif_glm, variable = "pH", type = "partial")
pdp_classif_svm  <- model_profile(explainer_classif_svm, variable = "pH", type = "partial")

plot(pdp_classif_rf, pdp_classif_glm, pdp_classif_svm)

3.3.2 Acumulated Local Effects plot

ale_classif_rf  <- model_profile(explainer_classif_rf, variable = "alcohol", type = "accumulated")
ale_classif_glm  <- model_profile(explainer_classif_glm, variable = "alcohol", type = "accumulated")
ale_classif_svm  <- model_profile(explainer_classif_svm, variable = "alcohol", type = "accumulated")

plot(ale_classif_rf, ale_classif_glm, ale_classif_svm)

4 Session info

sessionInfo()
## R version 3.6.3 (2020-02-29)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 18363)
## 
## Matrix products: default
## 
## locale:
## [1] LC_COLLATE=Polish_Poland.1250  LC_CTYPE=Polish_Poland.1250   
## [3] LC_MONETARY=Polish_Poland.1250 LC_NUMERIC=C                  
## [5] LC_TIME=Polish_Poland.1250    
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] breakDown_0.2.0   mlr_2.17.0        ParamHelpers_1.13 DALEXtra_2.0     
## [5] DALEX_2.0.1      
## 
## loaded via a namespace (and not attached):
##  [1] jsonlite_1.6.1       splines_3.6.3        foreach_1.4.8       
##  [4] prodlim_2019.11.13   stats4_3.6.3         ingredients_2.0     
##  [7] yaml_2.2.1           ipred_0.9-9          pillar_1.4.3        
## [10] backports_1.1.5      lattice_0.20-38      glue_1.3.2          
## [13] reticulate_1.14      pROC_1.16.1          digest_0.6.25       
## [16] checkmate_2.0.0      randomForest_4.6-14  colorspace_1.4-1    
## [19] recipes_0.1.10       gbm_2.1.5            htmltools_0.4.0     
## [22] Matrix_1.2-18        plyr_1.8.5           timeDate_3043.102   
## [25] XML_3.99-0.3         pkgconfig_2.0.3      caret_6.0-86        
## [28] purrr_0.3.3          scales_1.1.0         parallelMap_1.4     
## [31] gower_0.2.1          lava_1.6.7           tibble_2.1.3        
## [34] generics_0.0.2       farver_2.0.3         ggplot2_3.3.0       
## [37] withr_2.1.2          nnet_7.3-12          survival_3.1-8      
## [40] magrittr_1.5         crayon_1.3.4         evaluate_0.14       
## [43] nlme_3.1-144         MASS_7.3-51.5        class_7.3-15        
## [46] tools_3.6.3          data.table_1.12.8    lifecycle_0.2.0     
## [49] BBmisc_1.11          stringr_1.4.0        kernlab_0.9-29      
## [52] munsell_0.5.0        compiler_3.6.3       rlang_0.4.6         
## [55] grid_3.6.3           iterators_1.0.12     rappdirs_0.3.1      
## [58] labeling_0.3         rmarkdown_2.1        gtable_0.3.0        
## [61] ModelMetrics_1.2.2.2 codetools_0.2-16     reshape2_1.4.3      
## [64] R6_2.4.1             gridExtra_2.3        lubridate_1.7.4     
## [67] knitr_1.28           dplyr_1.0.0          fastmatch_1.1-0     
## [70] stringi_1.4.6        parallel_3.6.3       Rcpp_1.0.4          
## [73] vctrs_0.3.1          rpart_4.1-15         tidyselect_1.1.0    
## [76] xfun_0.12