Introduction

This vignette describes how to use the DeepPatientLevelPrediction package for transfer learning. Transfer learning is a machine learning technique where a model trained on one task is used as a starting point for training a model on a different task. This can be useful when you have a small dataset for the new task, but a large dataset for a related task. In this vignette, we will show how to use the DeepPatientLevelPrediction package to perform transfer learning on a patient-level prediction task.

Training initial model

The first step in transfer learning is to train an initial model. In this example, we will train a model to predict the risk of a patient developing a certain condition based on their electronic health record data. We will use the Eunomia package to access a dataset to train the model. The following code shows how to train the initial model:

library(DeepPatientLevelPrediction)

# Get connection details for the Eunomia dataset and create the cohorts
connectionDetails <- Eunomia::getEunomiaConnectionDetails()
Eunomia::createCohorts(connectionDetails)

The default Eunomia package includes four cohorts. Gastrointestinal bleeding (GiBleed) and use of three different drugs, diclofenac, NSAIDS and celecoxib. Usually then we would use one of three drug cohorts as our target cohort and then predict the risk of gastrointestinal bleeding. The cohort_definition_ids of these are: celecoxib: 1, diclofenac: 2, GiBleed: 3 and NSAIDS: 4.

After creating the cohorts we can see that there are most patients in the NSAIDS cohort. We will use this cohort as our target cohort for the initial model. There are least patients in the diclofenac cohort (excluding GiBleed), so we will use this cohort as our target cohort for the transfer learning model.

# create some simple covariate settings using Sex, Age and Long-term conditions and drug use in the last year.
covariateSettings <- FeatureExtraction::createCovariateSettings(
  useDemographicsGender = TRUE,
  useDemographicsAge = TRUE,
  useConditionOccurrenceLongTerm = TRUE,
  useDrugEraLongTerm = TRUE,
  endDays = 0
)

# Information about the database. In Eunomia sqlite there is only one schema, main and the cohorts are in a table named `cohort` which is the default. 
databaseDetails <- PatientLevelPrediction::createDatabaseDetails(
  connectionDetails = connectionDetails,
  cdmDatabaseId = "2", # Eunomia version used
  cdmDatabaseSchema = "main",
  targetId = 4,
  outcomeIds = 3,
  cdmDatabaseName = "eunomia"
)

# Let's now extract the plpData object from the database
plpData <- PatientLevelPrediction::getPlpData(
  databaseDetails = databaseDetails,
  covariateSettings = covariateSettings,
  restrictPlpDataSettings = PatientLevelPrediction::createRestrictPlpDataSettings()
)

Now we can set up our initial model development. We will use a simple ResNet.

modelSettings <- setResNet(numLayers = c(2),
                           sizeHidden = 128,
                           hiddenFactor = 1,
                           residualDropout = 0.1,
                           hiddenDropout = 0.1,
                           sizeEmbedding = 128,
                           estimatorSettings = setEstimator(
                             learningRate = 3e-4,
                             weightDecay = 0,
                             device = "cpu", # use cuda here if you have a gpu
                             batchSize = 256,
                             epochs = 5,
                             seed = 42
                           ),
                           hyperParamSearch = "random",
                           randomSample = 1)

plpResults <- PatientLevelPrediction::runPlp(
  plpData = plpData,
  outcomeId = 3, # 4 is the id of GiBleed
  modelSettings = modelSettings,
  analysisName = "Nsaids_GiBleed",
  analysisId = "1",
  # Let's predict the risk of Gibleed in the year following start of NSAIDs use
  populationSettings = PatientLevelPrediction::createStudyPopulationSettings(
    requireTimeAtRisk = FALSE,
    firstExposureOnly = TRUE,
    riskWindowStart = 1,
    riskWindowEnd = 365
  ),
  splitSettings = PatientLevelPrediction::createDefaultSplitSetting(splitSeed = 42),
  saveDirectory = "./output" # save in a folder in the current directory
)

This should take a few minutes on a cpu. Now that we have a model developed we can further finetune it on the diclofenac cohort. First we need to extract it.

databaseDetails <- PatientLevelPrediction::createDatabaseDetails(
  connectionDetails = connectionDetails,
  cdmDatabaseId = "2", # Eunomia version used
  cdmDatabaseSchema = "main",
  targetId = 2, # diclofenac cohort
  outcomeIds = 3,
  cdmDatabaseName = "eunomia"
)

plpDataTransfer <- PatientLevelPrediction::getPlpData(
  databaseDetails = databaseDetails,
  covariateSettings = covariateSettings, # same as for the developed model
  restrictPlpDataSettings = PatientLevelPrediction::createRestrictPlpDataSettings()
)

Now we can set up our transfer learning model development. For this we need to use a different modelSettings function. setFinetuner. We also need to know the path to the previously developed model. This should be of the form outputDir/analysisId/plpResult/model where outputDir is the directory specified when we develop our model and analysisId is the id we gave the analysis. In this case it is 1 and the path to the model is: ./output/1/plpResult/model.

modelSettingsTransfer <- setFinetuner(modelPath = './output/1/plpResult/model',
                                      estimatorSettings = setEstimator(
                                        learningRate = 3e-4,
                                        weightDecay = 0,
                                        device = "cpu", # use cuda here if you have a gpu
                                        batchSize = 256,
                                        epochs = 5,
                                        seed = 42
                                      ))

Currently the basic transfer learning works by loading the previously trained model and resetting it’s last layer, often called the prediction head. Then it will train only the parameters in this last layer. The hope is that the other layer’s have learned some generalizable representations of our data and by modifying the last layer we can mix those representations to suit the new task.

plpResultsTransfer <- PatientLevelPrediction::runPlp(
  plpData = plpDataTransfer,
  outcomeId = 3,
  modelSettings = modelSettingsTransfer,
  analysisName = "Diclofenac_GiBleed",
  analysisId = "2",
  populationSettings = PatientLevelPrediction::createStudyPopulationSettings(
    requireTimeAtRisk = FALSE,
    firstExposureOnly = TRUE,
    riskWindowStart = 1,
    riskWindowEnd = 365
  ),
  splitSettings = PatientLevelPrediction::createDefaultSplitSetting(splitSeed = 42),
  saveDirectory = "./outputTransfer" # save in a folder in the current directory
)

This should be much faster since it’s only training the last layer. Unfortunately the results are bad. However this is a toy example on synthetic toy data but the process on large observational data is exactly the same.

Conclusion

Now you have finetuned a model on a new cohort using transfer learning. This can be useful when you have a small dataset for the new task, but a large dataset for a related task or from a different database. The DeepPatientLevelPrediction package makes it easy to perform transfer learning on patient-level prediction tasks.

Acknowledgments

Considerable work has been dedicated to provide the DeepPatientLevelPrediction package.

citation("DeepPatientLevelPrediction")
## To cite package 'DeepPatientLevelPrediction' in publications use:
## 
##   Fridgeirsson E, Reps J, Chan You S, Kim C, John H (8).
##   _DeepPatientLevelPrediction: Deep Learning For Patient Level
##   Prediction Using Data In The OMOP Common Data Model_. R package
##   version 2.1.0, <https://github.com/OHDSI/DeepPatientLevelPrediction>.
## 
## A BibTeX entry for LaTeX users is
## 
##   @Manual{,
##     title = {DeepPatientLevelPrediction: Deep Learning For Patient Level Prediction Using Data In The
## OMOP Common Data Model},
##     author = {Egill Fridgeirsson and Jenna Reps and Seng {Chan You} and Chungsoo Kim and Henrik John},
##     year = {8},
##     note = {R package version 2.1.0},
##     url = {https://github.com/OHDSI/DeepPatientLevelPrediction},
##   }

Please reference this paper if you use the PLP Package in your work:

Reps JM, Schuemie MJ, Suchard MA, Ryan PB, Rijnbeek PR. Design and implementation of a standardized framework to generate and evaluate patient-level prediction models using observational healthcare data. J Am Med Inform Assoc. 2018;25(8):969-975.