Machine Learning Workflow: Model Stacking (Heart Data)

Biostat 203B

Author

Dr. Hua Zhou @ UCLA

Published

February 27, 2024

1 Setup

Display system information for reproducibility.

sessionInfo()
R version 4.3.2 (2023-10-31)
Platform: aarch64-unknown-linux-gnu (64-bit)
Running under: Ubuntu 22.04.3 LTS

Matrix products: default
BLAS:   /usr/lib/aarch64-linux-gnu/openblas-pthread/libblas.so.3 
LAPACK: /usr/lib/aarch64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so;  LAPACK version 3.10.0

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

time zone: Etc/UTC
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
 [1] htmlwidgets_1.6.4 compiler_4.3.2    fastmap_1.1.1     cli_3.6.2        
 [5] tools_4.3.2       htmltools_0.5.7   rstudioapi_0.15.0 yaml_2.3.8       
 [9] rmarkdown_2.25    knitr_1.45        jsonlite_1.8.8    xfun_0.42        
[13] digest_0.6.34     rlang_1.1.3       evaluate_0.23    
import IPython
print(IPython.sys_info())
using InteractiveUtils
versioninfo()

2 Overview

We illustrate the typical machine learning workflow for model stacking using the Heart data set. The outcome is AHD (Yes or No), using the stacks package.

Model stacking is an ensembling method that takes the outputs of many models and combines them to generate a new model—referred to as an ensemble in this package—that generates predictions informed by each of its members.

  1. Initial splitting to test and non-test sets.

  2. Pre-processing of data: dummy coding categorical variables, standardizing numerical variables, imputing missing values, …

  3. Tune the gradient boosting algorithm using 5-fold cross-validation (CV) on the non-test data.

  4. Choose the best model by CV and refit it on the whole non-test data.

  5. Final classification on the test data.

3 Heart data

The goal is to predict the binary outcome AHD (Yes or No) of patients.

# Load libraries
library(GGally)
Loading required package: ggplot2
Registered S3 method overwritten by 'GGally':
  method from   
  +.gg   ggplot2
library(gtsummary)
library(keras)
library(ranger)
library(stacks)
library(tidyverse)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr     1.1.4     ✔ readr     2.1.5
✔ forcats   1.0.0     ✔ stringr   1.5.1
✔ lubridate 1.9.3     ✔ tibble    3.2.1
✔ purrr     1.0.2     ✔ tidyr     1.3.1
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(tidymodels)
── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ──
✔ broom        1.0.5      ✔ rsample      1.2.0 
✔ dials        1.2.1      ✔ tune         1.1.2 
✔ infer        1.0.6      ✔ workflows    1.1.4 
✔ modeldata    1.3.0      ✔ workflowsets 1.0.1 
✔ parsnip      1.2.0      ✔ yardstick    1.3.0 
✔ recipes      1.0.10     
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ recipes::all_double()    masks gtsummary::all_double()
✖ recipes::all_factor()    masks gtsummary::all_factor()
✖ recipes::all_integer()   masks gtsummary::all_integer()
✖ recipes::all_logical()   masks gtsummary::all_logical()
✖ recipes::all_numeric()   masks gtsummary::all_numeric()
✖ scales::discard()        masks purrr::discard()
✖ dplyr::filter()          masks stats::filter()
✖ recipes::fixed()         masks stringr::fixed()
✖ yardstick::get_weights() masks keras::get_weights()
✖ dplyr::lag()             masks stats::lag()
✖ yardstick::spec()        masks readr::spec()
✖ recipes::step()          masks stats::step()
• Learn how to get started at https://www.tidymodels.org/start/
library(xgboost)

Attaching package: 'xgboost'

The following object is masked from 'package:dplyr':

    slice
# Load the `Heart.csv` data.
Heart <- read_csv("Heart.csv") |> 
  # first column is patient ID, which we don't need
  select(-1) |>
  # RestECG is categorical with value 0, 1, 2
  mutate(RestECG = as.character(RestECG)) |>
  mutate(AHD = as.factor(AHD)) |>
  print(width = Inf)
New names:
Rows: 303 Columns: 15
── Column specification
──────────────────────────────────────────────────────── Delimiter: "," chr
(3): ChestPain, Thal, AHD dbl (12): ...1, Age, Sex, RestBP, Chol, Fbs, RestECG,
MaxHR, ExAng, Oldpeak,...
ℹ Use `spec()` to retrieve the full column specification for this data. ℹ
Specify the column types or set `show_col_types = FALSE` to quiet this message.
• `` -> `...1`
# A tibble: 303 × 14
     Age   Sex ChestPain    RestBP  Chol   Fbs RestECG MaxHR ExAng Oldpeak Slope
   <dbl> <dbl> <chr>         <dbl> <dbl> <dbl> <chr>   <dbl> <dbl>   <dbl> <dbl>
 1    63     1 typical         145   233     1 2         150     0     2.3     3
 2    67     1 asymptomatic    160   286     0 2         108     1     1.5     2
 3    67     1 asymptomatic    120   229     0 2         129     1     2.6     2
 4    37     1 nonanginal      130   250     0 0         187     0     3.5     3
 5    41     0 nontypical      130   204     0 2         172     0     1.4     1
 6    56     1 nontypical      120   236     0 0         178     0     0.8     1
 7    62     0 asymptomatic    140   268     0 2         160     0     3.6     3
 8    57     0 asymptomatic    120   354     0 0         163     1     0.6     1
 9    63     1 asymptomatic    130   254     0 2         147     0     1.4     2
10    53     1 asymptomatic    140   203     1 2         155     1     3.1     3
      Ca Thal       AHD  
   <dbl> <chr>      <fct>
 1     0 fixed      No   
 2     3 normal     Yes  
 3     2 reversable Yes  
 4     0 normal     No   
 5     0 normal     No   
 6     0 normal     No   
 7     2 normal     Yes  
 8     0 normal     No   
 9     1 reversable Yes  
10     0 reversable Yes  
# ℹ 293 more rows
# Numerical summaries stratified by the outcome `AHD`.
Heart |> tbl_summary(by = AHD)
Characteristic No, N = 1641 Yes, N = 1391
Age 52 (45, 59) 58 (52, 62)
Sex 92 (56%) 114 (82%)
ChestPain

    asymptomatic 39 (24%) 105 (76%)
    nonanginal 68 (41%) 18 (13%)
    nontypical 41 (25%) 9 (6.5%)
    typical 16 (9.8%) 7 (5.0%)
RestBP 130 (120, 140) 130 (120, 145)
Chol 235 (209, 267) 249 (218, 284)
Fbs 23 (14%) 22 (16%)
RestECG

    0 95 (58%) 56 (40%)
    1 1 (0.6%) 3 (2.2%)
    2 68 (41%) 80 (58%)
MaxHR 161 (149, 172) 142 (125, 157)
ExAng 23 (14%) 76 (55%)
Oldpeak 0.20 (0.00, 1.03) 1.40 (0.55, 2.50)
Slope

    1 106 (65%) 36 (26%)
    2 49 (30%) 91 (65%)
    3 9 (5.5%) 12 (8.6%)
Ca

    0 130 (81%) 46 (33%)
    1 21 (13%) 44 (32%)
    2 7 (4.3%) 31 (22%)
    3 3 (1.9%) 17 (12%)
    Unknown 3 1
Thal

    fixed 6 (3.7%) 12 (8.7%)
    normal 129 (79%) 37 (27%)
    reversable 28 (17%) 89 (64%)
    Unknown 1 1
1 Median (IQR); n (%)
# Graphical summary:
# Heart |> ggpairs()

TODO

3.1 Julia

TODO

4 Initial split into test and non-test sets

We randomly split the data into 25% test data and 75% non-test data. Stratify on AHD.

# For reproducibility
set.seed(203)

data_split <- initial_split(
  Heart, 
  # stratify by AHD
  strata = "AHD", 
  prop = 0.75
  )
data_split
<Training/Testing/Total>
<227/76/303>
Heart_other <- training(data_split)
dim(Heart_other)
[1] 227  14
Heart_test <- testing(data_split)
dim(Heart_test)
[1] 76 14

TODO

TODO

5 Recipe (R) and Preprocessing (Python)

  • A data dictionary (roughly) is at https://keras.io/examples/structured_data/structured_data_classification_with_feature_space/.

  • We have following features:

    • Numerical features: Age, RestBP, Chol, Slope (1, 2 or 3), MaxHR, ExAng, Oldpeak, Ca (0, 1, 2 or 3).

    • Categorical features coded as integer: Sex (0 or 1), Fbs (0 or 1), RestECG (0, 1 or 2).

    • Categorical features coded as string: ChestPain, Thal.

  • There are missing values in Ca and Thal. Since missing proportion is not high, we will use simple mean (for numerical feature Ca) and mode (for categorical feature Thal) imputation.

heart_recipe <- 
  recipe(
    AHD ~ ., 
    data = Heart_other
  ) |>
  # mean imputation for Ca
  step_impute_mean(Ca) |>
  # mode imputation for Thal
  step_impute_mode(Thal) |>
  # create traditional dummy variables (necessary for xgboost)
  step_dummy(all_nominal_predictors()) |>
  # zero-variance filter
  step_zv(all_predictors())
  # estimate the means and standard deviations
  # prep(training = Heart_other, retain = TRUE)
heart_recipe
── Recipe ──────────────────────────────────────────────────────────────────────
── Inputs 
Number of variables by role
outcome:    1
predictor: 13
── Operations 
• Mean imputation for: Ca
• Mode imputation for: Thal
• Dummy variables from: all_nominal_predictors()
• Zero variance filter on: all_predictors()

TODO

5.1 Julia

TODO

6 Base models

We will use three different model definitions to try to classify the outcome AHD: logistic regression, random forest, and neural network.

First we set up the cross-validation folds to be shared by all models.

set.seed(203)
folds <- vfold_cv(Heart_other, v = 5)

6.1 Logistic regression

Set up model.

logit_mod <- 
  logistic_reg(
    penalty = tune(), 
    mixture = tune()
  ) |> 
  set_engine("glmnet", standardize = TRUE)
logit_mod
Logistic Regression Model Specification (classification)

Main Arguments:
  penalty = tune()
  mixture = tune()

Engine-Specific Arguments:
  standardize = TRUE

Computational engine: glmnet 

TODO

TODO

Bundle the recipe (R) and model into workflow.

logit_wf <- workflow() |>
  add_recipe(heart_recipe) |>
  add_model(logit_mod)
logit_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
4 Recipe Steps

• step_impute_mean()
• step_impute_mode()
• step_dummy()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Logistic Regression Model Specification (classification)

Main Arguments:
  penalty = tune()
  mixture = tune()

Engine-Specific Arguments:
  standardize = TRUE

Computational engine: glmnet 

TODO

TODO

Set up tuning grid

logit_grid <- grid_regular(
  penalty(range = c(-6, 3)), 
  mixture(),
  levels = c(100, 5)
  )

logit_res <- 
  tune_grid(
    object = logit_wf, 
    resamples = folds, 
    grid = logit_grid,
    control = control_stack_grid()
  )
logit_res
# Tuning results
# 5-fold cross-validation 
# A tibble: 5 × 5
  splits           id    .metrics             .notes           .predictions
  <list>           <chr> <list>               <list>           <list>      
1 <split [181/46]> Fold1 <tibble [1,000 × 6]> <tibble [0 × 3]> <tibble>    
2 <split [181/46]> Fold2 <tibble [1,000 × 6]> <tibble [0 × 3]> <tibble>    
3 <split [182/45]> Fold3 <tibble [1,000 × 6]> <tibble [0 × 3]> <tibble>    
4 <split [182/45]> Fold4 <tibble [1,000 × 6]> <tibble [0 × 3]> <tibble>    
5 <split [182/45]> Fold5 <tibble [1,000 × 6]> <tibble [0 × 3]> <tibble>    

TODO

TODO

6.2 Random forest

Set up model.

rf_mod <- 
  rand_forest(
    mode = "classification",
    # Number of predictors randomly sampled in each split
    mtry = tune(),
    # Number of trees in ensemble
    trees = tune()
  ) |>
  set_engine("ranger")
rf_mod
Random Forest Model Specification (classification)

Main Arguments:
  mtry = tune()
  trees = tune()

Computational engine: ranger 

TODO

TODO

Bundle the recipe (R) and model into workflow.

rf_wf <- workflow() |>
  add_recipe(heart_recipe) |>
  add_model(rf_mod)
rf_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()

── Preprocessor ────────────────────────────────────────────────────────────────
4 Recipe Steps

• step_impute_mean()
• step_impute_mode()
• step_dummy()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)

Main Arguments:
  mtry = tune()
  trees = tune()

Computational engine: ranger 

TODO

TODO

Set up tuning grid

rf_grid <- grid_regular(
  trees(range = c(100L, 500L)), 
  mtry(range = c(1L, 5L)),
  levels = c(5, 5)
  )

rf_res <- 
  tune_grid(
    object = rf_wf, 
    resamples = folds, 
    grid = rf_grid,
    control = control_stack_grid()
  )
rf_res
# Tuning results
# 5-fold cross-validation 
# A tibble: 5 × 5
  splits           id    .metrics          .notes           .predictions        
  <list>           <chr> <list>            <list>           <list>              
1 <split [181/46]> Fold1 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,150 × 8]>
2 <split [181/46]> Fold2 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,150 × 8]>
3 <split [182/45]> Fold3 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,125 × 8]>
4 <split [182/45]> Fold4 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,125 × 8]>
5 <split [182/45]> Fold5 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,125 × 8]>

TODO

TODO

6.3 Neural network

Set up model.

mlp_mod <- 
  mlp(
    mode = "classification",
    hidden_units = tune(),
    dropout = tune(),
    epochs = 50,
  ) |> 
  set_engine("keras", verbose = 0)
mlp_mod
Single Layer Neural Network Model Specification (classification)

Main Arguments:
  hidden_units = tune()
  dropout = tune()
  epochs = 50

Engine-Specific Arguments:
  verbose = 0

Computational engine: keras 

TODO

TODO

Bundle the recipe (R) and model into workflow.

mlp_wf <- workflow() |>
  add_recipe(heart_recipe) |>
  add_model(mlp_mod)
mlp_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: mlp()

── Preprocessor ────────────────────────────────────────────────────────────────
4 Recipe Steps

• step_impute_mean()
• step_impute_mode()
• step_dummy()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Single Layer Neural Network Model Specification (classification)

Main Arguments:
  hidden_units = tune()
  dropout = tune()
  epochs = 50

Engine-Specific Arguments:
  verbose = 0

Computational engine: keras 

TODO

TODO

Set up tuning grid

mlp_grid <- grid_regular(
  hidden_units(range = c(1, 20)),
  dropout(range = c(0, 0.6)),
  levels = 5
  )

mlp_res <- 
  tune_grid(
    object = mlp_wf, 
    resamples = folds, 
    grid = mlp_grid,
    control = control_stack_grid()
  )
2/2 - 0s - 26ms/epoch - 13ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 17ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 21ms/epoch - 11ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 19ms/epoch - 10ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 21ms/epoch - 10ms/step
2/2 - 0s - 9ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 6ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 15ms/epoch - 7ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 15ms/epoch - 8ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 8ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 19ms/epoch - 10ms/step
2/2 - 0s - 9ms/epoch - 4ms/step
2/2 - 0s - 21ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 9ms/epoch - 4ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 17ms/epoch - 8ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 22ms/epoch - 11ms/step
2/2 - 0s - 9ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 6ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 15ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 269ms/epoch - 135ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 17ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 16ms/epoch - 8ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 18ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
2/2 - 0s - 21ms/epoch - 10ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 19ms/epoch - 9ms/step
2/2 - 0s - 7ms/epoch - 3ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 8ms/epoch - 4ms/step
2/2 - 0s - 20ms/epoch - 10ms/step
2/2 - 0s - 7ms/epoch - 4ms/step
mlp_res
# Tuning results
# 5-fold cross-validation 
# A tibble: 5 × 5
  splits           id    .metrics          .notes           .predictions        
  <list>           <chr> <list>            <list>           <list>              
1 <split [181/46]> Fold1 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,150 × 8]>
2 <split [181/46]> Fold2 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,150 × 8]>
3 <split [182/45]> Fold3 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,125 × 8]>
4 <split [182/45]> Fold4 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,125 × 8]>
5 <split [182/45]> Fold5 <tibble [50 × 6]> <tibble [0 × 3]> <tibble [1,125 × 8]>

TODO

TODO

7 Model stacking

Build the stacked ensemble.

heart_model_st <- 
  # initialize the stack
  stacks() |>
  # add candidate members
  add_candidates(logit_res) |>
  add_candidates(rf_res) |>
  add_candidates(mlp_res) |>
  # determine how to combine their predictions
  blend_predictions(
    penalty = 10^(-6:2),
    metrics = c("roc_auc")
    ) |>
  # fit the candidates with nonzero stacking coefficients
  fit_members()
Warning: Predictions from 602 candidates were identical to those from existing
candidates and were removed from the data stack.
Warning: The `...` are not used in this function but one or more arguments were
passed: 'metrics'
heart_model_st
── A stacked ensemble model ─────────────────────────────────────


Out of 249 possible candidate members, the ensemble retained 9.

Penalty: 0.1.

Mixture: 1.


The 9 highest weighted member classes are:
# A tibble: 9 × 3
  member                    type           weight
  <chr>                     <chr>           <dbl>
1 .pred_Yes_rf_res_1_01     rand_forest  3.48    
2 .pred_Yes_rf_res_1_11     rand_forest  0.844   
3 .pred_Yes_logit_res_1_433 logistic_reg 0.612   
4 .pred_Yes_logit_res_1_053 logistic_reg 0.129   
5 .pred_Yes_logit_res_1_054 logistic_reg 0.122   
6 .pred_Yes_logit_res_1_434 logistic_reg 0.0363  
7 .pred_Yes_logit_res_1_055 logistic_reg 0.0149  
8 .pred_Yes_rf_res_1_04     rand_forest  0.0139  
9 .pred_Yes_logit_res_1_432 logistic_reg 0.000163

Plot the result.

autoplot(heart_model_st)

To show the relationship more directly:

autoplot(heart_model_st, type = "members")

To see the top results:

autoplot(heart_model_st, type = "weights")

To identify which model configurations were assigned what stacking coefficients, we can make use of the collect_parameters() function:

collect_parameters(heart_model_st, "rf_res")
# A tibble: 25 × 5
   member       mtry trees terms                   coef
   <chr>       <int> <int> <chr>                  <dbl>
 1 rf_res_1_01     1   100 .pred_Yes_rf_res_1_01 3.48  
 2 rf_res_1_02     1   200 .pred_Yes_rf_res_1_02 0     
 3 rf_res_1_03     1   300 .pred_Yes_rf_res_1_03 0     
 4 rf_res_1_04     1   400 .pred_Yes_rf_res_1_04 0.0139
 5 rf_res_1_05     1   500 .pred_Yes_rf_res_1_05 0     
 6 rf_res_1_06     2   100 .pred_Yes_rf_res_1_06 0     
 7 rf_res_1_07     2   200 .pred_Yes_rf_res_1_07 0     
 8 rf_res_1_08     2   300 .pred_Yes_rf_res_1_08 0     
 9 rf_res_1_09     2   400 .pred_Yes_rf_res_1_09 0     
10 rf_res_1_10     2   500 .pred_Yes_rf_res_1_10 0     
# ℹ 15 more rows

TODO

8 Final classification

heart_pred <- Heart_test %>%
  bind_cols(predict(heart_model_st, ., type = "prob")) %>%
  print(width = Inf)
# A tibble: 76 × 16
     Age   Sex ChestPain    RestBP  Chol   Fbs RestECG MaxHR ExAng Oldpeak Slope
   <dbl> <dbl> <chr>         <dbl> <dbl> <dbl> <chr>   <dbl> <dbl>   <dbl> <dbl>
 1    63     1 typical         145   233     1 2         150     0     2.3     3
 2    67     1 asymptomatic    120   229     0 2         129     1     2.6     2
 3    37     1 nonanginal      130   250     0 0         187     0     3.5     3
 4    56     1 nontypical      120   236     0 0         178     0     0.8     1
 5    44     1 nontypical      120   263     0 0         173     0     0       1
 6    52     1 nonanginal      172   199     1 0         162     0     0.5     1
 7    48     1 nontypical      110   229     0 0         168     0     1       3
 8    64     1 typical         110   211     0 2         144     1     1.8     2
 9    60     1 asymptomatic    130   206     0 2         132     1     2.4     2
10    43     1 asymptomatic    150   247     0 0         171     0     1.5     1
      Ca Thal       AHD   .pred_No .pred_Yes
   <dbl> <chr>      <fct>    <dbl>     <dbl>
 1     0 fixed      No       0.478     0.522
 2     2 reversable Yes      0.125     0.875
 3     0 normal     No       0.750     0.250
 4     0 normal     No       0.855     0.145
 5     0 reversable No       0.704     0.296
 6     0 reversable No       0.697     0.303
 7     0 reversable Yes      0.617     0.383
 8     0 normal     No       0.527     0.473
 9     2 reversable Yes      0.107     0.893
10     0 normal     No       0.775     0.225
# ℹ 66 more rows

Computing the ROC AUC for the model:

yardstick::roc_auc(
  heart_pred,
  truth = AHD,
  contains(".pred_No")
  )
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 roc_auc binary         0.853

We can use the members argument to generate predictions from each of the ensemble members.

heart_pred <-
  Heart_test |>
  select(AHD) |>
  bind_cols(
    predict(
      heart_model_st,
      Heart_test,
      type = "class",
      members = TRUE
      )
    ) |>
  print(width = Inf)
# A tibble: 76 × 11
   AHD   .pred_class .pred_class_logit_res_1_432 .pred_class_logit_res_1_433
   <fct> <fct>       <fct>                       <fct>                      
 1 No    Yes         No                          No                         
 2 Yes   Yes         Yes                         Yes                        
 3 No    No          No                          No                         
 4 No    No          No                          No                         
 5 No    No          No                          No                         
 6 No    No          No                          No                         
 7 Yes   No          Yes                         Yes                        
 8 No    No          No                          No                         
 9 Yes   Yes         Yes                         Yes                        
10 No    No          No                          No                         
   .pred_class_logit_res_1_434 .pred_class_logit_res_1_053
   <fct>                       <fct>                      
 1 No                          Yes                        
 2 Yes                         Yes                        
 3 No                          No                         
 4 No                          No                         
 5 No                          No                         
 6 No                          No                         
 7 Yes                         No                         
 8 No                          No                         
 9 Yes                         Yes                        
10 No                          No                         
   .pred_class_logit_res_1_054 .pred_class_logit_res_1_055
   <fct>                       <fct>                      
 1 Yes                         Yes                        
 2 Yes                         Yes                        
 3 No                          No                         
 4 No                          No                         
 5 No                          No                         
 6 No                          No                         
 7 No                          No                         
 8 No                          No                         
 9 Yes                         Yes                        
10 No                          No                         
   .pred_class_rf_res_1_01 .pred_class_rf_res_1_04 .pred_class_rf_res_1_11
   <fct>                   <fct>                   <fct>                  
 1 Yes                     Yes                     No                     
 2 Yes                     Yes                     Yes                    
 3 No                      No                      No                     
 4 No                      No                      No                     
 5 No                      No                      No                     
 6 No                      No                      No                     
 7 No                      No                      No                     
 8 Yes                     Yes                     Yes                    
 9 Yes                     Yes                     Yes                    
10 No                      No                      No                     
# ℹ 66 more rows
map(
  colnames(heart_pred),
  ~mean(heart_pred$AHD == pull(heart_pred, .x))
  ) |>
  set_names(colnames(heart_pred)) |>
  as_tibble() |>
  pivot_longer(c(everything(), -AHD))
# A tibble: 10 × 3
     AHD name                        value
   <dbl> <chr>                       <dbl>
 1     1 .pred_class                 0.737
 2     1 .pred_class_logit_res_1_432 0.789
 3     1 .pred_class_logit_res_1_433 0.789
 4     1 .pred_class_logit_res_1_434 0.789
 5     1 .pred_class_logit_res_1_053 0.763
 6     1 .pred_class_logit_res_1_054 0.776
 7     1 .pred_class_logit_res_1_055 0.763
 8     1 .pred_class_rf_res_1_01     0.737
 9     1 .pred_class_rf_res_1_04     0.737
10     1 .pred_class_rf_res_1_11     0.763