TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards

Overview

Documents | Projects | API References

TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards. It can reduce GPU memory and scale up the training when the model has massive linear layers (e.g., ViT, BERT and GPT) or huge classes (millions). It has the same API design as PyTorch.

Installation

pip install torchshard

More options in INSTALL.md.

Usage

import torchshard as ts

ts.init_process_group(group_size=2)                       # init parallel groups

m = torch.nn.Sequential(
    torch.nn.Linear(20, 30, bias=True),               
    ts.nn.ParallelLinear(30, 30, bias=True, dim=None),    # equal to nn.Linear()
    ts.nn.ParallelLinear(30, 30, bias=True, dim=0),       # parallel in row dimension
    ts.nn.ParallelLinear(30, 30, bias=True, dim=1),       # parallel in column dimension
).cuda()

x = m(x)                                                  # forward
loss = ts.nn.functional.parallel_cross_entropy(x, y)      # parallel loss function
loss.backward()                                           # backward

torch.save(
  ts.collect_state_dict(m, m.state_dict()), 'm.pt')       # save model state

Performance

The following figure is a showcase of training ResNet-50 on 8 NVIDIA TITAN-XP (12196 MiB) GPUs with scaling up classes from 1000 → 1 Million. The input size is 224 x 224, and the batch size is 256. Parallelism is with 8-way data parallel and 8-way model parallel.

The following figure shows training minGPT on 8 NVIDIA TITAN-XP (12196 MiB) GPUs with scaling up parameters from 10 Million → 808 Million. The input size is 32 x 32, and the batch size is 16. Parallelism is with 1-way data parallel and 8-way model parallel.

Contributing

The TorchShard welcomes your expertise and enthusiasm!

If you are interested in torchshard, you are welcome to help

  • polish code and develop new features
  • develop high-quality tutorials, projects, and advanced materials

Direct pull requests are welcome. Contact: kaiyuyue [at] umd.edu.

Citing TorchShard

If you think TorchShard is helpful in your research and consider to cite it, please use the following BibTeX entry.

@misc{torchshard2021,
  author =       {Kaiyu Yue},
  title =        {TorchShard},
  howpublished = {\url{https://github.com/KaiyuYue/torchshard}},
  year =         {2021}
}
Comments
  • Future Planinig on this project.

    Future Planinig on this project.

    Hello Kaiyu, I love this awesome project. The API design is elegant and simple and the software is lightweight and user-friendly. My understanding is that this project has realized a series of PyTorch wrappers for tensor slicing.

    1. I am curious about the future planning of this project.
    2. Is there some overlap in functionality between torchshard and N-D parallelism proposed in ColossalAI.
    3. How is compatibility with ZeRO? According to am+zero example, the memory footprint has a little change after combination torchshard with ZeRO.
    opened by feifeibear 2
  • Which one is faster?

    Which one is faster?

    Thanks for contributing this great lib. I have one question. Which one is faster (in speed) between dim=0and dim=1? The documentations seem to only contain accuracy results.

    opened by NOBLES5E 2
  • 8 gpus test example raise error.

    8 gpus test example raise error.

    When I do Unit Tests, it can pass when use two gpu devices, run command below: CUDA_VISIBLE_DEVICES=0,1 python3 -m unittest discover -v -s tests

    But I do Unit Tests with eight gpu devices, it raise ncclSystemError. run command: CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m unittest discover -v -s tests raise error: RuntimeError: NCCL error in ../torch/lib/c10d/ProcessGroupNCCL.cpp:825, unhandled system error, NCCL version 2.7.8 ncclSystemError: System call (socket, malloc, munmap, etc) failed.

    Is it necessary to pass unittest in eights gpu devices?

    opened by JiaquanYe 1
  • Error?

    Error?

    Hi, thanks for the excellent job! When I install it from pip, and

    import torchshard as ts
    ts.init_process_group(group_size=2) 
    

    The AttributeError occurs:

    AttributeError: module 'torchshard' has no attribute 'init_process_group'
    
    opened by WangWenhao0716 1
  • Multi-node setting?

    Multi-node setting?

    https://github.com/KaiyuYue/torchshard/blob/89e21def180bf6063ceb2e312a61631173abc7e7/projects/minGPT/main.py#L150

    I have noticed that the group_size is set to world_size in examples, but in fact the group_size can be set to other numbers according to my understanding.

    https://github.com/KaiyuYue/torchshard/blob/main/torchshard/distributed/core.py#L18

    I have also found that the get_world_size() will return the number of all processes.

    The two findings make me confused in a multi-node setting, say 2 nodes with each node with 2 processes.

    If the group_size is 2, then there are 2 distinct groups besides the default group (w/ overlap). However, get_world_size() is used without specifying a group can make a layer be splitted to 4 parts, which is expected to be 2 in our case.

    Correct me if I am wrong.

    Good Issue 
    opened by GeneZC 1
  • Is it possible to collect state dict in cpu?

    Is it possible to collect state dict in cpu?

    When I finish one epoch in trianing, the main_worker function will call ts.collect_state_dict(model, state_dict). But because the limit of GPU resource, it will raise Out of Memory in my machine, when call ts.collect_state_dict(model, state_dict). I found that will gather the state_dict in GPU, is it anyway to gather in CPU?

    Good Issue 
    opened by JiaquanYe 2
Releases(v0.1)
Owner
Kaiyu Yue
Kaiyu Yue
OptNet: Differentiable Optimization as a Layer in Neural Networks

OptNet: Differentiable Optimization as a Layer in Neural Networks This repository is by Brandon Amos and J. Zico Kolter and contains the PyTorch sourc

CMU Locus Lab 428 Dec 24, 2022
3D-RETR: End-to-End Single and Multi-View3D Reconstruction with Transformers

3D-RETR: End-to-End Single and Multi-View 3D Reconstruction with Transformers (BMVC 2021) Zai Shi*, Zhao Meng*, Yiran Xing, Yunpu Ma, Roger Wattenhofe

Zai Shi 36 Dec 21, 2022
lookahead optimizer (Lookahead Optimizer: k steps forward, 1 step back) for pytorch

lookahead optimizer for pytorch PyTorch implement of Lookahead Optimizer: k steps forward, 1 step back Usage: base_opt = torch.optim.Adam(model.parame

Liam 318 Dec 09, 2022
A tiny package to compare two neural networks in PyTorch

Compare neural networks by their feature similarity

Anand Krishnamoorthy 180 Dec 30, 2022
Learning Sparse Neural Networks through L0 regularization

Example implementation of the L0 regularization method described at Learning Sparse Neural Networks through L0 regularization, Christos Louizos, Max W

AMLAB 202 Nov 10, 2022
PyTorch implementations of normalizing flow and its variants.

PyTorch implementations of normalizing flow and its variants.

Tatsuya Yatagawa 55 Dec 01, 2022
A Closer Look at Structured Pruning for Neural Network Compression

A Closer Look at Structured Pruning for Neural Network Compression Code used to reproduce experiments in https://arxiv.org/abs/1810.04622. To prune, w

Bayesian and Neural Systems Group 140 Dec 05, 2022
PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf

README TabNet : Attentive Interpretable Tabular Learning This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attent

DreamQuark 2k Dec 27, 2022
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Tez: a simple pytorch trainer NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something does

abhishek thakur 1.1k Jan 04, 2023
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022
Pytorch implementation of Distributed Proximal Policy Optimization

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

Alexis David Jacq 164 Jan 05, 2023
A very simple and small path tracer written in pytorch meant to be run on the GPU

MentisOculi Pytorch Path Tracer A very simple and small path tracer written in pytorch meant to be run on the GPU Why use pytorch and not some other c

Matthew B. Mirman 222 Dec 01, 2022
An implementation of Performer, a linear attention-based transformer, in Pytorch

Performer - Pytorch An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random

Phil Wang 900 Dec 22, 2022
Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Quiver Team 221 Dec 22, 2022
PyTorch extensions for fast R&D prototyping and Kaggle farming

Pytorch-toolbelt A pytorch-toolbelt is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming: What

Eugene Khvedchenya 1.3k Jan 05, 2023
Fast Discounted Cumulative Sums in PyTorch

TODO: update this README! Fast Discounted Cumulative Sums in PyTorch This repository implements an efficient parallel algorithm for the computation of

Daniel Povey 7 Feb 17, 2022
PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

Kim Seonghyeon 433 Dec 27, 2022
Pytorch bindings for Fortran

Pytorch bindings for Fortran

Dmitry Alexeev 46 Dec 29, 2022
A PyTorch implementation of EfficientNet

EfficientNet PyTorch Quickstart Install with pip install efficientnet_pytorch and load a pretrained EfficientNet with: from efficientnet_pytorch impor

Luke Melas-Kyriazi 7.2k Jan 06, 2023