Multimodal Co-Attention Transformer (MCAT) for Survival Prediction in Gigapixel Whole Slide Images

Related tags

Deep LearningMCAT
Overview

Multimodal Co-Attention Transformer (MCAT) for Survival Prediction in Gigapixel Whole Slide Images

[ICCV 2021]

© Mahmood Lab - This code is made available under the GPLv3 License and is available for non-commercial academic purposes.

If you find our work useful in your research or if you use parts of this code please consider citing our paper:

@inproceedings{chen2021multimodal,
  title={Multimodal Co-Attention Transformer for Survival Prediction in Gigapixel Whole Slide Images},
  author={Chen, Richard J and Lu, Ming Y and Weng, Wei-Hung and Chen, Tiffany Y and Williamson, Drew FK and Manz, Trevor and Shady, Maha and Mahmood, Faisal},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={4015--4025},
  year={2021}
}

Updates:

  • 11/12/2021: Several users have raised concerns about the low c-Index for GBMLGG in SNN (Genomic Only). In using the gene families from MSigDB as gene signatures, IDH1 mutation was not included (key biomarker in distinguishing GBM and LGG).
  • 06/18/2021: Updated data preprocessing section for reproducibility.
  • 06/17/2021: Uploaded predicted risk scores on the validation folds for each models, and the evaluation script to compute the c-Index and Integrated AUC (I-AUC) validation metrics, found using the following Jupyter Notebook. Model checkpoints for MCAT are uploaded in the results directory.
  • 06/17/2021: Uploaded notebook detailing the MCAT network architecture, with sample input in the following following Jupyter Notebook, in which we print the shape of the tensors at each stage of MCAT.

Pre-requisites:

  • Linux (Tested on Ubuntu 18.04)
  • NVIDIA GPU (Tested on Nvidia GeForce RTX 2080 Ti x 16) with CUDA 11.0 and cuDNN 7.5
  • Python (3.7.7), h5py (2.10.0), matplotlib (3.1.1), numpy (1.18.1), opencv-python (4.1.1), openslide-python (1.1.1), openslide (3.4.1), pandas (1.1.3), pillow (7.0.0), PyTorch (1.6.0), scikit-learn (0.22.1), scipy (1.4.1), tensorflow (1.13.1), tensorboardx (1.9), torchvision (0.7.0), captum (0.2.0), shap (0.35.0)

Installation Guide for Linux (using anaconda)

1. Downloading TCGA Data

To download diagnostic WSIs (formatted as .svs files), molecular feature data and other clinical metadata, please refer to the NIH Genomic Data Commons Data Portal and the cBioPortal. WSIs for each cancer type can be downloaded using the GDC Data Transfer Tool.

2. Processing Whole Slide Images

To process WSIs, first, the tissue regions in each biopsy slide are segmented using Otsu's Segmentation on a downsampled WSI using OpenSlide. The 256 x 256 patches without spatial overlapping are extracted from the segmented tissue regions at the desired magnification. Consequently, a pretrained truncated ResNet50 is used to encode raw image patches into 1024-dim feature vectors, which we then save as .pt files for each WSI. The extracted features then serve as input (in a .pt file) to the network. The following folder structure is assumed for the extracted features vectors:

DATA_ROOT_DIR/
    └──TCGA_BLCA/
        ├── slide_1.pt
        ├── slide_2.pt
        └── ...
    └──TCGA_BRCA/
        ├── slide_1.pt
        ├── slide_2.pt
        └── ...
    └──TCGA_GBMLGG/
        ├── slide_1.pt
        ├── slide_2.pt
        └── ...
    └──TCGA_LUAD/
        ├── slide_1.ptd
        ├── slide_2.pt
        └── ...
    └──TCGA_UCEC/
        ├── slide_1.pt
        ├── slide_2.pt
        └── ...
    ...

DATA_ROOT_DIR is the base directory of all datasets / cancer type(e.g. the directory to your SSD). Within DATA_ROOT_DIR, each folder contains a list of .pt files for that dataset / cancer type.

3. Molecular Features and Genomic Signatures

Processed molecular profile features containing mutation status, copy number variation, and RNA-Seq abundance can be downloaded from the cBioPortal, which we include as CSV files in the following directory. For ordering gene features into gene embeddings, we used the following categorization of gene families (categorized via common features such as homology or biochemical activity) from MSigDB. Gene sets for homeodomain proteins and translocated cancer genes were not used due to overlap with transcription factors and oncogenes respectively. The curation of "genomic signatures" can be modified to curate genomic embedding that reflect unique biological functions.

4. Training-Validation Splits

For evaluating the algorithm's performance, we randomly partitioned each dataset using 5-fold cross-validation. Splits for each cancer type are found in the splits/5foldcv folder, which each contain splits_{k}.csv for k = 1 to 5. In each splits_{k}.csv, the first column corresponds to the TCGA Case IDs used for training, and the second column corresponds to the TCGA Case IDs used for validation. Alternatively, one could define their own splits, however, the files would need to be defined in this format. The dataset loader for using these train-val splits are defined in the get_split_from_df function in the Generic_WSI_Survival_Dataset class (inherited from the PyTorch Dataset class).

5. Running Experiments

To run experiments using the SNN, AMIL, and MMF networks defined in this repository, experiments can be run using the following generic command-line:

CUDA_VISIBLE_DEVICES=<DEVICE ID> python main.py --which_splits <SPLIT FOLDER PATH> --split_dir <SPLITS FOR CANCER TYPE> --mode <WHICH MODALITY> --model_type <WHICH MODEL>

Commands for all experiments / models can be found in the Commands.md file.

Owner
Mahmood Lab @ Harvard/BWH
AI for Pathology Image Analysis Lab @ HMS / BWH
Mahmood Lab @ Harvard/BWH
CLOCs: Camera-LiDAR Object Candidates Fusion for 3D Object Detection

CLOCs is a novel Camera-LiDAR Object Candidates fusion network. It provides a low-complexity multi-modal fusion framework that improves the performance of single-modality detectors. CLOCs operates on

Su Pang 254 Dec 16, 2022
FEDn is an open-source, modular and ML-framework agnostic framework for Federated Machine Learning

FEDn is an open-source, modular and ML-framework agnostic framework for Federated Machine Learning (FedML) developed and maintained by Scaleout Systems. FEDn enables highly scalable cross-silo and cr

Scaleout 75 Nov 09, 2022
(NeurIPS 2021) Pytorch implementation of paper "Re-ranking for image retrieval and transductive few-shot classification"

SSR (NeurIPS 2021) Pytorch implementation of paper "Re-ranking for image retrieval and transductivefew-shot classification" [Paper] [Project webpage]

xshen 29 Dec 06, 2022
Theano is a Python library that allows you to define, optimize, and evaluate mathematical expressions involving multi-dimensional arrays efficiently. It can use GPUs and perform efficient symbolic differentiation.

============================================================================================================ `MILA will stop developing Theano https:

9.6k Jan 06, 2023
[CVPR'21 Oral] Seeing Out of tHe bOx: End-to-End Pre-training for Vision-Language Representation Learning

Seeing Out of tHe bOx: End-to-End Pre-training for Vision-Language Representation Learning [CVPR'21, Oral] By Zhicheng Huang*, Zhaoyang Zeng*, Yupan H

Multimedia Research 196 Dec 13, 2022
Vision Transformer and MLP-Mixer Architectures

Vision Transformer and MLP-Mixer Architectures Update (2.7.2021): Added the "When Vision Transformers Outperform ResNets..." paper, and SAM (Sharpness

Google Research 6.4k Jan 04, 2023
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.

English | 简体中文 | 繁體中文 | 한국어 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow 🤗 Transformers provides thousands of pretrai

Hugging Face 77.4k Jan 05, 2023
This is an official PyTorch implementation of Task-Adaptive Neural Network Search with Meta-Contrastive Learning (NeurIPS 2021, Spotlight).

NeurIPS 2021 (Spotlight): Task-Adaptive Neural Network Search with Meta-Contrastive Learning This is an official PyTorch implementation of Task-Adapti

Wonyong Jeong 15 Nov 21, 2022
CTF challenges and write-ups for MicroCTF 2021.

MicroCTF 2021 Qualifications About This repository contains CTF challenges and official write-ups for MicroCTF 2021 Qualifications. License Distribute

Shellmates 12 Dec 27, 2022
Source code for our paper "Molecular Mechanics-Driven Graph Neural Network with Multiplex Graph for Molecular Structures"

Molecular Mechanics-Driven Graph Neural Network with Multiplex Graph for Molecular Structures Code for the Multiplex Molecular Graph Neural Network (M

shzhang 59 Dec 10, 2022
The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.

The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dea

MIC-DKFZ 1.2k Jan 04, 2023
Code for our SIGCOMM'21 paper "Network Planning with Deep Reinforcement Learning".

0. Introduction This repository contains the source code for our SIGCOMM'21 paper "Network Planning with Deep Reinforcement Learning". Notes The netwo

NetX Group 68 Nov 24, 2022
Explicable Reward Design for Reinforcement Learning Agents [NeurIPS'21]

Explicable Reward Design for Reinforcement Learning Agents [NeurIPS'21]

3 May 12, 2022
3D dataset of humans Manipulating Objects in-the-Wild (MOW)

MOW dataset [Website] This repository maintains our 3D dataset of humans Manipulating Objects in-the-Wild (MOW). The dataset contains 512 images in th

Zhe Cao 28 Nov 06, 2022
Computational Pathology Toolbox developed by TIA Centre, University of Warwick.

TIA Toolbox Computational Pathology Toolbox developed at the TIA Centre Getting Started All Users This package is for those interested in digital path

Tissue Image Analytics (TIA) Centre 156 Jan 08, 2023
PyTorch implementation of Towards Accurate Alignment in Real-time 3D Hand-Mesh Reconstruction (ICCV 2021).

Towards Accurate Alignment in Real-time 3D Hand-Mesh Reconstruction Introduction This is official PyTorch implementation of Towards Accurate Alignment

TANG Xiao 96 Dec 27, 2022
Code for the paper "Graph Attention Tracking". (CVPR2021)

SiamGAT 1. Environment setup This code has been tested on Ubuntu 16.04, Python 3.5, Pytorch 1.2.0, CUDA 9.0. Please install related libraries before r

122 Dec 24, 2022
Machine Translation Implement By Bi-GRU And Transformer

Seq2Seq Translation Implement By Bidirectional GRU And Transformer In Pytorch Before You Run The Code You should download the data through the link be

He Wang 2 Oct 27, 2021
Generic image compressor for machine learning. Pytorch code for our paper "Lossy compression for lossless prediction".

Lossy Compression for Lossless Prediction Using: Training: This repostiory contains our implementation of the paper: Lossy Compression for Lossless Pr

Yann Dubois 84 Jan 02, 2023
A clean and extensible PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners

A clean and extensible PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners A PyTorch re-implementation of Mask Autoencoder trai

Tianyu Hua 23 Dec 13, 2022