Making Prediction Using Tidymodels in Medicine

Kamarul Imran Musa (Assoc Prof)

School of Medical Sciences, Universiti Sains Malaysia

Motivation

Physician usually have these objectives when looking at patients data:

  • to understand factors why patients develop certain diseases (inference)
  • to guess the most probable diagnosis or outcome for a patient (prediction)

This talk is about the second objective.

For example,

  • physicians working at ED will want to guess the correct triage category when a patient is brought to ED with chest pain, sweating and history of diabetes mellitus

  • an epidemiologist may want to pick the best health outcome for patients who smoke cigarettes, practice sedentary life style and has diabetes mellitus.

Slides

Limitation

However, physicians’ minds, no matter how bright or experienced will not able

  • to store, recall, and correctly analyze many medical information simultenously
  • to optimally and accurately guess the diagnosis or outcome in complicated medical problems

Predictive analytics helps physicians make more accurate guess:

  • using machine learning or ML methods
  • by allowing them to augment his educated guess using computer-intensive methods
  • by analyzing data such as many complains, signs and symptoms and clinical conditions systematically (human only relies on their brains and experience hence less reproducible).

The end result is more accurate guess or prediction of a diagnosis or outcome.

About me

R Book with Dr Wan Nor Arifin

About this talk

  • It provides very high level of doing predictive analytics in medicine using R language
  • It demonstrates very briefly about tidymodels package
  • It helps physicians to understand concepts and of machine learning workflow in RStudio IDE

Tidymodels

Inference

In clinical and medical research scenario:

  • physicians get a set of data
  • with the help of disease modellers, they use statistical modelling (most of the time) to make inferences.
  • This is inferential statistics

Inferential statistics is formally defined as

  • a field of statistics that uses analytical tools for drawing conclusions about a population by examining random samples.

Inference

The goal of inferential statistics is to make generalizations about a population. For example, physicians want to understand the relationship between certain variables (aka risk factors) with having a disease or having a certain outcome of the disease.

Read more here

Is this inference

Prediction

In prediction, physicians use existing data set, and then they choose models or algorithms, so at the end they can reliably choose the correct diagnosis or outcome of a disease.

The outcome can be categorical

  • such as fatality (alive or dead), complications (yes or no).
  • These are examples of classification problem.

The outcome can be values

  • such as values of fasting blood sugar, quality of life scores, disability scores, expression.
  • These are examples of regression problems.

Prediction

To perform prediction (predictive analytics), physicians use machine learning methods. For example

  • support vector machine classifier to predict clinical deterioration on magnetic resonance imaging,
  • random forest to predict cancer diagnosis and
  • deep learning to classify mammogram images into cancerous or non-cancerous.

Usually predictive analytics can be grouped into

  • supervised learning
  • unsupervised learning
  • reinforcement learning

Supervised learning

Supervised learning is a machine learning approach that’s defined by its use of labeled datasets.

  • Regression problems or models: For models predicting a numeric outcome. A type of supervised learning method that uses an algorithm to understand the relationship between dependent and independent variables. Regression models are helpful for predicting numerical values based on different data points, such as sales revenue projections for a given business.

  • Classification problems or models: For models predicting a categorical response. It uses an algorithm to accurately assign test data into specific categories, such as separating apples from oranges. Or, in the real world, supervised learning algorithms can be used to classify spam in a separate folder from your inbox.

Tidymodels

The tidymodels framework is a collection of packages for modeling and machine learning using tidyverse principles.

Tidymodels

  • rsample : to split sample (e.g. train/test or cross-validation)
  • recipes : for pre-processing
  • workflow : workflows bundle pre-processing, modeling, and post-processing together

Tidymodels

The tidymodels framework is a collection of packages for modeling and machine learning using tidyverse principles.

Tidymodels

  • tune : tune helps you optimize the hyperparameters of your model and pre-processing steps.
  • parsnip : to specify model
  • yardstick : to evaluate model

Demo

Preparation

Open new R project, then load packages:

  • tidyverse : for data wrangling and data visualization
  • haven : to read statistical data
  • gtsummary : to produce statistical tables
library(tidyverse)
library(haven)
library(tidymodels)
library(caret)
library(gtsummary)

Read data

  • Dataset named stroke_fatality.dta (in STATA format).
  • Read then convert it to an R object of class data.frame
dead <- read_dta('stroke_fatality.dta') %>% 
  data.frame()
  • Convert labelled variables to dummy variables
dead <- dead %>%
  mutate(across(where(is.labelled), as_factor))

Variables :

glimpse(dead)
Rows: 226
Columns: 19
$ sex       <fct> female, male, female, female, male, female, female, female, …
$ race      <fct> malay, malay, malay, malay, malay, chinese, malay, malay, ma…
$ status2   <fct> dead, alive, alive, alive, alive, dead, alive, dead, alive, …
$ gcs       <dbl> 15, 15, 15, 11, 15, 7, 5, 13, 15, 15, 10, 15, 14, 9, 15, 15,…
$ sbp       <dbl> 150, 152, 231, 110, 199, 190, 145, 161, 222, 161, 149, 153, …
$ dbp       <dbl> 87, 108, 117, 79, 134, 101, 102, 96, 129, 107, 90, 61, 95, 1…
$ hr        <dbl> 92, 87, 64, 90, 72, 63, 102, 81, 72, 94, 59, 81, 61, 120, 67…
$ hb        <dbl> 10.4, 13.0, 11.0, 14.3, 15.7, 11.7, 13.4, 11.8, 12.6, 16.4, …
$ plat      <dbl> 249, 156, 179, 233, 351, 133, 290, 251, 196, 188, 139, 306, …
$ wbc       <dbl> 12.5, 7.4, 22.4, 9.6, 18.7, 11.3, 15.8, 8.5, 9.0, 9.5, 11.0,…
$ na        <dbl> 138, 132, 135, 132, 138, 140, 134, 135, 129, 137, 141, 137, …
$ potas     <dbl> 3.6, 4.1, 4.7, 3.8, 3.8, 3.0, 4.1, 4.0, 4.0, 3.7, 4.1, 4.2, …
$ gluc      <dbl> NA, 5.1, NA, NA, NA, NA, NA, NA, NA, 5.5, NA, NA, NA, NA, NA…
$ cbs       <dbl> 11.4, NA, 18.6, 6.8, 6.5, 8.4, 13.4, 14.9, 6.6, 5.8, 7.7, 6.…
$ chol      <dbl> NA, 4.7, NA, NA, NA, NA, NA, NA, NA, 6.9, NA, NA, NA, NA, NA…
$ tg        <dbl> NA, 1.1, NA, NA, NA, NA, NA, NA, NA, 1.0, NA, NA, NA, NA, NA…
$ urea      <dbl> 5.1, 8.4, 11.0, 7.4, 3.8, 4.0, 3.3, 5.6, 5.8, 5.8, 7.3, 7.6,…
$ hpt2      <fct> yes, yes, yes, yes, yes, no, yes, yes, yes, no, yes, no, yes…
$ icd10cat2 <fct> "CI,Others", "CI,Others", "Haemorrhagic", "Haemorrhagic", "H…

Outcome variable :

dead %>%
  count(status2)
  status2   n
1   alive 173
2    dead  53

Split data

  • A stratified random sample with 60/40 split within each of these data subsets

  • Then pool the results together.

  • In rsample, this is achieved using the strata argument. Usually split is done 80/20 or 70/30

Resource here

Training and testing data

  • Split
set.seed(123)
dead_split <- initial_split(dead, 
                            prop = 0.6, 
                            strata = status2)
dead_split
<Training/Testing/Total>
<134/92/226>
  • Extract training set from the split
dead_train <- training(dead_split)
  • Extract testing set from the split
dead_test <- testing(dead_split)
  • Extract validation set from training set
set.seed(234)
dead_val <- validation_split(dead_train, 
                            strata = status2, 
                            prop = 0.60)

Model

As the outcome is categorical, we will use logistic regression model from glmnet package

lr_mod <- 
  logistic_reg(penalty = tune(), mixture = 1) %>% 
  set_engine("glmnet")
  • tune() will find the best value for making predictions
  • mixture = 1 allows glmnet model to remove irrelevant predictors and choose a simpler model

glmnet package:

  • fits generalized linear and similar models via penalized maximum likelihood.
  • computes regularized path to obtain the lasso or elastic net penalty
  • uses grid of values (on the log scale) for the regularization parameter lambda.

Resource is here

Recipe

  • Define preprocessing steps using recipe()
lr_recipe <- 
  recipe(status2 ~ ., data = dead_train) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_other(threshold = 0.20) %>%
  step_impute_knn(all_predictors()) %>%
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors()) 
  • step_dummy() : converts characters or factors numeric binary model terms
  • step_zv() : removes indicator variables that only contain a single unique value (e.g. all zeros).
  • step_normalize() : centers and scales numeric variables

Workflow and Tuning

Create a workflow for ML algorithm

lr_workflow <- 
  workflow() %>% 
  add_model(lr_mod) %>% 
  add_recipe(lr_recipe)

Perform fine tuning

  • dials::grid_regular() : creates an expanded grid based on a combination of two hyperparameters
  • sets the grid up manually using a one-column tibble with 30 candidate values
lr_reg_grid <- 
  tibble(penalty = 10^seq(-4, -1, length.out = 30))

Train and tune model

  • tune::tune_grid() : trains 30 penalized logistic regression models
  • save the validation set predictions using control_grid()
  • area under the ROC curve will be used to quantify how well the model performs
lr_res <- 
  lr_workflow %>% 
  tune_grid(dead_val,
            grid = lr_reg_grid,
            control = control_grid(save_pred = TRUE),
            metrics = metric_set(roc_auc))

Get the validation set metrics:

lr_plot <- 
  lr_res %>% 
  collect_metrics() %>% 
  ggplot(aes(x = penalty, y = mean)) + 
  geom_point() + 
  geom_line() + 
  ylab("Area under the ROC Curve") +
  scale_x_log10(labels = scales::label_number())
lr_plot

  • Model performance is generally better after 0.01
  • suggesting some predictors are important to the model.
  • roc_auc metric alone could lead us to multiple options for the best value for this hyperparameter:
top_models <-
  lr_res %>% 
  show_best("roc_auc", n = 15) %>% 
  arrange(penalty) 
top_models
# A tibble: 15 × 7
   penalty .metric .estimator  mean     n std_err .config              
     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
 1 0.00137 roc_auc binary     0.716     1      NA Preprocessor1_Model12
 2 0.00452 roc_auc binary     0.718     1      NA Preprocessor1_Model17
 3 0.00574 roc_auc binary     0.720     1      NA Preprocessor1_Model18
 4 0.00728 roc_auc binary     0.733     1      NA Preprocessor1_Model19
 5 0.00924 roc_auc binary     0.758     1      NA Preprocessor1_Model20
 6 0.0117  roc_auc binary     0.777     1      NA Preprocessor1_Model21
 7 0.0149  roc_auc binary     0.797     1      NA Preprocessor1_Model22
 8 0.0189  roc_auc binary     0.811     1      NA Preprocessor1_Model23
 9 0.0240  roc_auc binary     0.817     1      NA Preprocessor1_Model24
10 0.0304  roc_auc binary     0.804     1      NA Preprocessor1_Model25
11 0.0386  roc_auc binary     0.814     1      NA Preprocessor1_Model26
12 0.0489  roc_auc binary     0.807     1      NA Preprocessor1_Model27
13 0.0621  roc_auc binary     0.798     1      NA Preprocessor1_Model28
14 0.0788  roc_auc binary     0.761     1      NA Preprocessor1_Model29
15 0.1     roc_auc binary     0.745     1      NA Preprocessor1_Model30

We prefer to choose a penalty value further along the x-axis, closer to where we start to see the decline in model performance.

lr_best <- 
  lr_res %>% 
  collect_metrics() %>% 
  arrange(penalty) %>% 
  slice(24)
lr_best
# A tibble: 1 × 7
  penalty .metric .estimator  mean     n std_err .config              
    <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
1  0.0240 roc_auc binary     0.817     1      NA Preprocessor1_Model24

Visualize the roc curve using the best penalty values

lr_auc <- 
  lr_res %>% 
  collect_predictions(parameters = lr_best) %>% 
  roc_curve(status2, .pred_alive) %>% 
  mutate(model = "Logistic Regression")
autoplot(lr_auc)

Suggested workflow

Workflow

  • Clearly understand the objective of analysis: Inference or prediction
  • Identify the data
  • Assess quality of data
  • Split data
  • Pre-processing
  • Tuning
  • Test accuracy

To be effective

  • create a team consisting of at least a subject matter expert and a programming expert (and an epidemiologist or a biostatistician)
  • vet the workflow thoroughly
  • assess quality of data
  • assess quality of prediction
  • Vet the ac curacies (do multiple ML models)
  • do not peak at data

Bias in medical ML projects

  • ML methods have systematic errors
    • errors in classifying subgroups of patients

    • errors in estimating risk levels

    • errors in predictions.

  • Conflicting results : Accuracy from artificial intelligence (AI) models derived from medical research vs accuracy in real clinical setting. Due to
    • ethical differences

    • societal variations

    • For example: difference in skin color (using machine to detect abnormalities on skin)

Read more from this source

Bias

Bias

Statistical bias

  • cases in which the distribution of a given dataset is not reflecting the true distribution of the population.

Social bias

  • inequities that may result in suboptimal outcomes for given groups of the human population

Mitigating bias

Strategies for mitigating bias across the different steps in machine learning systems development

Source

Summary

at 2018 R User Conference with Max Kuhn

  • Prediction is not inference
  • Use prediction to guess the diagnosis or outcome using data
  • Aim of prediction is to guess and not to understand the relationship between certain independent variable (risk factors) and the outcome of interest
  • Predicted outcome is only one condition. This condition could be categorical (classification) or numerical (regression).
  • Most prediction models use machine learning methods
  • Many ML packages but most used in R is tidymodels
  • Better quality data leads to more valid prediction