model agnostic breakDown plots for randomForest

Przemyslaw Biecek

2024-03-11

Here we will use the HR churn data (https://www.kaggle.com/) to present the breakDown package for randomForest models.

The data is in the breakDown package

set.seed(1313)

library(breakDown)
head(HR_data, 3)
#>   satisfaction_level last_evaluation number_project average_montly_hours
#> 1               0.38            0.53              2                  157
#> 2               0.80            0.86              5                  262
#> 3               0.11            0.88              7                  272
#>   time_spend_company Work_accident left promotion_last_5years sales salary
#> 1                  3             0    1                     0 sales    low
#> 2                  6             0    1                     0 sales medium
#> 3                  4             0    1                     0 sales medium

Now let’s create a random forest regression model for churn, the left variable.

library("randomForest")
#> randomForest 4.7-1.1
#> Type rfNews() to see new features/changes/bug fixes.
#> 
#> Attaching package: 'randomForest'
#> The following object is masked from 'package:ggplot2':
#> 
#>     margin
model <- randomForest(factor(left)~., data = HR_data, family = "binomial", maxnodes = 5)

But how to understand which factors drive predictions for a single observation?

With the breakDown package!

Explanations for the linear predictor.

library(ggplot2)

predict.function <- function(model, new_observation) predict(model, new_observation, type="prob")[,2]
predict.function(model, HR_data[11,-7])
#> [1] 0.888

explain_1 <- broken(model, HR_data[11,-7], data = HR_data[,-7],
                    predict.function = predict.function, 
                    direction = "down")
explain_1
#>                              contribution
#> (Intercept)                         0.148
#> - satisfaction_level = 0.45         0.133
#> - number_project = 2                0.201
#> - last_evaluation = 0.54            0.182
#> - average_montly_hours = 135        0.141
#> - time_spend_company = 3            0.068
#> - Work_accident = 0                 0.010
#> - salary = low                      0.005
#> - sales = sales                     0.000
#> - promotion_last_5years = 0         0.000
#> final_prognosis                     0.888
#> baseline:  0
plot(explain_1) + ggtitle("breakDown plot  (direction=down) for randomForest model")


explain_2 <- broken(model, HR_data[11,-7], data = HR_data[,-7],
                    predict.function = predict.function, 
                    direction = "up")
explain_2
#>                              contribution
#> (Intercept)                         0.148
#> + satisfaction_level = 0.45         0.133
#> + number_project = 2                0.201
#> + last_evaluation = 0.54            0.182
#> + average_montly_hours = 135        0.141
#> + time_spend_company = 3            0.068
#> + Work_accident = 0                 0.010
#> + salary = low                      0.005
#> + promotion_last_5years = 0         0.000
#> + sales = sales                     0.000
#> final_prognosis                     0.888
#> baseline:  0
plot(explain_2) + ggtitle("breakDown plot (direction=up) for randomForest model")