Code for "Optimizing risk-based breast cancer screening policies with reinforcement learning"

Related tags

Deep LearningTempo
Overview

Tempo: Optimizing risk-based breast cancer screening policies with reinforcement learning DOI

Introduction

This repository was used to develop Tempo, as described in: Optimizing risk-based breast cancer screening policies with reinforcement learning.

Screening programs must balance the benefits of early detection against the costs of over screening. Here, we introduce a novel reinforcement learning-based framework for personalized screening, Tempo, and demonstrate its efficacy in the context of breast cancer. We trained our risk-based screening policies on a large screening mammography dataset from Massachusetts General Hospital (MGH) USA and validated them on held-out patients from MGH, and on external datasets from Emory USA, Karolinska Sweden and Chang Gung Memorial Hospital (CGMH) Taiwan. Across all test sets, we found that a Tempo policy combined with an image-based AI risk model, Mirai [1] was significantly more efficient than current regimes used in clinical practice in terms of simulated early detection per screen frequency. Moreover, we showed that the same Tempo policy can be easily adapted to a wide range of possible screening preferences, allowing clinicians to select their desired early detection to screening cost trade-off without training new policies. Finally, we demonstrated Tempo policies based on AI-based risk models out performed Tempo policies based on less accurate clinical risk models. Altogether, our results show that pairing AI-based risk models with agile AI-designed screening policies has the potential to improve screening programs, advancing early detection while reducing over-screening.

This code base is meant to provide exact implementation details for the development of Tempo.

Aside on Software Depedencies

This code assumes python3.6 and a Linux environment. The package requirements can be install with pip:

pip install -r requirements.txt

Tempo-Mirai assumes access to Mirai risk assessments. Resources for using Mirai are shown here.

Method

method

Our full framework, named Tempo, is depicted above. As described above, we first train a risk progression neural network to predict future risk assessments given previous assessments. This model is then used to estimate patient risk at unobserved timepoints and it enables us to simulate risk-based screening policies. Next, we train our screening policy, which is implemented as a neural network, to maximize the reward (i.e combination of early detection and screening cost) on our retrospective training set. We train our screening policy to support all possible early detection vs screening cost trade-offs using envelope Q-learning [2], an RL algorithm designed to balance multiple objectives. The input of our screening policies is the patient's risk assessment, and desired weighting between rewards (i.e screening preference). The output of the policy is a recommendation for when to return for the next screen, ranging from six months to three years in the future, in multiples of six months. Our reward balances two contrasting aspects, one reflecting the imaging cost, i.e., the average mammograms a year recommended by the policy, and one modeling early detection benefit relative to the retrospective screening trajectory. Our early detection reward measures the time difference in months between each patient's recommended screening date, if it was after their last negative mammogram, and their actual diagnosis date. We evaluate screening policies by simulating their recommendations for heldout patients.

Training Risk progression models

We experimented with different learning rates, hidden sizes, numbers of layers and dropout, and chose the model that obtained the lowest validation KL divergence on the MGH validation set. Our final risk progression RNN had two layers, a hidden dimension size of 100, a dropout of 0.25, and was trained for 30 epochs with a learning rate of 1e-3 using the Adam optimizer.

To reproduce our grid search for our Mirai risk progression model, you can run:

python scripts/dispatcher.py --experiment_config_path configs/risk_progression/gru.json

Given a trained risk progression model, we can now estimate unobserved risk assessments auto-regressively. At each time step, the model takes as input the previous risk assessment, the prior hidden state, using the previous predicted assessment if the real one is not available, and predicts the risk assessment at the next time step.

Training Tempo Personalized Screening Policies

We implemented our personalized screening policy as multiple layer perceptron, which took as input a risk assessment and weighting between rewards and predicted the Q-value for each action, i.e follow up recommendation, across the rewards. This network was trained using Envelope Q-Learning [2]. We experimented with different numbers of layers, hidden dimension sizes, learning rates, dropouts, exploration epsilons, target network reset rates and weight decay rates.

To reproduce our grid search for our Mirai risk progression model, you can run:

python scripts/dispatcher.py --experiment_config_path configs/screening/neural.json

Data availability

All datasets were used under license to the respective hospital system for the current study and are not publicly available. To access the MGH dataset, investigators should reach out to C.L. to apply for an IRB approved research collaboration and obtain an appropriate Data Use Agreement. To access the Karolinska dataset, investigators should reach out to F.S. to apply for an approved research collaboration and sign a Data Use Agreement. To access the CGMH dataset, investigators should contact G.L. to apply for an IRB approved research collaboration. To access the Emory dataset, investigators should reach out to H.T to apply for an approved collaboration.

References

[1] Yala, Adam, et al. "Toward robust mammography-based models for breast cancer risk." Science Translational Medicine 13.578 (2021).

[2] Yang, Runzhe, Xingyuan Sun, and Karthik Narasimhan. "A generalized algorithm for multi-objective reinforcement learning and policy adaptation." arXiv preprint arXiv:1908.08342 (2019).

Citing Tempo

@article{yala2021optimizing,
  title={Optimizing risk-based breast cancer screening policies with reinforcement learning},
  author={Yala, Adam and Mikhael, Peter and Lehman, Constance and Lin, Gigin and Strand, Fredrik and Wang, Yung-Liang and Hughes, Kevin and Satuluru, Siddharth and Kim, Thomas and Banerjee, Imon and others},
  year={2021}
}
You might also like...
Opinionated code formatter, just like Python's black code formatter but for Beancount

beancount-black Opinionated code formatter, just like Python's black code formatter but for Beancount Try it out online here Features MIT licensed - b

a delightful machine learning tool that allows you to train, test and use models without writing code
a delightful machine learning tool that allows you to train, test and use models without writing code

igel A delightful machine learning tool that allows you to train/fit, test and use models without writing code Note I'm also working on a GUI desktop

Pytorch Lightning code guideline for conferences

Deep learning project seed Use this seed to start new deep learning / ML projects. Built in setup.py Built in requirements Examples with MNIST Badges

Automatically Build Multiple ML Models with a Single Line of Code. Created by Ram Seshadri. Collaborators Welcome. Permission Granted upon Request.
Automatically Build Multiple ML Models with a Single Line of Code. Created by Ram Seshadri. Collaborators Welcome. Permission Granted upon Request.

Auto-ViML Automatically Build Variant Interpretable ML models fast! Auto_ViML is pronounced "auto vimal" (autovimal logo created by Sanket Ghanmare) N

Code samples for my book "Neural Networks and Deep Learning"

Code samples for "Neural Networks and Deep Learning" This repository contains code samples for my book on "Neural Networks and Deep Learning". The cod

Code for: https://berkeleyautomation.github.io/bags/

DeformableRavens Code for the paper Learning to Rearrange Deformable Cables, Fabrics, and Bags with Goal-Conditioned Transporter Networks. Here is the

Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166
Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166

Region Proportion Regularized Inference (RePRI) for Few-Shot Segmentation In this repo, we provide the code for our paper : "Few-Shot Segmentation Wit

Applications using the GTN library and code to reproduce experiments in "Differentiable Weighted Finite-State Transducers"

gtn_applications An applications library using GTN. Current examples include: Offline handwriting recognition Automatic speech recognition Installing

Code for
Code for "Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search"

Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search This is an implementation for our paper Contextual Non-Loca

Releases(v1.0)
Owner
Adam Yala
PhD Candidate at MIT CSAIL
Adam Yala
Dense Unsupervised Learning for Video Segmentation (NeurIPS*2021)

Dense Unsupervised Learning for Video Segmentation This repository contains the official implementation of our paper: Dense Unsupervised Learning for

Visual Inference Lab @TU Darmstadt 173 Dec 26, 2022
BoxInst: High-Performance Instance Segmentation with Box Annotations

Introduction This repository is the code that needs to be submitted for OpenMMLab Algorithm Ecological Challenge, the paper is BoxInst: High-Performan

88 Dec 21, 2022
A Kernel fuzzer focusing on race bugs

Razzer: Finding kernel race bugs through fuzzing Environment setup $ source scripts/envsetup.sh scripts/envsetup.sh sets up necessary environment var

Systems and Software Security Lab at Seoul National University (SNU) 328 Dec 26, 2022
Official Repository of NeurIPS2021 paper: PTR

PTR: A Benchmark for Part-based Conceptual, Relational, and Physical Reasoning Figure 1. Dataset Overview. Introduction A critical aspect of human vis

Yining Hong 32 Jun 02, 2022
I will implement Fastai in each projects present in this repository.

DEEP LEARNING FOR CODERS WITH FASTAI AND PYTORCH The repository contains a list of the projects which I have worked on while reading the book Deep Lea

Thinam Tamang 43 Dec 20, 2022
AQP is a modular pipeline built to enable the comparison and testing of different quality metric configurations.

Audio Quality Platform - AQP An Open Modular Python Platform for Objective Speech and Audio Quality Metrics AQP is a highly modular pipeline designed

Jack Geraghty 24 Oct 01, 2022
SelfAugment extends MoCo to include automatic unsupervised augmentation selection.

SelfAugment extends MoCo to include automatic unsupervised augmentation selection. In addition, we've included the ability to pretrain on several new datasets and included a wandb integration.

Colorado Reed 24 Oct 26, 2022
PyTorch reimplementation of minimal-hand (CVPR2020)

Minimal Hand Pytorch Unofficial PyTorch reimplementation of minimal-hand (CVPR2020). you can also find in youtube or bilibili bare hand youtube or bil

Hao Meng 228 Dec 29, 2022
GitHub repository for the ICLR Computational Geometry & Topology Challenge 2021

ICLR Computational Geometry & Topology Challenge 2022 Welcome to the ICLR 2022 Computational Geometry & Topology challenge 2022 --- by the ICLR 2022 W

42 Dec 13, 2022
Translate darknet to tensorflow. Load trained weights, retrain/fine-tune using tensorflow, export constant graph def to mobile devices

Intro Real-time object detection and classification. Paper: version 1, version 2. Read more about YOLO (in darknet) and download weight files here. In

Trieu 6.1k Dec 30, 2022
PyTorch code accompanying the paper "Landmark-Guided Subgoal Generation in Hierarchical Reinforcement Learning" (NeurIPS 2021).

HIGL This is a PyTorch implementation for our paper: Landmark-Guided Subgoal Generation in Hierarchical Reinforcement Learning (NeurIPS 2021). Our cod

Junsu Kim 20 Dec 14, 2022
Prometheus Exporter for data scraped from datenplattform.darmstadt.de

darmstadt-opendata-exporter Scrapes data from https://datenplattform.darmstadt.de and presents it in the Prometheus Exposition format. Pull requests w

Martin Weinelt 2 Apr 12, 2022
Repository for Driving Style Recognition algorithms for Autonomous Vehicles

Driving Style Recognition Using Interval Type-2 Fuzzy Inference System and Multiple Experts Decision Making Created by Iago Pachêco Gomes at USP - ICM

Iago Gomes 9 Nov 28, 2022
Tacotron 2 - PyTorch implementation with faster-than-realtime inference

Tacotron 2 (without wavenet) PyTorch implementation of Natural TTS Synthesis By Conditioning Wavenet On Mel Spectrogram Predictions. This implementati

NVIDIA Corporation 4.1k Jan 03, 2023
AOT-GAN for High-Resolution Image Inpainting (codebase for image inpainting)

AOT-GAN for High-Resolution Image Inpainting Arxiv Paper | AOT-GAN: Aggregated Contextual Transformations for High-Resolution Image Inpainting Yanhong

Multimedia Research 214 Jan 03, 2023
Code release for "Detecting Twenty-thousand Classes using Image-level Supervision".

Detecting Twenty-thousand Classes using Image-level Supervision Detic: A Detector with image classes that can use image-level labels to easily train d

Meta Research 1.3k Jan 04, 2023
Predicting Semantic Map Representations from Images with Pyramid Occupancy Networks

This is the code associated with the paper Predicting Semantic Map Representations from Images with Pyramid Occupancy Networks, published at CVPR 2020.

Thomas Roddick 219 Dec 20, 2022
A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.

TorchRL Disclaimer This library is not officially released yet and is subject to change. The features are available before an official release so that

Meta Research 860 Jan 07, 2023
Dynamic Realtime Animation Control

Our project is targeted at making an application that dynamically detects the user’s expressions and gestures and projects it onto an animation software which then renders a 2D/3D animation realtime

Harsh Avinash 10 Aug 01, 2022
Official implementation of Rethinking Graph Neural Architecture Search from Message-passing (CVPR2021)

Rethinking Graph Neural Architecture Search from Message-passing Intro The GNAS can automatically learn better architecture with the optimal depth of

Shaofei Cai 48 Sep 30, 2022