Runtime type annotations for the shape, dtype etc. of PyTorch Tensors.

Overview

torchtyping

Type annotations for a tensor's shape, dtype, names, ...

Turn this:

def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    # x has shape (batch, x_channels)
    # y has shape (batch, y_channels)
    # return has shape (batch, x_channels, y_channels)

    return x.unsqueeze(-1) * y.unsqueeze(-2)

into this:

def batch_outer_product(x:   TensorType["batch", "x_channels"],
                        y:   TensorType["batch", "y_channels"]
                        ) -> TensorType["batch", "x_channels", "y_channels"]:

    return x.unsqueeze(-1) * y.unsqueeze(-2)

with programmatic checking that the shape (dtype, ...) specification is met.

Bye-bye bugs! Say hello to enforced, clear documentation of your code.

If (like me) you find yourself littering your code with comments like # x has shape (batch, hidden_state) or statements like assert x.shape == y.shape , just to keep track of what shape everything is, then this is for you.


Installation

pip install torchtyping

Requires Python 3.7+ and PyTorch 1.7.0+.

Usage

torchtyping allows for type annotating:

  • shape: size, number of dimensions;
  • dtype (float, integer, etc.);
  • layout (dense, sparse);
  • names of dimensions as per named tensors;
  • arbitrary number of batch dimensions with ...;
  • ...plus anything else you like, as torchtyping is highly extensible.

If typeguard is (optionally) installed then at runtime the types can be checked to ensure that the tensors really are of the advertised shape, dtype, etc.

# EXAMPLE

from torch import rand
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked

patch_typeguard()  # use before @typechecked

@typechecked
def func(x: TensorType["batch"],
         y: TensorType["batch"]) -> TensorType["batch"]:
    return x + y

func(rand(3), rand(3))  # works
func(rand(3), rand(1))
# TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3.

typeguard also has an import hook that can be used to automatically test an entire module, without needing to manually add @typeguard.typechecked decorators.

If you're not using typeguard then torchtyping.patch_typeguard() can be omitted altogether, and torchtyping just used for documentation purposes. If you're not already using typeguard for your regular Python programming, then strongly consider using it. It's a great way to squash bugs. Both typeguard and torchtyping also integrate with pytest, so if you're concerned about any performance penalty then they can be enabled during tests only.

API

torchtyping.TensorType[shape, dtype, layout, details]

The core of the library.

Each of shape, dtype, layout, details are optional.

  • The shape argument can be any of:
    • An int: the dimension must be of exactly this size. If it is -1 then any size is allowed.
    • A str: the size of the dimension passed at runtime will be bound to this name, and all tensors checked that the sizes are consistent.
    • A ...: An arbitrary number of dimensions of any sizes.
    • A str: int pair (technically it's a slice), combining both str and int behaviour. (Just a str on its own is equivalent to str: -1.)
    • A str: ... pair, in which case the multiple dimensions corresponding to ... will be bound to the name specified by str, and again checked for consistency between arguments.
    • None, which when used in conjunction with is_named below, indicates a dimension that must not have a name in the sense of named tensors.
    • A None: int pair, combining both None and int behaviour. (Just a None on its own is equivalent to None: -1.)
    • A typing.Any: Any size is allowed for this dimension (equivalent to -1).
    • Any tuple of the above. For example.TensorType["batch": ..., "length": 10, "channels", -1]. If you just want to specify the number of dimensions then use for example TensorType[-1, -1, -1] for a three-dimensional tensor.
  • The dtype argument can be any of:
    • torch.float32, torch.float64 etc.
    • int, bool, float, which are converted to their corresponding PyTorch types. float is specifically interpreted as torch.get_default_dtype(), which is usually float32.
  • The layout argument can be either torch.strided or torch.sparse_coo, for dense and sparse tensors respectively.
  • The details argument offers a way to pass an arbitrary number of additional flags that customise and extend torchtyping. Two flags are built-in by default. torchtyping.is_named causes the names of tensor dimensions to be checked, and torchtyping.is_float can be used to check that arbitrary floating point types are passed in. (Rather than just a specific one as with e.g. TensorType[torch.float32].) For discussion on how to customise torchtyping with your own details, see the further documentation.

Check multiple things at once by just putting them all together inside a single []. For example TensorType["batch": ..., "length", "channels", float, is_named].

torchtyping.patch_typeguard()

torchtyping integrates with typeguard to perform runtime type checking. torchtyping.patch_typeguard() should be called at the global level, and will patch typeguard to check TensorTypes.

This function is safe to run multiple times. (It does nothing after the first run).

  • If using @typeguard.typechecked, then torchtyping.patch_typeguard() should be called any time before using @typeguard.typechecked. For example you could call it at the start of each file using torchtyping.
  • If using typeguard.importhook.install_import_hook, then torchtyping.patch_typeguard() should be called any time before defining the functions you want checked. For example you could call torchtyping.patch_typeguard() just once, at the same time as the typeguard import hook. (The order of the hook and the patch doesn't matter.)
  • If you're not using typeguard then torchtyping.patch_typeguard() can be omitted altogether, and torchtyping just used for documentation purposes.
pytest --torchtyping-patch-typeguard

torchtyping offers a pytest plugin to automatically run torchtyping.patch_typeguard() before your tests. pytest will automatically discover the plugin, you just need to pass the --torchtyping-patch-typeguard flag to enable it. Packages can then be passed to typeguard as normal, either by using @typeguard.typechecked, typeguard's import hook, or the pytest flag --typeguard-packages="your_package_here".

Further documentation

See the further documentation for:

  • FAQ;
    • Including flake8 and mypy compatibility;
  • How to write custom extensions to torchtyping;
  • Resources and links to other libraries and materials on this topic;
  • More examples.
Owner
Patrick Kidger
Maths+ML PhD student at Oxford. Neural ODEs+SDEs+CDEs, time series, rough analysis. (Also ice skating, martial arts and scuba diving!)
Patrick Kidger
Source code for "FastBERT: a Self-distilling BERT with Adaptive Inference Time".

FastBERT Source code for "FastBERT: a Self-distilling BERT with Adaptive Inference Time". Good News 2021/10/29 - Code: Code of FastPLM is released on

Weijie Liu 584 Jan 02, 2023
Official implementation of the Neurips 2021 paper Searching Parameterized AP Loss for Object Detection.

Parameterized AP Loss By Chenxin Tao, Zizhang Li, Xizhou Zhu, Gao Huang, Yong Liu, Jifeng Dai This is the official implementation of the Neurips 2021

46 Jul 06, 2022
List of content farm sites like g.penzai.com.

内容农场网站清单 Google 中文搜索结果包含了相当一部分的内容农场式条目,比如「小 X 知识网」「小 X 百科网」。此种链接常会 302 重定向其主站,页面内容为自动生成,大量堆叠关键字,揉杂一些爬取到的内容,完全不具可读性和参考价值。 尤为过分的是,该类网站可能有成千上万个分身域名被 Goog

WDMPA 541 Jan 03, 2023
Robotics with GPU computing

Robotics with GPU computing Cupoch is a library that implements rapid 3D data processing for robotics using CUDA. The goal of this library is to imple

Shirokuma 625 Jan 07, 2023
Mmdet benchmark with python

mmdet_benchmark 本项目是为了研究 mmdet 推断性能瓶颈,并且对其进行优化。 配置与环境 机器配置 CPU:Intel(R) Core(TM) i9-10900K CPU @ 3.70GHz GPU:NVIDIA GeForce RTX 3080 10GB 内存:64G 硬盘:1T

杨培文 (Yang Peiwen) 24 May 21, 2022
rliable is an open-source Python library for reliable evaluation, even with a handful of runs, on reinforcement learning and machine learnings benchmarks.

Open-source library for reliable evaluation on reinforcement learning and machine learning benchmarks. See NeurIPS 2021 oral for details.

Google Research 529 Jan 01, 2023
City-Scale Multi-Camera Vehicle Tracking Guided by Crossroad Zones Code

City-Scale Multi-Camera Vehicle Tracking Guided by Crossroad Zones Requirements Python 3.8 or later with all requirements.txt dependencies installed,

88 Dec 12, 2022
Neural network for recognizing the gender of people in photos

Neural Network For Gender Recognition How to test it? Install requirements.txt file using pip install -r requirements.txt command Run nn.py using pyth

Valery Chapman 1 Sep 18, 2022
Hidden-Fold Networks (HFN): Random Recurrent Residuals Using Sparse Supermasks

Hidden-Fold Networks (HFN): Random Recurrent Residuals Using Sparse Supermasks by Ángel López García-Arias, Masanori Hashimoto, Masato Motomura, and J

Ángel López García-Arias 4 May 19, 2022
Independent and minimal implementations of some reinforcement learning algorithms using PyTorch (including PPO, A3C, A2C, ...).

PyTorch RL Minimal Implementations There are implementations of some reinforcement learning algorithms, whose characteristics are as follow: Less pack

Gemini Light 4 Dec 31, 2022
Simulation-based performance analysis of server-less Blockchain-enabled Federated Learning

Blockchain-enabled Server-less Federated Learning Repository containing the files used to reproduce the results of the publication "Blockchain-enabled

Francesc Wilhelmi 9 Sep 27, 2022
Code to reproduce the results in "Visually Grounded Reasoning across Languages and Cultures", EMNLP 2021.

marvl-code [WIP] This is the implementation of the approaches described in the paper: Fangyu Liu*, Emanuele Bugliarello*, Edoardo M. Ponti, Siva Reddy

25 Nov 15, 2022
A simple, fast, and efficient object detector without FPN

You Only Look One-level Feature (YOLOF), CVPR2021 A simple, fast, and efficient object detector without FPN. This repo provides an implementation for

789 Jan 09, 2023
Unofficial PyTorch Implementation of "DOLG: Single-Stage Image Retrieval with Deep Orthogonal Fusion of Local and Global Features"

Pytorch Implementation of Deep Orthogonal Fusion of Local and Global Features (DOLG) This is the unofficial PyTorch Implementation of "DOLG: Single-St

DK 96 Jan 06, 2023
This dlib-based facial login system

Facial-Login-System This dlib-based facial login system is a technology capable of matching a human face from a digital webcam frame capture against a

Mushahid Ali 3 Apr 23, 2022
Code for the paper "PortraitNet: Real-time portrait segmentation network for mobile device" @ CAD&Graphics2019

PortraitNet Code for the paper "PortraitNet: Real-time portrait segmentation network for mobile device". @ CAD&Graphics 2019 Introduction We propose a

265 Dec 01, 2022
A lightweight Python-based 3D network multi-agent simulator. Uses a cell-based congestion model. Calculates risk, loudness and battery capacities of the agents. Suitable for 3D network optimization tasks.

AMAZ3DSim AMAZ3DSim is a lightweight python-based 3D network multi-agent simulator. It uses a cell-based congestion model. It calculates risk, battery

Daniel Hirsch 13 Nov 04, 2022
PyTorch implementation for COMPLETER: Incomplete Multi-view Clustering via Contrastive Prediction (CVPR 2021)

Completer: Incomplete Multi-view Clustering via Contrastive Prediction This repo contains the code and data of the following paper accepted by CVPR 20

XLearning Group 72 Dec 07, 2022
Face Alignment using python

Face Alignment Face Alignment using python Input Image Aligned Face Aligned Face Aligned Face Input Image Aligned Face Input Image Aligned Face Instal

Sajjad Aemmi 28 Nov 23, 2022
Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch

Segformer - Pytorch Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch. Install $ pip install segformer-pytorch

Phil Wang 208 Dec 25, 2022