This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).

Overview

Predicting Patient Outcomes with Graph Representation Learning

This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning. You can watch a video of the spotlight talk at W3PHIAI (AAAI workshop) here:

Watch the video

Citation

If you use this code or the models in your research, please cite the following:

@misc{rocheteautong2021,
      title={Predicting Patient Outcomes with Graph Representation Learning}, 
      author={Emma Rocheteau and Catherine Tong and Petar Veličković and Nicholas Lane and Pietro Liò},
      year={2021},
      eprint={2101.03940},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Motivation

Recent work on predicting patient outcomes in the Intensive Care Unit (ICU) has focused heavily on the physiological time series data, largely ignoring sparse data such as diagnoses and medications. When they are included, they are usually concatenated in the late stages of a model, which may struggle to learn from rarer disease patterns. Instead, we propose a strategy to exploit diagnoses as relational information by connecting similar patients in a graph. To this end, we propose LSTM-GNN for patient outcome prediction tasks: a hybrid model combining Long Short-Term Memory networks (LSTMs) for extracting temporal features and Graph Neural Networks (GNNs) for extracting the patient neighbourhood information. We demonstrate that LSTM-GNNs outperform the LSTM-only baseline on length of stay prediction tasks on the eICU database. More generally, our results indicate that exploiting information from neighbouring patient cases using graph neural networks is a promising research direction, yielding tangible returns in supervised learning performance on Electronic Health Records.

Pre-Processing Instructions

eICU Pre-Processing

  1. To run the sql files you must have the eICU database set up: https://physionet.org/content/eicu-crd/2.0/.

  2. Follow the instructions: https://eicu-crd.mit.edu/tutorials/install_eicu_locally/ to ensure the correct connection configuration.

  3. Replace the eICU_path in paths.json to a convenient location in your computer, and do the same for eICU_preprocessing/create_all_tables.sql using find and replace for '/Users/emmarocheteau/PycharmProjects/eICU-GNN-LSTM/eICU_data/'. Leave the extra '/' at the end.

  4. In your terminal, navigate to the project directory, then type the following commands:

    psql 'dbname=eicu user=eicu options=--search_path=eicu'
    

    Inside the psql console:

    \i eICU_preprocessing/create_all_tables.sql
    

    This step might take a couple of hours.

    To quit the psql console:

    \q
    
  5. Then run the pre-processing scripts in your terminal. This will need to run overnight:

    python3 -m eICU_preprocessing.run_all_preprocessing
    

Graph Construction

To make the graphs, you can use the following scripts:

This is to make most of the graphs that we use. You can alter the arguments given to this script.

python3 -m graph_construction.create_graph --freq_adjust --penalise_non_shared --k 3 --mode k_closest

Write the diagnosis strings into eICU_data folder:

python3 -m graph_construction.get_diagnosis_strings

Get the bert embeddings:

python3 -m graph_construction.bert

Create the graph from the bert embeddings:

python3 -m graph_construction.create_bert_graph --k 3 --mode k_closest

Alternatively, you can request to download our graphs using this link: https://drive.google.com/drive/folders/1yWNLhGOTPhu6mxJRjKCgKRJCJjuToBS4?usp=sharing

Training the ML Models

Before proceeding to training the ML models, do the following.

  1. Define data_dir, graph_dir, log_path and ray_dir in paths.json to convenient locations.

  2. Run the following to unpack the processed eICU data into mmap files for easy loading during training. The mmap files will be saved in data_dir.

    python3 -m src.dataloader.convert
    

The following commands train and evaluate the models introduced in our paper.

N.B.

  • The models are structured using pytorch-lightning. Graph neural networks and neighbourhood sampling are implemented using pytorch-geometric.

  • Our models assume a default graph which is made with k=3 under a k-closest scheme. If you wish to use other graphs, refer to read_graph_edge_list in src/dataloader/pyg_reader.py to add a reference handle to version2filename for your graph.

  • The default task is In-House-Mortality Prediction (ihm), add --task los to the command to perform the Length-of-Stay Prediction (los) task instead.

  • These commands use the best set of hyperparameters; To use other hyperparameters, remove --read_best from the command and refer to src/args.py.

a. LSTM-GNN

The following runs the training and evaluation for LSTM-GNN models. --gnn_name can be set as gat, sage, or mpnn. When mpnn is used, add --ns_sizes 10 to the command.

python3 -m train_ns_lstmgnn --bilstm --ts_mask --add_flat --class_weights --gnn_name gat --add_diag --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.lstmgnn_search --bilstm --ts_mask --add_flat --class_weights  --gnn_name gat --add_diag

b. Dynamic LSTM-GNN

The following runs the training & evaluation for dynamic LSTM-GNN models. --gnn_name can be set as gcn, gat, or mpnn.

python3 -m train_dynamic --bilstm --random_g --ts_mask --add_flat --class_weights --gnn_name mpnn --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.dynamic_lstmgnn_search --bilstm --random_g --ts_mask --add_flat --class_weights --gnn_name mpnn

c. GNN

The following runs the GNN models (with neighbourhood sampling). --gnn_name can be set as gat, sage, or mpnn. When mpnn is used, add --ns_sizes 10 to the command.

python3 -m train_ns_gnn --ts_mask --add_flat --class_weights --gnn_name gat --add_diag --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.ns_gnn_search --ts_mask --add_flat --class_weights --gnn_name gat --add_diag

d. LSTM (Baselines)

The following runs the baseline bi-LSTMs. To remove diagnoses from the input vector, remove --add_diag from the command.

python3 -m train_ns_lstm --bilstm --ts_mask --add_flat --class_weights --num_workers 0 --add_diag --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.lstm_search --bilstm --ts_mask --add_flat --class_weights --num_workers 0 --add_diag
Owner
Emma Rocheteau
Computer Science PhD Student at Cambridge
Emma Rocheteau
This application explain how we can easily integrate Deepface framework with Python Django application

deepface_suite This application explain how we can easily integrate Deepface framework with Python Django application install redis cache install requ

Mohamed Naji Aboo 3 Apr 18, 2022
This is the official repository of XVFI (eXtreme Video Frame Interpolation)

XVFI This is the official repository of XVFI (eXtreme Video Frame Interpolation), https://arxiv.org/abs/2103.16206 Last Update: 20210607 We provide th

Jihyong Oh 195 Dec 29, 2022
Official implementation of Pixel-Level Bijective Matching for Video Object Segmentation

BMVOS This is the official implementation of Pixel-Level Bijective Matching for Video Object Segmentation, to appear in WACV 2022. @article{cho2021pix

Suhwan Cho 13 Dec 14, 2022
H&M Fashion Image similarity search with Weaviate and DocArray

H&M Fashion Image similarity search with Weaviate and DocArray This example shows how to do image similarity search using DocArray and Weaviate as Doc

Laura Ham 18 Aug 11, 2022
MiraiML: asynchronous, autonomous and continuous Machine Learning in Python

MiraiML Mirai: future in japanese. MiraiML is an asynchronous engine for continuous & autonomous machine learning, built for real-time usage. Usage In

Arthur Paulino 25 Jul 27, 2022
This is a Python Module For Encryption, Hashing And Other stuff

EnroCrypt This is a Python Module For Encryption, Hashing And Other Basic Stuff You Need, With Secure Encryption And Strong Salted Hashing You Can Do

5 Sep 15, 2022
Supervised 3D Pre-training on Large-scale 2D Natural Image Datasets for 3D Medical Image Analysis

Introduction This is an implementation of our paper Supervised 3D Pre-training on Large-scale 2D Natural Image Datasets for 3D Medical Image Analysis.

24 Dec 06, 2022
The code for our NeurIPS 2021 paper "Kernelized Heterogeneous Risk Minimization".

Kernelized-HRM Jiashuo Liu, Zheyuan Hu The code for our NeurIPS 2021 paper "Kernelized Heterogeneous Risk Minimization"[1]. This repo contains the cod

Liu Jiashuo 8 Nov 20, 2022
Generalizing Gaze Estimation with Outlier-guided Collaborative Adaptation

Generalizing Gaze Estimation with Outlier-guided Collaborative Adaptation Our paper is accepted by ICCV2021. Picture: Overview of the proposed Plug-an

Yunfei Liu 32 Dec 10, 2022
TensorFlow2 Classification Model Zoo playing with TensorFlow2 on the CIFAR-10 dataset.

Training CIFAR-10 with TensorFlow2(TF2) TensorFlow2 Classification Model Zoo. I'm playing with TensorFlow2 on the CIFAR-10 dataset. Architectures LeNe

Chia-Hung Yuan 16 Sep 27, 2022
A model which classifies reviews as positive or negative.

SentiMent Analysis In this project I built a model to classify movie reviews fromn the IMDB dataset of 50K reviews. WordtoVec : Neural networks only w

Rishabh Bali 2 Feb 09, 2022
Python SDK for building, training, and deploying ML models

Overview of Kubeflow Fairing Kubeflow Fairing is a Python package that streamlines the process of building, training, and deploying machine learning (

Kubeflow 325 Dec 13, 2022
The object detection pipeline is based on Ultralytics YOLOv5

AYOLOv2 The main goal of this repository is to rewrite the object detection pipeline with a better code structure for better portability and adaptabil

153 Dec 22, 2022
Meaningful titles for tabs and PDF downloads! Also supports tab search.

arxiv-utils If you are a researcher that reads a lot on ArXiv, you'll benefit a lot from this web extension. Renames the title of PDF page to the pape

Johnson 174 Dec 20, 2022
PyTorch implementation of DCT fast weight RNNs

DCT based fast weights This repository contains the official code for the paper: Training and Generating Neural Networks in Compressed Weight Space. T

Kazuki Irie 4 Dec 24, 2022
Official repository for "Exploiting Session Information in BERT-based Session-aware Sequential Recommendation", SIGIR 2022 short.

Session-aware BERT4Rec Official repository for "Exploiting Session Information in BERT-based Session-aware Sequential Recommendation", SIGIR 2022 shor

Jamie J. Seol 22 Dec 13, 2022
PCAM: Product of Cross-Attention Matrices for Rigid Registration of Point Clouds

PCAM: Product of Cross-Attention Matrices for Rigid Registration of Point Clouds PCAM: Product of Cross-Attention Matrices for Rigid Registration of P

valeo.ai 24 May 31, 2022
Python library for tracking human heads with FLAME (a 3D morphable head model)

Video Head Tracker 3D tracking library for human heads based on FLAME (a 3D morphable head model). The tracking algorithm is inspired by face2face. It

61 Dec 25, 2022
TransMIL: Transformer based Correlated Multiple Instance Learning for Whole Slide Image Classification

TransMIL: Transformer based Correlated Multiple Instance Learning for Whole Slide Image Classification [NeurIPS 2021] Abstract Multiple instance learn

132 Dec 30, 2022
EEGEyeNet is benchmark to evaluate ET prediction based on EEG measurements with an increasing level of difficulty

Introduction EEGEyeNet EEGEyeNet is a benchmark to evaluate ET prediction based on EEG measurements with an increasing level of difficulty. Overview T

Ard Kastrati 23 Dec 22, 2022