Let’s see an example for DALEX package for classification models for the survival problem for Titanic dataset. Here we are using a dataset titanic avaliable in the DALEX package. Note that this data was copied from the stablelearner package.
library("DALEX")
head(titanic)
#> gender age class embarked country fare sibsp parch survived
#> 1 male 42 3rd Southampton United States 7.11 0 0 no
#> 2 male 13 3rd Southampton United States 20.05 0 2 no
#> 3 male 16 3rd Southampton United States 20.05 1 1 no
#> 4 female 39 3rd Southampton England 20.05 1 1 yes
#> 5 female 16 3rd Southampton Norway 7.13 0 0 yes
#> 6 male 25 3rd Southampton United States 7.13 0 0 yes
Ok, now it’s time to create a model. Let’s use the Random Forest model.
# prepare model
library("randomForest")
titanic <- na.omit(titanic)
model_titanic_rf <- randomForest(survived == "yes" ~ gender + age + class + embarked +
fare + sibsp + parch, data = titanic)
model_titanic_rf
#>
#> Call:
#> randomForest(formula = survived == "yes" ~ gender + age + class + embarked + fare + sibsp + parch, data = titanic)
#> Type of random forest: regression
#> Number of trees: 500
#> No. of variables tried at each split: 2
#>
#> Mean of squared residuals: 0.143236
#> % Var explained: 34.65
The third step (it’s optional but useful) is to create a DALEX explainer for random forest model.
library("DALEX")
explain_titanic_rf <- explain(model_titanic_rf,
data = titanic[,-9],
y = titanic$survived == "yes",
label = "Random Forest v7")
#> Preparation of a new explainer is initiated
#> -> model label : Random Forest v7
#> -> data : 2099 rows 8 cols
#> -> target variable : 2099 values
#> -> model_info : package randomForest , ver. 4.6.14 , task regression ( [33m default [39m )
#> -> predict function : yhat.randomForest will be used ( [33m default [39m )
#> -> predicted values : numerical, min = 0.01286123 , mean = 0.3248356 , max = 0.9912115
#> -> residual function : difference between y and yhat ( [33m default [39m )
#> -> residuals : numerical, min = -0.779851 , mean = -0.0003954087 , max = 0.9085878
#> [32m A new explainer has been created! [39m
Use the feature_importance() explainer to present importance of particular features. Note that type = "difference" normalizes dropouts, and now they all start in 0.
library("ingredients")
fi_rf <- feature_importance(explain_titanic_rf)
head(fi_rf)
#> variable mean_dropout_loss label
#> 1 _full_model_ 0.3332983 Random Forest v7
#> 2 country 0.3332983 Random Forest v7
#> 3 parch 0.3440449 Random Forest v7
#> 4 sibsp 0.3451616 Random Forest v7
#> 5 embarked 0.3503033 Random Forest v7
#> 6 fare 0.3733943 Random Forest v7
plot(fi_rf)
As we see the most important feature is gender. Next three importnat features are class, age and fare. Let’s see the link between model response and these features.
Such univariate relation can be calculated with partial_dependence().
Kids 5 years old and younger have much higher survival probability.
pp_age <- partial_dependence(explain_titanic_rf, variables = c("age", "fare"))
head(pp_age)
#> Top profiles :
#> _vname_ _label_ _x_ _yhat_ _ids_
#> 1 fare Random Forest v7 0.0000000 0.3241036 0
#> 2 age Random Forest v7 0.1666667 0.5364253 0
#> 3 age Random Forest v7 2.0000000 0.5607931 0
#> 4 age Random Forest v7 4.0000000 0.5750886 0
#> 5 fare Random Forest v7 6.1904000 0.3111265 0
#> 6 age Random Forest v7 7.0000000 0.5414633 0
plot(pp_age)
cp_age <- conditional_dependence(explain_titanic_rf, variables = c("age", "fare"))
plot(cp_age)
ap_age <- accumulated_dependence(explain_titanic_rf, variables = c("age", "fare"))
plot(ap_age)
Let’s see break down explanation for model predictions for 8 years old male from 1st class that embarked from port C.
First Ceteris Paribus Profiles for numerical variables
new_passanger <- data.frame(
class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", "engineering crew", "restaurant staff", "victualling crew")),
gender = factor("male", levels = c("female", "male")),
age = 8,
sibsp = 0,
parch = 0,
fare = 72,
embarked = factor("Southampton", levels = c("Belfast", "Cherbourg", "Queenstown", "Southampton"))
)
sp_rf <- ceteris_paribus(explain_titanic_rf, new_passanger)
plot(sp_rf) +
show_observations(sp_rf)
And for selected categorical variables. Note, that sibsp is numerical but here is presented as a categorical variable.
plot(sp_rf,
variables = c("class", "embarked", "gender", "sibsp"),
variable_type = "categorical")
It looks like the most important feature for this passenger is age and sex. After all his odds for survival are higher than for the average passenger. Mainly because of the young age and despite of being a male.
passangers <- select_sample(titanic, n = 100)
sp_rf <- ceteris_paribus(explain_titanic_rf, passangers)
clust_rf <- cluster_profiles(sp_rf, k = 3)
head(clust_rf)
#> Top profiles :
#> _vname_ _label_ _x_ _cluster_ _yhat_ _ids_
#> 1 fare Random Forest v7_1 0.0000000 1 0.1935959 0
#> 2 sibsp Random Forest v7_1 0.0000000 1 0.1695383 0
#> 3 parch Random Forest v7_1 0.0000000 1 0.1672070 0
#> 4 age Random Forest v7_1 0.1666667 1 0.4664651 0
#> 5 parch Random Forest v7_1 0.2800000 1 0.1671393 0
#> 6 sibsp Random Forest v7_1 1.0000000 1 0.1608335 0
plot(sp_rf, alpha = 0.1) +
show_aggregated_profiles(clust_rf, color = "_label_", size = 2)
sessionInfo()
#> R version 3.6.1 (2019-07-05)
#> Platform: x86_64-apple-darwin15.6.0 (64-bit)
#> Running under: macOS Catalina 10.15.3
#>
#> Matrix products: default
#> BLAS: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib
#>
#> locale:
#> [1] C/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] ggplot2_3.2.1 randomForest_4.6-14 ingredients_1.1
#> [4] DALEX_1.0
#>
#> loaded via a namespace (and not attached):
#> [1] Rcpp_1.0.3 pillar_1.4.3 compiler_3.6.1 tools_3.6.1
#> [5] digest_0.6.23 evaluate_0.14 lifecycle_0.1.0 tibble_2.1.3
#> [9] gtable_0.3.0 pkgconfig_2.0.3 rlang_0.4.2 yaml_2.2.0
#> [13] xfun_0.11 withr_2.1.2 stringr_1.4.0 dplyr_0.8.3
#> [17] knitr_1.28 grid_3.6.1 tidyselect_0.2.5 glue_1.3.1
#> [21] R6_2.4.1 rmarkdown_1.16 purrr_0.3.3 farver_2.0.3
#> [25] magrittr_1.5 scales_1.1.0 htmltools_0.4.0 assertthat_0.2.1
#> [29] colorspace_1.4-1 labeling_0.3 stringi_1.4.5 lazyeval_0.2.2
#> [33] munsell_0.5.0 crayon_1.3.4