In the following, we explain the counterfactuals
workflow for both a classification and a regression task using concrete
use cases.
library("counterfactuals")
library("iml")
library("rpart")
The Predictor
class of the iml
package
provides the necessary flexibility to cover classification and
regression models fitted with diverse R packages. In the introduction
vignette, we saw models fitted with the mlr3
and
randomForest
packages. In the following, we show extensions
to - an classification tree fitted with the caret
package,
the mlr
(a predecesor of mlr3
) and
tidymodels
. For each model we generate counterfactuals for
the 100th row of the plasma dataset of the gamlss.data
package using the WhatIf
method.
data(plasma, package = "gamlss.data")
= plasma[100L,] x_interest
library("caret")
= caret::train(retplasma ~ ., data = plasma[-100L,], method = "rpart",
treecaret tuneGrid = data.frame(cp = 0.01))
= Predictor$new(model = treecaret, data = plasma[-100L,], y = "retplasma")
predcaret $predict(x_interest)
predcaret#> .prediction
#> 1 342.9231
= NICERegr$new(predcaret, optimization = "proximity",
nicecaret margin_correct = 0.5, return_multiple = FALSE)
$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
nicecaret#> 1 Counterfactual(s)
#>
#> Desired outcome range: [500, Inf]
#>
#> Head:
#> age sex smokstat bmi vituse calories fat fiber alcohol cholesterol betadiet retdiet betaplasma
#> <int> <fctr> <fctr> <num> <fctr> <num> <num> <num> <num> <num> <int> <int> <int>
#> 1: 46 1 3 35.25969 3 2667.5 131.6 10.1 0 550.5 1210 1291 218
library("tidymodels")
= decision_tree(mode = "regression", engine = "rpart") %>%
treetm fit(retplasma ~ ., data = plasma[-100L,])
= Predictor$new(model = treetm, data = plasma[-100L,], y = "retplasma")
predtm $predict(x_interest)
predtm#> .pred
#> 1 342.9231
= NICERegr$new(predtm, optimization = "proximity",
nicetm margin_correct = 0.5, return_multiple = FALSE)
$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
nicetm#> 1 Counterfactual(s)
#>
#> Desired outcome range: [500, Inf]
#>
#> Head:
#> age sex smokstat bmi vituse calories fat fiber alcohol cholesterol betadiet retdiet betaplasma
#> <int> <fctr> <fctr> <num> <fctr> <num> <num> <num> <num> <num> <int> <int> <int>
#> 1: 46 1 3 35.25969 3 2667.5 131.6 10.1 0 550.5 1210 1291 218
library("mlr")
= mlr::makeRegrTask(data = plasma[-100L,], target = "retplasma")
task = mlr::makeLearner("regr.rpart")
mod
= mlr::train(mod, task)
treemlr = Predictor$new(model = treemlr, data = plasma[-100L,], y = "retplasma")
predmlr $predict(x_interest)
predmlr#> .prediction
#> 1 342.9231
= NICERegr$new(predmlr, optimization = "proximity",
nicemlr margin_correct = 0.5, return_multiple = FALSE)
$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
nicemlr#> 1 Counterfactual(s)
#>
#> Desired outcome range: [500, Inf]
#>
#> Head:
#> age sex smokstat bmi vituse calories fat fiber alcohol cholesterol betadiet retdiet betaplasma
#> <int> <fctr> <fctr> <num> <fctr> <num> <num> <num> <num> <num> <int> <int> <int>
#> 1: 46 1 3 35.25969 3 2667.5 131.6 10.1 0 550.5 1210 1291 218
= rpart(retplasma ~ ., data = plasma[-100L,])
treerpart = Predictor$new(model = treerpart, data = plasma[-100L,], y = "retplasma")
predrpart $predict(x_interest)
predrpart#> pred
#> 1 342.9231
= NICERegr$new(predrpart, optimization = "proximity",
nicerpart margin_correct = 0.5, return_multiple = FALSE)
$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
nicerpart#> 1 Counterfactual(s)
#>
#> Desired outcome range: [500, Inf]
#>
#> Head:
#> age sex smokstat bmi vituse calories fat fiber alcohol cholesterol betadiet retdiet betaplasma
#> <int> <fctr> <fctr> <num> <fctr> <num> <num> <num> <num> <num> <int> <int> <int>
#> 1: 46 1 3 35.25969 3 2667.5 131.6 10.1 0 550.5 1210 1291 218