PD and ICE curves with ORSF

Partial dependence (PD)

Partial dependence (PD) shows the expected prediction from a model as a function of a single predictor or multiple predictors. The expectation is marginalized over the values of all other predictors, giving something like a multivariable adjusted estimate of the model’s prediction.

library(aorsf)
library(ggplot2)

You can compute PD and individual conditional expectation (ICE) in three ways:

Classification

Begin by fitting an oblique classification random forest:


set.seed(329)

index_train <- sample(nrow(penguins_orsf), 150) 

penguins_orsf_train <- penguins_orsf[index_train, ]
penguins_orsf_test <- penguins_orsf[-index_train, ]

fit_clsf <- orsf(data = penguins_orsf_train, 
                 formula = species ~ .)

Compute PD using out-of-bag data for flipper_length_mm = c(190, 210).


pred_spec <- list(flipper_length_mm = c(190, 210))

pd_oob <- orsf_pd_oob(fit_clsf, pred_spec = pred_spec)

pd_oob
#> Key: <class>
#>        class flipper_length_mm      mean         lwr       medn       upr
#>       <fctr>             <num>     <num>       <num>      <num>     <num>
#> 1:    Adelie               190 0.6176908 0.202278109 0.75856417 0.9810614
#> 2:    Adelie               210 0.4338528 0.019173811 0.56489202 0.8648110
#> 3: Chinstrap               190 0.2114979 0.017643385 0.15211271 0.7215181
#> 4: Chinstrap               210 0.1803019 0.020108201 0.09679464 0.7035053
#> 5:    Gentoo               190 0.1708113 0.001334861 0.02769695 0.5750201
#> 6:    Gentoo               210 0.3858453 0.068685035 0.20717073 0.9532853

Note that predicted probabilities are returned for each class and probabilities in the mean column sum to 1 if you take the sum over each class at a specific value of the pred_spec variables. For example,


sum(pd_oob[flipper_length_mm == 190, mean])
#> [1] 1

But this isn’t the case for the median predicted probability!


sum(pd_oob[flipper_length_mm == 190, medn])
#> [1] 0.9383738

Regression

Begin by fitting an oblique regression random forest:


set.seed(329)

index_train <- sample(nrow(penguins_orsf), 150) 

penguins_orsf_train <- penguins_orsf[index_train, ]
penguins_orsf_test <- penguins_orsf[-index_train, ]

fit_regr <- orsf(data = penguins_orsf_train, 
                 formula = bill_length_mm ~ .)

Compute PD using new data for flipper_length_mm = c(190, 210).


pred_spec <- list(flipper_length_mm = c(190, 210))

pd_new <- orsf_pd_new(fit_regr, 
                      pred_spec = pred_spec,
                      new_data = penguins_orsf_test)

pd_new
#>    flipper_length_mm     mean      lwr     medn      upr
#>                <num>    <num>    <num>    <num>    <num>
#> 1:               190 42.96571 37.09805 43.69769 48.72301
#> 2:               210 45.66012 40.50693 46.31577 51.65163

You can also let pred_spec_auto pick reasonable values like so:


pred_spec = pred_spec_auto(species, island, body_mass_g)

pd_new <- orsf_pd_new(fit_regr, 
                      pred_spec = pred_spec,
                      new_data = penguins_orsf_test)

pd_new
#>       species    island body_mass_g     mean      lwr     medn      upr
#>        <fctr>    <fctr>       <num>    <num>    <num>    <num>    <num>
#>  1:    Adelie    Biscoe        3200 40.31374 37.24373 40.31967 44.22824
#>  2: Chinstrap    Biscoe        3200 45.10582 42.63342 45.10859 47.60119
#>  3:    Gentoo    Biscoe        3200 42.81649 40.19221 42.55664 46.84035
#>  4:    Adelie     Dream        3200 40.16219 36.95895 40.34633 43.90681
#>  5: Chinstrap     Dream        3200 46.21778 43.53954 45.90929 49.19173
#>  6:    Gentoo     Dream        3200 42.60465 39.89647 42.63520 46.28769
#>  7:    Adelie Torgersen        3200 39.91652 36.80227 39.79806 43.68842
#>  8: Chinstrap Torgersen        3200 44.27807 41.95470 44.40742 46.68848
#>  9:    Gentoo Torgersen        3200 42.09510 39.49863 41.80049 45.81833
#> 10:    Adelie    Biscoe        3550 40.77971 38.04027 40.59561 44.57505
#> 11: Chinstrap    Biscoe        3550 45.81304 43.52102 45.73116 48.36366
#> 12:    Gentoo    Biscoe        3550 43.31233 40.77355 43.03077 47.22936
#> 13:    Adelie     Dream        3550 40.77741 38.07399 40.78175 44.37273
#> 14: Chinstrap     Dream        3550 47.30926 44.80493 46.77540 50.47092
#> 15:    Gentoo     Dream        3550 43.26955 40.86119 43.16204 46.89190
#> 16:    Adelie Torgersen        3550 40.25780 37.35251 40.07871 44.04576
#> 17: Chinstrap Torgersen        3550 44.77911 42.60161 44.81944 47.14986
#> 18:    Gentoo Torgersen        3550 42.49520 39.95866 42.14160 46.26237
#> 19:    Adelie    Biscoe        3975 41.61744 38.94515 41.36634 45.38752
#> 20: Chinstrap    Biscoe        3975 46.59363 44.59970 46.44923 49.11457
#> 21:    Gentoo    Biscoe        3975 44.07857 41.60792 43.74562 47.85109
#> 22:    Adelie     Dream        3975 41.50511 39.06187 41.24741 45.13027
#> 23: Chinstrap     Dream        3975 48.14978 45.87390 47.54867 51.50683
#> 24:    Gentoo     Dream        3975 44.01928 41.70577 43.84099 47.50470
#> 25:    Adelie Torgersen        3975 40.94764 38.12519 40.66759 44.73689
#> 26: Chinstrap Torgersen        3975 45.44820 43.49986 45.44036 47.63243
#> 27:    Gentoo Torgersen        3975 43.13791 40.70628 42.70627 46.87306
#> 28:    Adelie    Biscoe        4700 42.93914 40.48463 42.44768 46.81756
#> 29: Chinstrap    Biscoe        4700 47.18534 45.40866 47.07739 49.55747
#> 30:    Gentoo    Biscoe        4700 45.32541 43.08173 44.93498 49.23391
#> 31:    Adelie     Dream        4700 42.73806 40.44229 42.22226 46.49936
#> 32: Chinstrap     Dream        4700 48.37354 46.34335 48.00781 51.18955
#> 33:    Gentoo     Dream        4700 45.09132 42.88328 44.79530 48.82180
#> 34:    Adelie Torgersen        4700 42.09349 39.72074 41.56168 45.68838
#> 35: Chinstrap Torgersen        4700 46.17045 44.39042 46.09525 48.35127
#> 36:    Gentoo Torgersen        4700 44.31621 42.18968 43.81773 47.98024
#> 37:    Adelie    Biscoe        5300 43.89769 41.43335 43.28504 48.10892
#> 38: Chinstrap    Biscoe        5300 47.53721 45.66038 47.52770 49.88701
#> 39:    Gentoo    Biscoe        5300 46.16115 43.81722 45.59309 50.57469
#> 40:    Adelie     Dream        5300 43.59846 41.25825 43.24518 47.46193
#> 41: Chinstrap     Dream        5300 48.48139 46.36282 48.25679 51.02996
#> 42:    Gentoo     Dream        5300 45.91819 43.62832 45.54110 49.91622
#> 43:    Adelie Torgersen        5300 42.92879 40.66576 42.31072 46.76406
#> 44: Chinstrap Torgersen        5300 46.59576 44.80400 46.49196 49.03906
#> 45:    Gentoo Torgersen        5300 45.11384 42.95190 44.51289 49.27629
#>       species    island body_mass_g     mean      lwr     medn      upr

By default, all combinations of all variables are used. However, you can also look at the variables one by one, separately, like so:


pd_new <- orsf_pd_new(fit_regr, 
                      expand_grid = FALSE,
                      pred_spec = pred_spec,
                      new_data = penguins_orsf_test)

pd_new
#>        variable value     level     mean      lwr     medn      upr
#>          <char> <num>    <char>    <num>    <num>    <num>    <num>
#>  1:     species    NA    Adelie 41.90271 37.10417 41.51723 48.51478
#>  2:     species    NA Chinstrap 47.11314 42.40419 46.96478 51.51392
#>  3:     species    NA    Gentoo 44.37038 39.87306 43.89889 51.21635
#>  4:      island    NA    Biscoe 44.21332 37.22711 45.27862 51.21635
#>  5:      island    NA     Dream 44.43354 37.01471 45.57261 51.51392
#>  6:      island    NA Torgersen 43.29539 37.01513 44.26924 49.84391
#>  7: body_mass_g  3200      <NA> 42.84625 37.03978 43.95991 49.19173
#>  8: body_mass_g  3550      <NA> 43.53326 37.56730 44.43756 50.47092
#>  9: body_mass_g  3975      <NA> 44.30431 38.31567 45.22089 51.50683
#> 10: body_mass_g  4700      <NA> 45.22559 39.88199 46.34680 51.18955
#> 11: body_mass_g  5300      <NA> 45.91412 40.84742 46.95327 51.48851

And you can also bypass all the bells and whistles by using your own data.frame for a pred_spec. (Just make sure you request values that exist in the training data.)


custom_pred_spec <- data.frame(species = 'Adelie', 
                               island = 'Biscoe')

pd_new <- orsf_pd_new(fit_regr, 
                      pred_spec = custom_pred_spec,
                      new_data = penguins_orsf_test)

pd_new
#>    species island     mean      lwr     medn      upr
#>     <fctr> <fctr>    <num>    <num>    <num>    <num>
#> 1:  Adelie Biscoe 41.98024 37.22711 41.65252 48.51478

Survival

Begin by fitting an oblique survival random forest:


set.seed(329)

index_train <- sample(nrow(pbc_orsf), 150) 

pbc_orsf_train <- pbc_orsf[index_train, ]
pbc_orsf_test <- pbc_orsf[-index_train, ]

fit_surv <- orsf(data = pbc_orsf_train, 
                 formula = Surv(time, status) ~ . - id,
                 oobag_pred_horizon = 365.25 * 5)

Compute PD using in-bag data for bili = c(1,2,3,4,5):

pd_train <- orsf_pd_inb(fit_surv, pred_spec = list(bili = 1:5))
pd_train
#>    pred_horizon  bili      mean        lwr      medn       upr
#>           <num> <num>     <num>      <num>     <num>     <num>
#> 1:      1826.25     1 0.2566200 0.02234786 0.1334170 0.8918909
#> 2:      1826.25     2 0.3121392 0.06853733 0.1896849 0.9204338
#> 3:      1826.25     3 0.3703242 0.11409793 0.2578505 0.9416791
#> 4:      1826.25     4 0.4240692 0.15645214 0.3331057 0.9591581
#> 5:      1826.25     5 0.4663670 0.20123406 0.3841700 0.9655296

If you don’t have specific values of a variable in mind, let pred_spec_auto pick for you:

pd_train <- orsf_pd_inb(fit_surv, pred_spec_auto(bili))
pd_train
#>    pred_horizon  bili      mean        lwr      medn       upr
#>           <num> <num>     <num>      <num>     <num>     <num>
#> 1:      1826.25  0.55 0.2481444 0.02035041 0.1242215 0.8801444
#> 2:      1826.25  0.70 0.2502831 0.02045039 0.1271039 0.8836536
#> 3:      1826.25  1.50 0.2797763 0.03964900 0.1601715 0.9041584
#> 4:      1826.25  3.50 0.3959349 0.13431288 0.2920400 0.9501230
#> 5:      1826.25  7.25 0.5351935 0.28064629 0.4652185 0.9783000

Specify pred_horizon to get PD at each value:


pd_train <- orsf_pd_inb(fit_surv, pred_spec_auto(bili),
                        pred_horizon = seq(500, 3000, by = 500))
pd_train
#>     pred_horizon  bili       mean         lwr        medn       upr
#>            <num> <num>      <num>       <num>       <num>     <num>
#>  1:          500  0.55 0.06171990 0.000443399 0.008654190 0.5907104
#>  2:         1000  0.55 0.14185009 0.005793742 0.055728527 0.7360749
#>  3:         1500  0.55 0.20825053 0.013609478 0.091745579 0.8556319
#>  4:         2000  0.55 0.26790167 0.023047689 0.145741690 0.8910549
#>  5:         2500  0.55 0.31796166 0.063797305 0.202544999 0.9017710
#>  6:         3000  0.55 0.39108086 0.090852131 0.301804690 0.9234812
#>  7:          500  0.70 0.06240527 0.000443399 0.008934806 0.5980510
#>  8:         1000  0.70 0.14313570 0.006159694 0.056348007 0.7432448
#>  9:         1500  0.70 0.21012128 0.013717586 0.092461532 0.8597396
#> 10:         2000  0.70 0.27013021 0.023169510 0.146344595 0.8935664
#> 11:         2500  0.70 0.31880954 0.062506113 0.201979102 0.9068170
#> 12:         3000  0.70 0.39286323 0.089707173 0.308392927 0.9252028
#> 13:          500  1.50 0.06679162 0.001271788 0.011028398 0.6241228
#> 14:         1000  1.50 0.15727919 0.011478962 0.068332010 0.7678732
#> 15:         1500  1.50 0.23316655 0.028732095 0.117289745 0.8789647
#> 16:         2000  1.50 0.30139227 0.046792721 0.180096425 0.9144202
#> 17:         2500  1.50 0.35260943 0.084586675 0.238015966 0.9266065
#> 18:         3000  1.50 0.43512074 0.131110330 0.346025144 0.9438562
#> 19:          500  3.50 0.08638646 0.005208753 0.028239001 0.6740930
#> 20:         1000  3.50 0.22353655 0.051917978 0.139604845 0.8283986
#> 21:         1500  3.50 0.32700976 0.090198324 0.217982772 0.9371150
#> 22:         2000  3.50 0.41618105 0.144532860 0.311508093 0.9566091
#> 23:         2500  3.50 0.49248461 0.219511094 0.402095677 0.9636221
#> 24:         3000  3.50 0.56008108 0.263569896 0.503253258 0.9734948
#> 25:          500  7.25 0.12585007 0.022092057 0.063550987 0.7543806
#> 26:         1000  7.25 0.32646274 0.135343689 0.259567907 0.8884333
#> 27:         1500  7.25 0.46412653 0.218208755 0.387874346 0.9702903
#> 28:         2000  7.25 0.55117610 0.293367409 0.484277295 0.9812413
#> 29:         2500  7.25 0.62002385 0.371965247 0.569543990 0.9845058
#> 30:         3000  7.25 0.68034820 0.425128031 0.646423180 0.9888637
#>     pred_horizon  bili       mean         lwr        medn       upr

One variable, moving horizon

For the next few sections, we update orsf_fit to include all the data in pbc_orsf instead of just the training sample:

# a rare case of modify_in_place = TRUE
orsf_update(fit_surv, 
            data = pbc_orsf, 
            modify_in_place = TRUE)

fit_surv
#> ---------- Oblique random survival forest
#> 
#>      Linear combinations: Accelerated Cox regression
#>           N observations: 276
#>                 N events: 111
#>                  N trees: 500
#>       N predictors total: 17
#>    N predictors per node: 5
#>  Average leaves per tree: 21.038
#> Min observations in leaf: 5
#>       Min events in leaf: 1
#>           OOB stat value: 0.84
#>            OOB stat type: Harrell's C-index
#>      Variable importance: anova
#> 
#> -----------------------------------------

What if the effect of a predictor varies over time? Partial dependence can show this.


pd_sex_tv <- orsf_pd_oob(fit_surv, 
                         pred_spec = pred_spec_auto(sex),
                         pred_horizon = seq(365, 365*5))

ggplot(pd_sex_tv) +
 aes(x = pred_horizon, y = mean, color = sex) + 
 geom_line() +
 labs(x = 'Time since baseline',
      y = 'Expected risk')

From inspection, we can see that males have higher risk than females and the difference in that risk grows over time. This can also be seen by viewing the ratio of expected risk over time:


library(data.table)

ratio_tv <- pd_sex_tv[
 , .(ratio = mean[sex == 'm'] / mean[sex == 'f']), by = pred_horizon
]

ggplot(ratio_tv, aes(x = pred_horizon, y = ratio)) + 
 geom_line(color = 'grey') + 
 geom_smooth(color = 'black', se = FALSE) + 
 labs(x = 'time since baseline',
      y = 'ratio in expected risk for males versus females')

To get a view of PD for any number of variables in the training data, use orsf_summarize_uni(). This function computes out-of-bag PD for the most important n_variables and returns a nicely formatted view of the output:

pd_smry <- orsf_summarize_uni(fit_surv, n_variables = 4)

pd_smry
#> 
#> -- ascites (VI Rank: 1) -------------------------
#> 
#>         |---------------- Risk ----------------|
#>   Value      Mean    Median     25th %    75th %
#>  <char>     <num>     <num>      <num>     <num>
#>       0 0.3083611 0.1989535 0.06581247 0.5269629
#>       1 0.4702604 0.3975953 0.27481738 0.6564321
#> 
#> -- bili (VI Rank: 2) ----------------------------
#> 
#>         |---------------- Risk ----------------|
#>   Value      Mean    Median     25th %    75th %
#>  <char>     <num>     <num>      <num>     <num>
#>    0.60 0.2357183 0.1548597 0.05872720 0.3722014
#>    0.80 0.2398514 0.1612673 0.06160845 0.3783445
#>    1.40 0.2612711 0.1808661 0.07893386 0.4068599
#>    3.55 0.3710045 0.3148286 0.17270635 0.5451942
#>    7.30 0.4787515 0.4404399 0.29634561 0.6427858
#> 
#> -- edema (VI Rank: 3) ---------------------------
#> 
#>         |---------------- Risk ----------------|
#>   Value      Mean    Median     25th %    75th %
#>  <char>     <num>     <num>      <num>     <num>
#>       0 0.3036004 0.1840849 0.06509174 0.5228237
#>     0.5 0.3558595 0.2643993 0.11132293 0.5833002
#>       1 0.4694189 0.3977797 0.28211662 0.6332457
#> 
#> -- copper (VI Rank: 4) --------------------------
#> 
#>         |---------------- Risk ----------------|
#>   Value      Mean    Median     25th %    75th %
#>  <char>     <num>     <num>      <num>     <num>
#>    25.0 0.2630999 0.1617276 0.05581251 0.4308429
#>    42.5 0.2706567 0.1703028 0.05887747 0.4418590
#>    74.0 0.2908956 0.1940176 0.07155433 0.4768302
#>     130 0.3446359 0.2656935 0.11918406 0.5574967
#>     217 0.4272771 0.3615510 0.22018120 0.6261011
#> 
#>  Predicted risk at time t = 1826.25 for top 4 predictors

This ‘summary’ object can be converted into a data.table for downstream plotting and tables.

head(as.data.table(pd_smry))
#>    variable importance  Value      Mean    Median     25th %    75th %
#>      <char>      <num> <char>     <num>     <num>      <num>     <num>
#> 1:  ascites  0.4965517      0 0.3083611 0.1989535 0.06581247 0.5269629
#> 2:  ascites  0.4965517      1 0.4702604 0.3975953 0.27481738 0.6564321
#> 3:     bili  0.4153488   0.60 0.2357183 0.1548597 0.05872720 0.3722014
#> 4:     bili  0.4153488   0.80 0.2398514 0.1612673 0.06160845 0.3783445
#> 5:     bili  0.4153488   1.40 0.2612711 0.1808661 0.07893386 0.4068599
#> 6:     bili  0.4153488   3.55 0.3710045 0.3148286 0.17270635 0.5451942
#>    pred_horizon  level
#>           <num> <char>
#> 1:      1826.25      0
#> 2:      1826.25      1
#> 3:      1826.25   <NA>
#> 4:      1826.25   <NA>
#> 5:      1826.25   <NA>
#> 6:      1826.25   <NA>

Multiple variables, jointly

Partial dependence can show the expected value of a model’s predictions as a function of a specific predictor, or as a function of multiple predictors. For instance, we can estimate predicted risk as a joint function of bili, edema, and trt:


pred_spec = pred_spec_auto(bili, edema, trt)

pd_bili_edema <- orsf_pd_oob(fit_surv, pred_spec)

ggplot(pd_bili_edema) + 
 aes(x = bili, y = medn, col = trt, linetype = edema) + 
 geom_line() + 
 labs(y = 'Expected predicted risk')

From inspection,

Find interactions using PD

Random forests are good at using interactions, but less good at telling you about them. Use orsf_vint() to apply the method for variable interaction scoring with PD described by Greenwell et al (2018). This can take a little while if you have lots of predictors, and it seems to work best with continuous by continuous interactions. Interactions with categorical variables are sometimes over- or under- scored.


# use just the continuous variables
preds <- names(fit_surv$get_means())

vint_scores <- orsf_vint(fit_surv, predictors = preds)

vint_scores
#>            interaction      score          pd_values
#>                 <char>      <num>             <list>
#>  1:   albumin..protime 1.16208007 <data.table[25x9]>
#>  2:    copper..protime 0.79309473 <data.table[25x9]>
#>  3:       bili..copper 0.75913271 <data.table[25x9]>
#>  4:         bili..chol 0.75512999 <data.table[25x9]>
#>  5:          age..bili 0.73391673 <data.table[25x9]>
#>  6:      bili..albumin 0.68318871 <data.table[25x9]>
#>  7:       albumin..ast 0.59701935 <data.table[25x9]>
#>  8:      bili..protime 0.59350261 <data.table[25x9]>
#>  9:     bili..platelet 0.57346860 <data.table[25x9]>
#> 10:       ast..protime 0.56164263 <data.table[25x9]>
#> 11:    albumin..copper 0.54634520 <data.table[25x9]>
#> 12:         bili..trig 0.50536831 <data.table[25x9]>
#> 13:       copper..trig 0.48646180 <data.table[25x9]>
#> 14:       age..protime 0.46142701 <data.table[25x9]>
#> 15:           age..ast 0.44014148 <data.table[25x9]>
#> 16:      age..platelet 0.42738851 <data.table[25x9]>
#> 17:  albumin..platelet 0.41637733 <data.table[25x9]>
#> 18:      chol..albumin 0.39936841 <data.table[25x9]>
#> 19:  platelet..protime 0.38308225 <data.table[25x9]>
#> 20:        age..copper 0.36384070 <data.table[25x9]>
#> 21:        copper..ast 0.34949201 <data.table[25x9]>
#> 22:      trig..protime 0.30198694 <data.table[25x9]>
#> 23:     bili..alk.phos 0.26162716 <data.table[25x9]>
#> 24:      chol..protime 0.25007961 <data.table[25x9]>
#> 25:   copper..alk.phos 0.22516988 <data.table[25x9]>
#> 26:         chol..trig 0.21034041 <data.table[25x9]>
#> 27:          bili..ast 0.20985458 <data.table[25x9]>
#> 28:     trig..platelet 0.18693124 <data.table[25x9]>
#> 29:      age..alk.phos 0.18147838 <data.table[25x9]>
#> 30:       chol..copper 0.17937060 <data.table[25x9]>
#> 31:   copper..platelet 0.17098297 <data.table[25x9]>
#> 32:       age..albumin 0.15870462 <data.table[25x9]>
#> 33:     alk.phos..trig 0.14064898 <data.table[25x9]>
#> 34:          age..trig 0.13026236 <data.table[25x9]>
#> 35:          chol..ast 0.11979451 <data.table[25x9]>
#> 36:  albumin..alk.phos 0.11805809 <data.table[25x9]>
#> 37:     chol..alk.phos 0.11323838 <data.table[25x9]>
#> 38:      ast..platelet 0.09153564 <data.table[25x9]>
#> 39:  alk.phos..protime 0.08366434 <data.table[25x9]>
#> 40:      alk.phos..ast 0.08197152 <data.table[25x9]>
#> 41:          ast..trig 0.07332730 <data.table[25x9]>
#> 42:          age..chol 0.05520846 <data.table[25x9]>
#> 43:     chol..platelet 0.04921464 <data.table[25x9]>
#> 44: alk.phos..platelet 0.04728775 <data.table[25x9]>
#> 45:      albumin..trig 0.04584439 <data.table[25x9]>
#>            interaction      score          pd_values

The scores include partial dependence values that you can pull out and plot:


# top scoring interaction
pd_top <- vint_scores$pd_values[[1]]

# center pd values so it's easier to see the interaction effect
pd_top[, mean := mean - mean[1], by = var_2_value]

ggplot(pd_top) + 
 aes(x = var_1_value, 
     y = mean, 
     color = factor(var_2_value), 
     group = factor(var_2_value)) + 
 geom_line() + 
 labs(x = "albumin", 
      y = "predicted mortality (centered)",
      color = "protime")

Again we use a sanity check with coxph to see if these interactions are detected using a standard test:


# test the top score (expect strong interaction)
fit_cph <- coxph(Surv(time,status) ~ albumin * protime, 
                 data = pbc_orsf)

anova(fit_cph)
#> Analysis of Deviance Table
#>  Cox model: response is Surv(time, status)
#> Terms added sequentially (first to last)
#> 
#>                  loglik  Chisq Df Pr(>|Chi|)    
#> NULL            -550.19                         
#> albumin         -526.29 47.801  1  4.717e-12 ***
#> protime         -514.89 22.806  1  1.792e-06 ***
#> albumin:protime -511.76  6.252  1    0.01241 *  
#> ---
#> Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Note: Caution is warranted when interpreting statistical hypotheses that are motivated by the same data they are tested with. Results like the p-values for interaction shown above should be interpreted as exploratory.

Individual conditional expectations (ICE)

Unlike partial dependence, which shows the expected prediction as a function of one or multiple predictors, individual conditional expectations (ICE) show the prediction for an individual observation as a function of a predictor.

Classification

Compute ICE using out-of-bag data for flipper_length_mm = c(190, 210).


pred_spec <- list(flipper_length_mm = c(190, 210))

ice_oob <- orsf_ice_oob(fit_clsf, pred_spec = pred_spec)

ice_oob
#> Key: <class>
#>      id_variable id_row  class flipper_length_mm       pred
#>            <int> <char> <fctr>             <num>      <num>
#>   1:           1      1 Adelie               190 0.92169247
#>   2:           1      2 Adelie               190 0.80944657
#>   3:           1      3 Adelie               190 0.85172955
#>   4:           1      4 Adelie               190 0.93559327
#>   5:           1      5 Adelie               190 0.97708693
#>  ---                                                       
#> 896:           2    146 Gentoo               210 0.26092984
#> 897:           2    147 Gentoo               210 0.04798334
#> 898:           2    148 Gentoo               210 0.07927359
#> 899:           2    149 Gentoo               210 0.84779971
#> 900:           2    150 Gentoo               210 0.11105143

There are two identifiers in the output:

Note that predicted probabilities are returned for each class and each observation in the data. Predicted probabilities for a given observation and given variable value sum to 1. For example,


ice_oob %>%
 .[flipper_length_mm == 190] %>% 
 .[id_row == 1] %>% 
 .[['pred']] %>% 
 sum()
#> [1] 1

Regression

Compute ICE using new data for flipper_length_mm = c(190, 210).


pred_spec <- list(flipper_length_mm = c(190, 210))

ice_new <- orsf_ice_new(fit_regr, 
                        pred_spec = pred_spec,
                        new_data = penguins_orsf_test)

ice_new
#>      id_variable id_row flipper_length_mm     pred
#>            <int> <char>             <num>    <num>
#>   1:           1      1               190 37.94483
#>   2:           1      2               190 37.61595
#>   3:           1      3               190 37.53681
#>   4:           1      4               190 39.49476
#>   5:           1      5               190 38.95635
#>  ---                                              
#> 362:           2    179               210 51.80471
#> 363:           2    180               210 47.27183
#> 364:           2    181               210 47.05031
#> 365:           2    182               210 50.39028
#> 366:           2    183               210 48.44774

You can also let pred_spec_auto pick reasonable values like so:


pred_spec = pred_spec_auto(species, island, body_mass_g)

ice_new <- orsf_ice_new(fit_regr, 
                        pred_spec = pred_spec,
                        new_data = penguins_orsf_test)

ice_new
#>       id_variable id_row species    island body_mass_g     pred
#>             <int> <char>  <fctr>    <fctr>       <num>    <num>
#>    1:           1      1  Adelie    Biscoe        3200 37.78339
#>    2:           1      2  Adelie    Biscoe        3200 37.73273
#>    3:           1      3  Adelie    Biscoe        3200 37.71248
#>    4:           1      4  Adelie    Biscoe        3200 40.25782
#>    5:           1      5  Adelie    Biscoe        3200 40.04074
#>   ---                                                          
#> 8231:          45    179  Gentoo Torgersen        5300 46.14559
#> 8232:          45    180  Gentoo Torgersen        5300 43.98050
#> 8233:          45    181  Gentoo Torgersen        5300 44.59837
#> 8234:          45    182  Gentoo Torgersen        5300 44.85146
#> 8235:          45    183  Gentoo Torgersen        5300 44.23710

By default, all combinations of all variables are used. However, you can also look at the variables one by one, separately, like so:


ice_new <- orsf_ice_new(fit_regr, 
                        expand_grid = FALSE,
                        pred_spec = pred_spec,
                        new_data = penguins_orsf_test)

ice_new
#>       id_variable id_row    variable value  level     pred
#>             <int> <char>      <char> <num> <char>    <num>
#>    1:           1      1     species    NA Adelie 37.74136
#>    2:           1      2     species    NA Adelie 37.42367
#>    3:           1      3     species    NA Adelie 37.04598
#>    4:           1      4     species    NA Adelie 39.89602
#>    5:           1      5     species    NA Adelie 39.14848
#>   ---                                                     
#> 2009:           5    179 body_mass_g  5300   <NA> 51.50196
#> 2010:           5    180 body_mass_g  5300   <NA> 47.27055
#> 2011:           5    181 body_mass_g  5300   <NA> 48.34064
#> 2012:           5    182 body_mass_g  5300   <NA> 48.75828
#> 2013:           5    183 body_mass_g  5300   <NA> 48.11020

And you can also bypass all the bells and whistles by using your own data.frame for a pred_spec. (Just make sure you request values that exist in the training data.)


custom_pred_spec <- data.frame(species = 'Adelie', 
                               island = 'Biscoe')

ice_new <- orsf_ice_new(fit_regr, 
                        pred_spec = custom_pred_spec,
                        new_data = penguins_orsf_test)

ice_new
#>      id_variable id_row species island     pred
#>            <int> <char>  <fctr> <fctr>    <num>
#>   1:           1      1  Adelie Biscoe 38.52327
#>   2:           1      2  Adelie Biscoe 38.32073
#>   3:           1      3  Adelie Biscoe 37.71248
#>   4:           1      4  Adelie Biscoe 41.68380
#>   5:           1      5  Adelie Biscoe 40.91140
#>  ---                                           
#> 179:           1    179  Adelie Biscoe 43.09493
#> 180:           1    180  Adelie Biscoe 38.79455
#> 181:           1    181  Adelie Biscoe 39.37734
#> 182:           1    182  Adelie Biscoe 40.71952
#> 183:           1    183  Adelie Biscoe 39.34501

Survival

Compute ICE using in-bag data for bili = c(1,2,3,4,5):

ice_train <- orsf_ice_inb(fit_surv, pred_spec = list(bili = 1:5))
ice_train
#>       id_variable id_row pred_horizon  bili      pred
#>             <int> <char>        <num> <num>     <num>
#>    1:           1      1      1826.25     1 0.9015162
#>    2:           1      2      1826.25     1 0.1018093
#>    3:           1      3      1826.25     1 0.6810217
#>    4:           1      4      1826.25     1 0.3609523
#>    5:           1      5      1826.25     1 0.1354010
#>   ---                                                
#> 1376:           5    272      1826.25     5 0.2651225
#> 1377:           5    273      1826.25     5 0.3036780
#> 1378:           5    274      1826.25     5 0.3468740
#> 1379:           5    275      1826.25     5 0.1653363
#> 1380:           5    276      1826.25     5 0.3543087

If you don’t have specific values of a variable in mind, let pred_spec_auto pick for you:

ice_train <- orsf_ice_inb(fit_surv, pred_spec_auto(bili))
ice_train
#>       id_variable id_row pred_horizon  bili       pred
#>             <int> <char>        <num> <num>      <num>
#>    1:           1      1      1826.25   0.6 0.89210440
#>    2:           1      2      1826.25   0.6 0.09173543
#>    3:           1      3      1826.25   0.6 0.65389145
#>    4:           1      4      1826.25   0.6 0.34483859
#>    5:           1      5      1826.25   0.6 0.13107816
#>   ---                                                 
#> 1376:           5    272      1826.25   7.3 0.31509470
#> 1377:           5    273      1826.25   7.3 0.35307247
#> 1378:           5    274      1826.25   7.3 0.41603645
#> 1379:           5    275      1826.25   7.3 0.25370259
#> 1380:           5    276      1826.25   7.3 0.45088467

Specify pred_horizon to get ICE at each value:


ice_train <- orsf_ice_inb(fit_surv, pred_spec_auto(bili),
                          pred_horizon = seq(500, 3000, by = 500))
ice_train
#>       id_variable id_row pred_horizon  bili      pred
#>             <int> <char>        <num> <num>     <num>
#>    1:           1      1          500   0.6 0.5950043
#>    2:           1      1         1000   0.6 0.7652137
#>    3:           1      1         1500   0.6 0.8751746
#>    4:           1      1         2000   0.6 0.9057135
#>    5:           1      1         2500   0.6 0.9231915
#>   ---                                                
#> 8276:           5    276         1000   7.3 0.2153098
#> 8277:           5    276         1500   7.3 0.3700953
#> 8278:           5    276         2000   7.3 0.4903015
#> 8279:           5    276         2500   7.3 0.5774981
#> 8280:           5    276         3000   7.3 0.6268579

Multi-prediction horizon ice comes with minimal extra computational cost. Use a fine grid of time values and assess whether predictors have time-varying effects.

Visualizing ICE curves

Inspecting the ICE curves for each observation can help identify whether there is heterogeneity in a model’s predictions. I.e., does the effect of the variable follow the same pattern for all the data, or are there groups where the variable impacts risk differently?

I am going to turn off boundary checking in orsf_ice_oob by setting boundary_checks = FALSE, and this will allow me to generate ICE curves that go beyond the 90th percentile of bili.


pred_spec <- list(bili = seq(1, 10, length.out = 25))

ice_oob <- orsf_ice_oob(fit_surv, pred_spec, boundary_checks = FALSE)

ice_oob
#>       id_variable id_row pred_horizon  bili      pred
#>             <int> <char>        <num> <num>     <num>
#>    1:           1      1      1826.25     1 0.8790861
#>    2:           1      2      1826.25     1 0.8132035
#>    3:           1      3      1826.25     1 0.6240238
#>    4:           1      4      1826.25     1 0.7461603
#>    5:           1      5      1826.25     1 0.5754091
#>   ---                                                
#> 6896:          25    272      1826.25    10 0.7018976
#> 6897:          25    273      1826.25    10 0.4606246
#> 6898:          25    274      1826.25    10 0.3351786
#> 6899:          25    275      1826.25    10 0.6040355
#> 6900:          25    276      1826.25    10 0.2789017

For plots, it is helpful to scale the ICE data. I subtract the initial value of predicted risk (i.e., when bili = 1) from each observation’s conditional expectation values. So,

Now we can visualize the curves.


ggplot(ice_oob, aes(x = bili, 
                    y = pred, 
                    group = id_row)) + 
 geom_line(alpha = 0.15) + 
 labs(y = 'Change in predicted risk') +
 geom_smooth(se = FALSE, aes(group = 1))

From inspection of the figure,

Limitations of PD

Partial dependence has a number of known limitations and assumptions that users should be aware of (see Hooker, 2021). In particular, partial dependence is less intuitive when >2 predictors are examined jointly, and it is assumed that the feature(s) for which the partial dependence is computed are not correlated with other features (this is likely not true in many cases). Accumulated local effect plots can be used (see here) in the case where feature independence is not a valid assumption.

References

  1. Hooker, Giles, Mentch, Lucas, Zhou, Siyu (2021). “Unrestricted permutation forces extrapolation: variable importance requires at least one more model, or there is no free variable importance.” Statistics and Computing, 31, 1-16.