Setup
Display system information for reproducibility.
R version 4.3.2 (2023-10-31)
Platform: aarch64-unknown-linux-gnu (64-bit)
Running under: Ubuntu 22.04.3 LTS
Matrix products: default
BLAS: /usr/lib/aarch64-linux-gnu/openblas-pthread/libblas.so.3
LAPACK: /usr/lib/aarch64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so; LAPACK version 3.10.0
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=en_US.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
time zone: Etc/UTC
tzcode source: system (glibc)
attached base packages:
[1] stats graphics grDevices utils datasets methods base
loaded via a namespace (and not attached):
[1] htmlwidgets_1.6.4 compiler_4.3.2 fastmap_1.1.1 cli_3.6.2
[5] tools_4.3.2 htmltools_0.5.7 rstudioapi_0.15.0 yaml_2.3.8
[9] rmarkdown_2.25 knitr_1.45 jsonlite_1.8.8 xfun_0.42
[13] digest_0.6.34 rlang_1.1.3 evaluate_0.23
import IPython
print (IPython.sys_info())
using InteractiveUtils
versioninfo ()
Overview
We illustrate the typical machine learning workflow for model stacking using the Heart
data set. The outcome is AHD
(Yes
or No
), using the stacks
package.
Model stacking is an ensembling method that takes the outputs of many models and combines them to generate a new model—referred to as an ensemble in this package—that generates predictions informed by each of its members.
Initial splitting to test and non-test sets.
Pre-processing of data: dummy coding categorical variables, standardizing numerical variables, imputing missing values, …
Tune the gradient boosting algorithm using 5-fold cross-validation (CV) on the non-test data.
Choose the best model by CV and refit it on the whole non-test data.
Final classification on the test data.
Heart data
The goal is to predict the binary outcome AHD
(Yes
or No
) of patients.
# Load libraries
library (GGally)
Loading required package: ggplot2
Registered S3 method overwritten by 'GGally':
method from
+.gg ggplot2
library (gtsummary)
library (keras)
library (ranger)
library (stacks)
library (tidyverse)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr 1.1.4 ✔ readr 2.1.5
✔ forcats 1.0.0 ✔ stringr 1.5.1
✔ lubridate 1.9.3 ✔ tibble 3.2.1
✔ purrr 1.0.2 ✔ tidyr 1.3.1
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag() masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ──
✔ broom 1.0.5 ✔ rsample 1.2.0
✔ dials 1.2.1 ✔ tune 1.1.2
✔ infer 1.0.6 ✔ workflows 1.1.4
✔ modeldata 1.3.0 ✔ workflowsets 1.0.1
✔ parsnip 1.2.0 ✔ yardstick 1.3.0
✔ recipes 1.0.10
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ recipes::all_double() masks gtsummary::all_double()
✖ recipes::all_factor() masks gtsummary::all_factor()
✖ recipes::all_integer() masks gtsummary::all_integer()
✖ recipes::all_logical() masks gtsummary::all_logical()
✖ recipes::all_numeric() masks gtsummary::all_numeric()
✖ scales::discard() masks purrr::discard()
✖ dplyr::filter() masks stats::filter()
✖ recipes::fixed() masks stringr::fixed()
✖ yardstick::get_weights() masks keras::get_weights()
✖ dplyr::lag() masks stats::lag()
✖ yardstick::spec() masks readr::spec()
✖ recipes::step() masks stats::step()
• Learn how to get started at https://www.tidymodels.org/start/
Attaching package: 'xgboost'
The following object is masked from 'package:dplyr':
slice
# Load the `Heart.csv` data.
Heart <- read_csv ("Heart.csv" ) |>
# first column is patient ID, which we don't need
select (- 1 ) |>
# RestECG is categorical with value 0, 1, 2
mutate (RestECG = as.character (RestECG)) |>
mutate (AHD = as.factor (AHD)) |>
print (width = Inf )
New names:
Rows: 303 Columns: 15
── Column specification
──────────────────────────────────────────────────────── Delimiter: "," chr
(3): ChestPain, Thal, AHD dbl (12): ...1, Age, Sex, RestBP, Chol, Fbs, RestECG,
MaxHR, ExAng, Oldpeak,...
ℹ Use `spec()` to retrieve the full column specification for this data. ℹ
Specify the column types or set `show_col_types = FALSE` to quiet this message.
• `` -> `...1`
# A tibble: 303 × 14
Age Sex ChestPain RestBP Chol Fbs RestECG MaxHR ExAng Oldpeak Slope
<dbl> <dbl> <chr> <dbl> <dbl> <dbl> <chr> <dbl> <dbl> <dbl> <dbl>
1 63 1 typical 145 233 1 2 150 0 2.3 3
2 67 1 asymptomatic 160 286 0 2 108 1 1.5 2
3 67 1 asymptomatic 120 229 0 2 129 1 2.6 2
4 37 1 nonanginal 130 250 0 0 187 0 3.5 3
5 41 0 nontypical 130 204 0 2 172 0 1.4 1
6 56 1 nontypical 120 236 0 0 178 0 0.8 1
7 62 0 asymptomatic 140 268 0 2 160 0 3.6 3
8 57 0 asymptomatic 120 354 0 0 163 1 0.6 1
9 63 1 asymptomatic 130 254 0 2 147 0 1.4 2
10 53 1 asymptomatic 140 203 1 2 155 1 3.1 3
Ca Thal AHD
<dbl> <chr> <fct>
1 0 fixed No
2 3 normal Yes
3 2 reversable Yes
4 0 normal No
5 0 normal No
6 0 normal No
7 2 normal Yes
8 0 normal No
9 1 reversable Yes
10 0 reversable Yes
# ℹ 293 more rows
# Numerical summaries stratified by the outcome `AHD`.
Heart |> tbl_summary (by = AHD)
Characteristic
No , N = 164
Yes , N = 139
Age
52 (45, 59)
58 (52, 62)
Sex
92 (56%)
114 (82%)
ChestPain
asymptomatic
39 (24%)
105 (76%)
nonanginal
68 (41%)
18 (13%)
nontypical
41 (25%)
9 (6.5%)
typical
16 (9.8%)
7 (5.0%)
RestBP
130 (120, 140)
130 (120, 145)
Chol
235 (209, 267)
249 (218, 284)
Fbs
23 (14%)
22 (16%)
RestECG
0
95 (58%)
56 (40%)
1
1 (0.6%)
3 (2.2%)
2
68 (41%)
80 (58%)
MaxHR
161 (149, 172)
142 (125, 157)
ExAng
23 (14%)
76 (55%)
Oldpeak
0.20 (0.00, 1.03)
1.40 (0.55, 2.50)
Slope
1
106 (65%)
36 (26%)
2
49 (30%)
91 (65%)
3
9 (5.5%)
12 (8.6%)
Ca
0
130 (81%)
46 (33%)
1
21 (13%)
44 (32%)
2
7 (4.3%)
31 (22%)
3
3 (1.9%)
17 (12%)
Unknown
3
1
Thal
fixed
6 (3.7%)
12 (8.7%)
normal
129 (79%)
37 (27%)
reversable
28 (17%)
89 (64%)
Unknown
1
1
# Graphical summary:
# Heart |> ggpairs()
Initial split into test and non-test sets
We randomly split the data into 25% test data and 75% non-test data. Stratify on AHD
.
# For reproducibility
set.seed (203 )
data_split <- initial_split (
Heart,
# stratify by AHD
strata = "AHD" ,
prop = 0.75
)
data_split
<Training/Testing/Total>
<227/76/303>
Heart_other <- training (data_split)
dim (Heart_other)
Heart_test <- testing (data_split)
dim (Heart_test)
Recipe (R) and Preprocessing (Python)
A data dictionary (roughly) is at https://keras.io/examples/structured_data/structured_data_classification_with_feature_space/ .
We have following features:
Numerical features: Age
, RestBP
, Chol
, Slope
(1, 2 or 3), MaxHR
, ExAng
, Oldpeak
, Ca
(0, 1, 2 or 3).
Categorical features coded as integer: Sex
(0 or 1), Fbs
(0 or 1), RestECG
(0, 1 or 2).
Categorical features coded as string: ChestPain
, Thal
.
There are missing values in Ca
and Thal
. Since missing proportion is not high, we will use simple mean (for numerical feature Ca
) and mode (for categorical feature Thal
) imputation.
heart_recipe <-
recipe (
AHD ~ .,
data = Heart_other
) |>
# mean imputation for Ca
step_impute_mean (Ca) |>
# mode imputation for Thal
step_impute_mode (Thal) |>
# create traditional dummy variables (necessary for xgboost)
step_dummy (all_nominal_predictors ()) |>
# zero-variance filter
step_zv (all_predictors ())
# estimate the means and standard deviations
# prep(training = Heart_other, retain = TRUE)
heart_recipe
── Recipe ──────────────────────────────────────────────────────────────────────
Number of variables by role
• Mean imputation for: Ca
• Mode imputation for: Thal
• Dummy variables from: all_nominal_predictors()
• Zero variance filter on: all_predictors()
Base models
We will use three different model definitions to try to classify the outcome AHD
: logistic regression, random forest, and neural network.
First we set up the cross-validation folds to be shared by all models.
set.seed (203 )
folds <- vfold_cv (Heart_other, v = 5 )
Logistic regression
Set up model.
logit_mod <-
logistic_reg (
penalty = tune (),
mixture = tune ()
) |>
set_engine ("glmnet" , standardize = TRUE )
logit_mod
Logistic Regression Model Specification (classification)
Main Arguments:
penalty = tune()
mixture = tune()
Engine-Specific Arguments:
standardize = TRUE
Computational engine: glmnet
Bundle the recipe (R) and model into workflow.
logit_wf <- workflow () |>
add_recipe (heart_recipe) |>
add_model (logit_mod)
logit_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()
── Preprocessor ────────────────────────────────────────────────────────────────
4 Recipe Steps
• step_impute_mean()
• step_impute_mode()
• step_dummy()
• step_zv()
── Model ───────────────────────────────────────────────────────────────────────
Logistic Regression Model Specification (classification)
Main Arguments:
penalty = tune()
mixture = tune()
Engine-Specific Arguments:
standardize = TRUE
Computational engine: glmnet
Set up tuning grid
logit_grid <- grid_regular (
penalty (range = c (- 6 , 3 )),
mixture (),
levels = c (100 , 5 )
)
logit_res <-
tune_grid (
object = logit_wf,
resamples = folds,
grid = logit_grid,
control = control_stack_grid ()
)
logit_res
# Tuning results
# 5-fold cross-validation
# A tibble: 5 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [181/46]> Fold1 <tibble [1,000 × 6]> <tibble [0 × 3]> <tibble>
2 <split [181/46]> Fold2 <tibble [1,000 × 6]> <tibble [0 × 3]> <tibble>
3 <split [182/45]> Fold3 <tibble [1,000 × 6]> <tibble [0 × 3]> <tibble>
4 <split [182/45]> Fold4 <tibble [1,000 × 6]> <tibble [0 × 3]> <tibble>
5 <split [182/45]> Fold5 <tibble [1,000 × 6]> <tibble [0 × 3]> <tibble>
Random forest
Set up model.
rf_mod <-
rand_forest (
mode = "classification" ,
# Number of predictors randomly sampled in each split
mtry = tune (),
# Number of trees in ensemble
trees = tune ()
) |>
set_engine ("ranger" )
rf_mod
Random Forest Model Specification (classification)
Main Arguments:
mtry = tune()
trees = tune()
Computational engine: ranger
Bundle the recipe (R) and model into workflow.
rf_wf <- workflow () |>
add_recipe (heart_recipe) |>
add_model (rf_mod)
rf_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()
── Preprocessor ────────────────────────────────────────────────────────────────
4 Recipe Steps
• step_impute_mean()
• step_impute_mode()
• step_dummy()
• step_zv()
── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)
Main Arguments:
mtry = tune()
trees = tune()
Computational engine: ranger
Set up tuning grid
rf_grid <- grid_regular (
trees (range = c (100L, 500L)),
mtry (range = c (1L, 5L)),
levels = c (5 , 5 )
)
rf_res <-
tune_grid (
object = rf_wf,
resamples = folds,
grid = rf_grid,
control = control_stack_grid ()
)
rf_res
# Tuning results
# 5-fold cross-validation
# A tibble: 5 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [181/46]> Fold1 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,150 × 8]>
2 <split [181/46]> Fold2 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,150 × 8]>
3 <split [182/45]> Fold3 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,125 × 8]>
4 <split [182/45]> Fold4 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,125 × 8]>
5 <split [182/45]> Fold5 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,125 × 8]>
Neural network
Set up model.
mlp_mod <-
mlp (
mode = "classification" ,
hidden_units = tune (),
dropout = tune (),
epochs = 50 ,
) |>
set_engine ("keras" , verbose = 0 )
mlp_mod
Single Layer Neural Network Model Specification (classification)
Main Arguments:
hidden_units = tune()
dropout = tune()
epochs = 50
Engine-Specific Arguments:
verbose = 0
Computational engine: keras
Bundle the recipe (R) and model into workflow.
mlp_wf <- workflow () |>
add_recipe (heart_recipe) |>
add_model (mlp_mod)
mlp_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: mlp()
── Preprocessor ────────────────────────────────────────────────────────────────
4 Recipe Steps
• step_impute_mean()
• step_impute_mode()
• step_dummy()
• step_zv()
── Model ───────────────────────────────────────────────────────────────────────
Single Layer Neural Network Model Specification (classification)
Main Arguments:
hidden_units = tune()
dropout = tune()
epochs = 50
Engine-Specific Arguments:
verbose = 0
Computational engine: keras
Set up tuning grid
mlp_grid <- grid_regular (
hidden_units (range = c (1 , 20 )),
dropout (range = c (0 , 0.6 )),
levels = 5
)
mlp_res <-
tune_grid (
object = mlp_wf,
resamples = folds,
grid = mlp_grid,
control = control_stack_grid ()
)
2/2 - 0s - 26ms/epoch - 13ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 17ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 21ms/epoch - 11ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 19ms/epoch - 10ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 21ms/epoch - 10ms/step
2/2 - 0s - 9ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 6ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 15ms/epoch - 7ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 15ms/epoch - 8ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 8ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 19ms/epoch - 10ms/step
2/2 - 0s - 9ms/epoch - 4ms/step
2/2 - 0s - 21ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 9ms/epoch - 4ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 17ms/epoch - 8ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 22ms/epoch - 11ms/step
2/2 - 0s - 9ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 6ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 15ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 269ms/epoch - 135ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 21ms/epoch - 10ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
# Tuning results
# 5-fold cross-validation
# A tibble: 5 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [181/46]> Fold1 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,150 × 8]>
2 <split [181/46]> Fold2 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,150 × 8]>
3 <split [182/45]> Fold3 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,125 × 8]>
4 <split [182/45]> Fold4 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,125 × 8]>
5 <split [182/45]> Fold5 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,125 × 8]>
Model stacking
Build the stacked ensemble.
heart_model_st <-
# initialize the stack
stacks () |>
# add candidate members
add_candidates (logit_res) |>
add_candidates (rf_res) |>
add_candidates (mlp_res) |>
# determine how to combine their predictions
blend_predictions (
penalty = 10 ^ (- 6 : 2 ),
metrics = c ("roc_auc" )
) |>
# fit the candidates with nonzero stacking coefficients
fit_members ()
Warning: Predictions from 602 candidates were identical to those from existing
candidates and were removed from the data stack.
Warning: The `...` are not used in this function but one or more arguments were
passed: 'metrics'
── A stacked ensemble model ─────────────────────────────────────
Out of 249 possible candidate members, the ensemble retained 9.
Penalty: 0.1.
Mixture: 1.
The 9 highest weighted member classes are:
# A tibble: 9 × 3
member type weight
<chr> <chr> <dbl>
1 .pred_Yes_rf_res_1_01 rand_forest 3.48
2 .pred_Yes_rf_res_1_11 rand_forest 0.844
3 .pred_Yes_logit_res_1_433 logistic_reg 0.612
4 .pred_Yes_logit_res_1_053 logistic_reg 0.129
5 .pred_Yes_logit_res_1_054 logistic_reg 0.122
6 .pred_Yes_logit_res_1_434 logistic_reg 0.0363
7 .pred_Yes_logit_res_1_055 logistic_reg 0.0149
8 .pred_Yes_rf_res_1_04 rand_forest 0.0139
9 .pred_Yes_logit_res_1_432 logistic_reg 0.000163
Plot the result.
To show the relationship more directly:
autoplot (heart_model_st, type = "members" )
To see the top results:
autoplot (heart_model_st, type = "weights" )
To identify which model configurations were assigned what stacking coefficients, we can make use of the collect_parameters()
function:
collect_parameters (heart_model_st, "rf_res" )
# A tibble: 25 × 5
member mtry trees terms coef
<chr> <int> <int> <chr> <dbl>
1 rf_res_1_01 1 100 .pred_Yes_rf_res_1_01 3.48
2 rf_res_1_02 1 200 .pred_Yes_rf_res_1_02 0
3 rf_res_1_03 1 300 .pred_Yes_rf_res_1_03 0
4 rf_res_1_04 1 400 .pred_Yes_rf_res_1_04 0.0139
5 rf_res_1_05 1 500 .pred_Yes_rf_res_1_05 0
6 rf_res_1_06 2 100 .pred_Yes_rf_res_1_06 0
7 rf_res_1_07 2 200 .pred_Yes_rf_res_1_07 0
8 rf_res_1_08 2 300 .pred_Yes_rf_res_1_08 0
9 rf_res_1_09 2 400 .pred_Yes_rf_res_1_09 0
10 rf_res_1_10 2 500 .pred_Yes_rf_res_1_10 0
# ℹ 15 more rows
Final classification
heart_pred <- Heart_test %>%
bind_cols (predict (heart_model_st, ., type = "prob" )) %>%
print (width = Inf )
# A tibble: 76 × 16
Age Sex ChestPain RestBP Chol Fbs RestECG MaxHR ExAng Oldpeak Slope
<dbl> <dbl> <chr> <dbl> <dbl> <dbl> <chr> <dbl> <dbl> <dbl> <dbl>
1 63 1 typical 145 233 1 2 150 0 2.3 3
2 67 1 asymptomatic 120 229 0 2 129 1 2.6 2
3 37 1 nonanginal 130 250 0 0 187 0 3.5 3
4 56 1 nontypical 120 236 0 0 178 0 0.8 1
5 44 1 nontypical 120 263 0 0 173 0 0 1
6 52 1 nonanginal 172 199 1 0 162 0 0.5 1
7 48 1 nontypical 110 229 0 0 168 0 1 3
8 64 1 typical 110 211 0 2 144 1 1.8 2
9 60 1 asymptomatic 130 206 0 2 132 1 2.4 2
10 43 1 asymptomatic 150 247 0 0 171 0 1.5 1
Ca Thal AHD .pred_No .pred_Yes
<dbl> <chr> <fct> <dbl> <dbl>
1 0 fixed No 0.478 0.522
2 2 reversable Yes 0.125 0.875
3 0 normal No 0.750 0.250
4 0 normal No 0.855 0.145
5 0 reversable No 0.704 0.296
6 0 reversable No 0.697 0.303
7 0 reversable Yes 0.617 0.383
8 0 normal No 0.527 0.473
9 2 reversable Yes 0.107 0.893
10 0 normal No 0.775 0.225
# ℹ 66 more rows
Computing the ROC AUC for the model:
yardstick:: roc_auc (
heart_pred,
truth = AHD,
contains (".pred_No" )
)
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 roc_auc binary 0.853
We can use the members argument to generate predictions from each of the ensemble members.
heart_pred <-
Heart_test |>
select (AHD) |>
bind_cols (
predict (
heart_model_st,
Heart_test,
type = "class" ,
members = TRUE
)
) |>
print (width = Inf )
# A tibble: 76 × 11
AHD .pred_class .pred_class_logit_res_1_432 .pred_class_logit_res_1_433
<fct> <fct> <fct> <fct>
1 No Yes No No
2 Yes Yes Yes Yes
3 No No No No
4 No No No No
5 No No No No
6 No No No No
7 Yes No Yes Yes
8 No No No No
9 Yes Yes Yes Yes
10 No No No No
.pred_class_logit_res_1_434 .pred_class_logit_res_1_053
<fct> <fct>
1 No Yes
2 Yes Yes
3 No No
4 No No
5 No No
6 No No
7 Yes No
8 No No
9 Yes Yes
10 No No
.pred_class_logit_res_1_054 .pred_class_logit_res_1_055
<fct> <fct>
1 Yes Yes
2 Yes Yes
3 No No
4 No No
5 No No
6 No No
7 No No
8 No No
9 Yes Yes
10 No No
.pred_class_rf_res_1_01 .pred_class_rf_res_1_04 .pred_class_rf_res_1_11
<fct> <fct> <fct>
1 Yes Yes No
2 Yes Yes Yes
3 No No No
4 No No No
5 No No No
6 No No No
7 No No No
8 Yes Yes Yes
9 Yes Yes Yes
10 No No No
# ℹ 66 more rows
map (
colnames (heart_pred),
~ mean (heart_pred$ AHD == pull (heart_pred, .x))
) |>
set_names (colnames (heart_pred)) |>
as_tibble () |>
pivot_longer (c (everything (), - AHD))
# A tibble: 10 × 3
AHD name value
<dbl> <chr> <dbl>
1 1 .pred_class 0.737
2 1 .pred_class_logit_res_1_432 0.789
3 1 .pred_class_logit_res_1_433 0.789
4 1 .pred_class_logit_res_1_434 0.789
5 1 .pred_class_logit_res_1_053 0.763
6 1 .pred_class_logit_res_1_054 0.776
7 1 .pred_class_logit_res_1_055 0.763
8 1 .pred_class_rf_res_1_01 0.737
9 1 .pred_class_rf_res_1_04 0.737
10 1 .pred_class_rf_res_1_11 0.763