Pytorch-3dunet - 3D U-Net model for volumetric semantic segmentation written in pytorch

Overview

DOI Build Status

pytorch-3dunet

PyTorch implementation 3D U-Net and its variants:

The code allows for training the U-Net for both: semantic segmentation (binary and multi-class) and regression problems (e.g. de-noising, learning deconvolutions).

2D U-Net

Training the standard 2D U-Net is also possible, see 2DUnet_dsb2018 for example configuration. Just make sure to keep the singleton z-dimension in your H5 dataset (i.e. (1, Y, X) instead of (Y, X)) , because data loading / data augmentation requires tensors of rank 3 always.

Prerequisites

  • Linux
  • NVIDIA GPU
  • CUDA CuDNN

Running on Windows

The package has not been tested on Windows, however some reported using it on Windows. One thing to keep in mind: when training with CrossEntropyLoss: the label type in the config file should be change from long to int64, otherwise there will be an error: RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'target'.

Supported Loss Functions

Semantic Segmentation

  • BCEWithLogitsLoss (binary cross-entropy)
  • DiceLoss (standard DiceLoss defined as 1 - DiceCoefficient used for binary semantic segmentation; when more than 2 classes are present in the ground truth, it computes the DiceLoss per channel and averages the values).
  • BCEDiceLoss (Linear combination of BCE and Dice losses, i.e. alpha * BCE + beta * Dice, alpha, beta can be specified in the loss section of the config)
  • CrossEntropyLoss (one can specify class weights via weight: [w_1, ..., w_k] in the loss section of the config)
  • PixelWiseCrossEntropyLoss (one can specify not only class weights but also per pixel weights in order to give more gradient to important (or under-represented) regions in the ground truth)
  • WeightedCrossEntropyLoss (see 'Weighted cross-entropy (WCE)' in the below paper for a detailed explanation; one can specify class weights via weight: [w_1, ..., w_k] in the loss section of the config)
  • GeneralizedDiceLoss (see 'Generalized Dice Loss (GDL)' in the below paper for a detailed explanation; one can specify class weights via weight: [w_1, ..., w_k] in the loss section of the config). Note: use this loss function only if the labels in the training dataset are very imbalanced e.g. one class having at least 3 orders of magnitude more voxels than the others. Otherwise use standard DiceLoss.

For a detailed explanation of some of the supported loss functions see: Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations Carole H. Sudre, Wenqi Li, Tom Vercauteren, Sebastien Ourselin, M. Jorge Cardoso

IMPORTANT: if one wants to use their own loss function, bear in mind that the current model implementation always output logits and it's up to the implementation of the loss to normalize it correctly, e.g. by applying Sigmoid or Softmax.

Regression

  • MSELoss
  • L1Loss
  • SmoothL1Loss
  • WeightedSmoothL1Loss - extension of the SmoothL1Loss which allows to weight the voxel values above (below) a given threshold differently

Supported Evaluation Metrics

Semantic Segmentation

  • MeanIoU - Mean intersection over union
  • DiceCoefficient - Dice Coefficient (computes per channel Dice Coefficient and returns the average) If a 3D U-Net was trained to predict cell boundaries, one can use the following semantic instance segmentation metrics (the metrics below are computed by running connected components on thresholded boundary map and comparing the resulted instances to the ground truth instance segmentation):
  • BoundaryAveragePrecision - Average Precision applied to the boundary probability maps: thresholds the boundary maps given by the network, runs connected components to get the segmentation and computes AP between the resulting segmentation and the ground truth
  • AdaptedRandError - Adapted Rand Error (see http://brainiac2.mit.edu/SNEMI3D/evaluation for a detailed explanation)
  • AveragePrecision - see https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric

If not specified MeanIoU will be used by default.

Regression

  • PSNR - peak signal to noise ratio

Installation

  • The easiest way to install pytorch-3dunet package is via conda:
conda create -n 3dunet -c conda-forge -c awolny pytorch-3dunet
conda activate 3dunet

After installation the following commands are accessible within the conda environment: train3dunet for training the network and predict3dunet for prediction (see below).

  • One can also install directly from source:
python setup.py install

Installation tips

Make sure that the installed pytorch is compatible with your CUDA version, otherwise the training/prediction will fail to run on GPU. You can re-install pytorch compatible with your CUDA in the 3dunet env by:

conda install -c pytorch torchvision cudatoolkit=<YOU_CUDA_VERSION> pytorch

Train

Given that pytorch-3dunet package was installed via conda as described above, one can train the network by simply invoking:

train3dunet --config <CONFIG>

where CONFIG is the path to a YAML configuration file, which specifies all aspects of the training procedure.

In order to train on your own data just provide the paths to your HDF5 training and validation datasets in the config.

The HDF5 files should contain the raw/label data sets in the following axis order: DHW (in case of 3D) CDHW (in case of 4D).

One can monitor the training progress with Tensorboard tensorboard --logdir <checkpoint_dir>/logs/ (you need tensorflow installed in your conda env), where checkpoint_dir is the path to the checkpoint directory specified in the config.

Training tips

  1. When training with binary-based losses, i.e.: BCEWithLogitsLoss, DiceLoss, BCEDiceLoss, GeneralizedDiceLoss: The target data has to be 4D (one target binary mask per channel). If you have a 3D binary data (foreground/background), you can just change ToTensor transform for the label to contain expand_dims: true, see e.g. train_config_dice.yaml. When training with WeightedCrossEntropyLoss, CrossEntropyLoss, PixelWiseCrossEntropyLoss the target dataset has to be 3D, see also pytorch documentation for CE loss: https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html
  2. final_sigmoid in the model config section applies only to the inference time: When training with cross entropy based losses (WeightedCrossEntropyLoss, CrossEntropyLoss, PixelWiseCrossEntropyLoss) set final_sigmoid=False so that Softmax normalization is applied to the output. When training with BCEWithLogitsLoss, DiceLoss, BCEDiceLoss, GeneralizedDiceLoss set final_sigmoid=True

Prediction

Given that pytorch-3dunet package was installed via conda as described above, one can run the prediction via:

predict3dunet --config <CONFIG>

In order to predict on your own data, just provide the path to your model as well as paths to HDF5 test files (see test_config_dice.yaml).

Prediction tips

In order to avoid checkerboard artifacts in the output prediction masks the patch predictions are averaged, so make sure that patch/stride params lead to overlapping blocks, e.g. patch: [64 128 128] stride: [32 96 96] will give you a 'halo' of 32 voxels in each direction.

Data Parallelism

By default, if multiple GPUs are available training/prediction will be run on all the GPUs using DataParallel. If training/prediction on all available GPUs is not desirable, restrict the number of GPUs using CUDA_VISIBLE_DEVICES, e.g.

CUDA_VISIBLE_DEVICES=0,1 train3dunet --config <CONFIG>

or

CUDA_VISIBLE_DEVICES=0,1 predict3dunet --config <CONFIG>

Examples

Cell boundary predictions for lightsheet images of Arabidopsis thaliana lateral root

The data can be downloaded from the following OSF project:

Training and inference configs can be found in 3DUnet_lightsheet_boundary.

Sample z-slice predictions on the test set (top: raw input , bottom: boundary predictions):

Cell boundary predictions for confocal images of Arabidopsis thaliana ovules

The data can be downloaded from the following OSF project:

Training and inference configs can be found in 3DUnet_confocal_boundary.

Sample z-slice predictions on the test set (top: raw input , bottom: boundary predictions):

Nuclei predictions for lightsheet images of Arabidopsis thaliana lateral root

The training and validation sets can be downloaded from the following OSF project: https://osf.io/thxzn/

Training and inference configs can be found in 3DUnet_lightsheet_nuclei.

Sample z-slice predictions on the test set (top: raw input, bottom: nuclei predictions):

2D nuclei predictions for Kaggle DSB2018

The data can be downloaded from: https://www.kaggle.com/c/data-science-bowl-2018/data

Training and inference configs can be found in 2DUnet_dsb2018.

Sample predictions on the test image (top: raw input, bottom: nuclei predictions):

Contribute

If you want to contribute back, please make a pull request.

Cite

If you use this code for your research, please cite as:

@article {10.7554/eLife.57613,
article_type = {journal},
title = {Accurate and versatile 3D segmentation of plant tissues at cellular resolution},
author = {Wolny, Adrian and Cerrone, Lorenzo and Vijayan, Athul and Tofanelli, Rachele and Barro, Amaya Vilches and Louveaux, Marion and Wenzl, Christian and Strauss, Sören and Wilson-Sánchez, David and Lymbouridou, Rena and Steigleder, Susanne S and Pape, Constantin and Bailoni, Alberto and Duran-Nebreda, Salva and Bassel, George W and Lohmann, Jan U and Tsiantis, Miltos and Hamprecht, Fred A and Schneitz, Kay and Maizel, Alexis and Kreshuk, Anna},
editor = {Hardtke, Christian S and Bergmann, Dominique C and Bergmann, Dominique C and Graeff, Moritz},
volume = 9,
year = 2020,
month = {jul},
pub_date = {2020-07-29},
pages = {e57613},
citation = {eLife 2020;9:e57613},
doi = {10.7554/eLife.57613},
url = {https://doi.org/10.7554/eLife.57613},
abstract = {Quantitative analysis of plant and animal morphogenesis requires accurate segmentation of individual cells in volumetric images of growing organs. In the last years, deep learning has provided robust automated algorithms that approach human performance, with applications to bio-image analysis now starting to emerge. Here, we present PlantSeg, a pipeline for volumetric segmentation of plant tissues into cells. PlantSeg employs a convolutional neural network to predict cell boundaries and graph partitioning to segment cells based on the neural network predictions. PlantSeg was trained on fixed and live plant organs imaged with confocal and light sheet microscopes. PlantSeg delivers accurate results and generalizes well across different tissues, scales, acquisition settings even on non plant samples. We present results of PlantSeg applications in diverse developmental contexts. PlantSeg is free and open-source, with both a command line and a user-friendly graphical interface.},
keywords = {instance segmentation, cell segmentation, deep learning, image analysis},
journal = {eLife},
issn = {2050-084X},
publisher = {eLife Sciences Publications, Ltd},
}
Comments
  • fix weights unsqueeze in PixelWiseCrossEntropy

    fix weights unsqueeze in PixelWiseCrossEntropy

    First off, thanks for the great library, @wolny ! It has really accelerated my work being able to start with a nice implementation of 3D unets.

    I think there might be a small bug in the PixelWiseCrossEntropy loss. It seems that the weights get passed in as a NxDxHxW tensor and in the "expand weights" code block they should be expanded to NxCxDxHxW tensor to match the target (which has been converted to a one hot encoding). Thus, I think the unsqueeze should be applied to axis 1, not axis 0. In this case the weights would become Nx1xDxHxW, then NxCxDxHxW in the subsequent weights.expand_as(input).

    Without this change, I get the following error when I train with batch size > 1.

    2021-03-12 17:05:50,156 [MainThread] INFO UNet3DTrainer - Training iteration [1/100000]. Epoch [0/99]
    Traceback (most recent call last):
      File "/cluster/home/kyamauch/.local/lib/python3.8/site-packages/pytorch3dunet/train.py", line 33, in <module>
        main()
      File "/cluster/home/kyamauch/.local/lib/python3.8/site-packages/pytorch3dunet/train.py", line 29, in main
        trainer.fit()
      File "/cluster/home/kyamauch/.local/lib/python3.8/site-packages/pytorch3dunet/unet3d/trainer.py", line 246, in fit
        should_terminate = self.train()
      File "/cluster/home/kyamauch/.local/lib/python3.8/site-packages/pytorch3dunet/unet3d/trainer.py", line 273, in train
        output, loss = self._forward_pass(input, target, weight)
      File "/cluster/home/kyamauch/.local/lib/python3.8/site-packages/pytorch3dunet/unet3d/trainer.py", line 408, in _forward_pass
        loss = self.loss_criterion(output, target, weight)
      File "/cluster/apps/nss/gcc-6.3.0/python_gpu/3.8.5/torch/nn/modules/module.py", line 727, in _call_impl
        result = self.forward(*input, **kwargs)
      File "/cluster/home/kyamauch/.local/lib/python3.8/site-packages/pytorch3dunet/unet3d/losses.py", line 220, in forward
        weights = weights.expand_as(input)
    RuntimeError: The expanded size of the tensor (3) must match the existing size (12) at non-singleton dimension 1.  Target sizes: [12, 3, 70, 70, 70].  Tensor sizes: [1, 12, 70, 70, 70]
    

    Does this change seem right?

    opened by kevinyamauchi 2
  • Update environment.yaml

    Update environment.yaml

    pytorch channel should have higher priority than conda-forge, otherwise the pytorch installation from conda-forge will be used. (And this causes issues with gpu installations)

    opened by constantinpape 1
  • Create command-lines (i.e. console_scripts) when installing from source

    Create command-lines (i.e. console_scripts) when installing from source

    Hi,

    I know that the command lines are installed into the conda environment.

    This code adds commands when installing from source (i.e. python setup.py install). I needed to do this as I ultimately want to call pytorch-3dunet within mpi2/LAMA and don't want to use the conda env due to install issues etc.

    Feel free to merge if it doesn't cause conflicts.

    Kind Regards, Kyle Drover

    opened by dorkylever 1
  • Read data path config as a directory

    Read data path config as a directory

    There may be many hdf5 data files, and it is common putting all data files in a directory. Specify all paths in the config file is somewhat inconvenient and makes the config unreadable.

    opened by songxiaocheng 1
  • Add Squeeze and Excitation and UNETR as an option

    Add Squeeze and Excitation and UNETR as an option

    Squeeze and Excitation UNet and UNETR can be selected as an option to train in config.yml.

    Example (UNETR):

    # use a fixed random seed to guarantee that when you run the code twice you will get the same outcome
    manual_seed: 0
    model:
      name: UNETR
      # number of input channels to the model
      in_channels: 1
      ...
    

    Example (SE UNet):

    # use a fixed random seed to guarantee that when you run the code twice you will get the same outcome
    manual_seed: 0
    model:
      name: ResidualUNetSE3D
      # number of input channels to the model
      in_channels: 1
      ...
    

    Credits for UNETR code.

    opened by imadtoubal 0
Owner
Adrian Wolny
PhD student in Machine Learning @HCIHeidelberg
Adrian Wolny
領域を指定し、キーを入力することで画像を保存するツールです。クラス分類用のデータセット作成を想定しています。

image-capture-class-annotation 領域を指定し、キーを入力することで画像を保存するツールです。 クラス分類用のデータセット作成を想定しています。 Requirement OpenCV 3.4.2 or later Usage 実行方法は以下です。 起動後はマウスクリック4

KazuhitoTakahashi 5 May 28, 2021
Equipped customers with insights about their EVs Hourly energy consumption and helped predict future charging behavior using LSTM model

Equipped customers with insights about their EVs Hourly energy consumption and helped predict future charging behavior using LSTM model. Designed sample dashboard with insights and recommendation for

Yash 2 Apr 07, 2022
code for Grapadora research paper experimentation

Road feature embedding selection method Code for research paper experimentation Abstract Traffic forecasting models rely on data that needs to be sens

Eric López Manibardo 0 May 26, 2022
Pytorch implementation of Rosca, Mihaela, et al. "Variational Approaches for Auto-Encoding Generative Adversarial Networks."

alpha-GAN Unofficial pytorch implementation of Rosca, Mihaela, et al. "Variational Approaches for Auto-Encoding Generative Adversarial Networks." arXi

Victor Shepardson 78 Dec 08, 2022
An automated facial recognition based attendance system (desktop application)

Facial_Recognition_based_Attendance_System An automated facial recognition based attendance system (desktop application) Made using Python, Tkinter an

1 Jun 21, 2022
Official PyTorch implementation of "Proxy Synthesis: Learning with Synthetic Classes for Deep Metric Learning" (AAAI 2021)

Proxy Synthesis: Learning with Synthetic Classes for Deep Metric Learning Official PyTorch implementation of "Proxy Synthesis: Learning with Synthetic

NAVER/LINE Vision 30 Dec 06, 2022
We will release the code of "ConTNet: Why not use convolution and transformer at the same time?" in this repo

ConTNet Introduction ConTNet (Convlution-Tranformer Network) is proposed mainly in response to the following two issues: (1) ConvNets lack a large rec

93 Nov 08, 2022
A distributed deep learning framework that supports flexible parallelization strategies.

FlexFlow FlexFlow is a deep learning framework that accelerates distributed DNN training by automatically searching for efficient parallelization stra

528 Dec 25, 2022
An ML & Correlation platform for transforming disparate data points of interest into usable intelligence.

SSIDprobeCollector An ML & Correlation platform for transforming disparate data points of interest into usable intelligence. At a High level the platf

Bill Reyor 1 Jan 30, 2022
DGL-TreeSearch and the Gurobi-MWIS interface

Independent Set Benchmarking Suite This repository contains the code for our maximum independent set benchmarking suite as well as our implementations

Maximilian Böther 19 Nov 22, 2022
Implementation of SwinTransformerV2 in TensorFlow.

SwinTransformerV2-TensorFlow A TensorFlow implementation of SwinTransformerV2 by Microsoft Research Asia, based on their official implementation of Sw

Phan Nguyen 2 May 30, 2022
FlexConv: Continuous Kernel Convolutions with Differentiable Kernel Sizes

FlexConv: Continuous Kernel Convolutions with Differentiable Kernel Sizes This repository contains the source code accompanying the paper: FlexConv: C

Robert-Jan Bruintjes 96 Dec 12, 2022
This repository is based on Ultralytics/yolov5, with adjustments to enable rotate prediction boxes.

Rotate-Yolov5 This repository is based on Ultralytics/yolov5, with adjustments to enable rotate prediction boxes. Section I. Description The codes are

xinzelee 90 Dec 13, 2022
This is code of book "Learn Deep Learning with PyTorch"

深度学习入门之PyTorch Learn Deep Learning with PyTorch 非常感谢您能够购买此书,这个github repository包含有深度学习入门之PyTorch的实例代码。由于本人水平有限,在写此书的时候参考了一些网上的资料,在这里对他们表示敬意。由于深度学习的技术在

Xingyu Liao 2.5k Jan 04, 2023
ATAC: Adversarially Trained Actor Critic

ATAC: Adversarially Trained Actor Critic Adversarially Trained Actor Critic for Offline Reinforcement Learning by Ching-An Cheng*, Tengyang Xie*, Nan

Microsoft 41 Dec 08, 2022
Continuous Augmented Positional Embeddings (CAPE) implementation for PyTorch

PyTorch implementation of Continuous Augmented Positional Embeddings (CAPE), by Likhomanenko et al. Enhance your Transformer positional embeddings with easy-to-use augmentations!

Guillermo Cámbara 26 Dec 13, 2022
PyTorch code for ICPR 2020 paper Future Urban Scene Generation Through Vehicle Synthesis

Future urban scene generation through vehicle synthesis This repository contains Pytorch code for the ICPR2020 paper "Future Urban Scene Generation Th

Alessandro Simoni 4 Oct 11, 2021
Semi-Supervised Semantic Segmentation via Adaptive Equalization Learning, NeurIPS 2021 (Spotlight)

Semi-Supervised Semantic Segmentation via Adaptive Equalization Learning, NeurIPS 2021 (Spotlight) Abstract Due to the limited and even imbalanced dat

Hanzhe Hu 99 Dec 12, 2022
GANsformer: Generative Adversarial Transformers Drew A

GANformer: Generative Adversarial Transformers Drew A. Hudson* & C. Lawrence Zitnick Update: We released the new GANformer2 paper! *I wish to thank Ch

Drew Arad Hudson 1.2k Jan 02, 2023
Deep Unsupervised 3D SfM Face Reconstruction Based on Massive Landmark Bundle Adjustment.

(ACMMM 2021 Oral) SfM Face Reconstruction Based on Massive Landmark Bundle Adjustment This repository shows two tasks: Face landmark detection and Fac

BoomStar 51 Dec 13, 2022