Torch-mutable-modules - Use in-place and assignment operations on PyTorch module parameters with support for autograd

Overview

Torch Mutable Modules

Use in-place and assignment operations on PyTorch module parameters with support for autograd.

Publish to PyPI Run tests PyPI version Number of downloads from PyPI per month Python version support Code Style: Black

Why does this exist?

PyTorch does not allow in-place operations on module parameters (usually desirable):

linear_layer = torch.nn.Linear(1, 1)
linear_layer.weight.data += 69
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Valid, but will NOT store grad_fn=<AddBackward0>
linear_layer.weight += 420
# ^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

In some cases, however, it is useful to be able to modify module parameters in-place. For example, if we have a neural network (net_1) that predicts the parameter values to another neural network (net_2), we need to be able to modify the weights of net_2 in-place and backpropagate the gradients to net_1.

# create a parameter predictor network (net_1)
net_1 = torch.nn.Linear(1, 2)

# predict the weights and biases of net_2 using net_1
p_weight_and_bias = net_1(input_0).unsqueeze(2)
p_weight, p_bias = p_weight_and_bias[:, 0], p_weight_and_bias[:, 1]

# create a mutable network (net_2)
net_2 = to_mutable_module(torch.nn.Linear(1, 1))

# hot-swap the weights and biases of net_2 with the predicted values
net_2.weight = p_weight
net_2.bias = p_bias

# compute the output and backpropagate the gradients to net_1
output = net_2(input_1)
loss = criterion(output, label)
loss.backward()
optimizer.step()

This library provides a way to easily convert PyTorch modules into mutable modules with the to_mutable_module function.

Installation

You can install torch-mutable-modules from PyPI.

pip install torch-mutable-modules

To upgrade an existing installation of torch-mutable-modules, use the following command:

pip install --upgrade --no-cache-dir torch-mutable-modules

Importing

You can use wildcard imports or import specific functions directly:

# import all functions
from torch_mutable_modules import *

# ... or import the function manually
from torch_mutable_modules import to_mutable_module

Usage

To convert an existing PyTorch module into a mutable module, use the to_mutable_module function:

converted_module = to_mutable_module(
    torch.nn.Linear(1, 1)
) # type of converted_module is still torch.nn.Linear

converted_module.weight *= 0
convreted_module.weight += 69
convreted_module.weight # tensor([[69.]], grad_fn=<AddBackward0>)

You can also declare your own PyTorch module classes as mutable, and all child modules will be recursively converted into mutable modules:

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)
    
    def forward(self, x):
        return self.linear(x)

my_module = to_mutable_module(MyModule())
my_module.linear.weight *= 0
my_module.linear.weight += 69
my_module.linear.weight # tensor([[69.]], grad_fn=<AddBackward0>)

Usage with CUDA

To create a module on the GPU, simply pass a PyTorch module that is already on the GPU to the to_mutable_module function:

converted_module = to_mutable_module(
    torch.nn.Linear(1, 1).cuda()
) # converted_module is now a mutable module on the GPU

Moving a module to the GPU with .to() and .cuda() after instanciation is NOT supported. Instead, hot-swap the module parameter tensors with their CUDA counterparts.

# both of these are valid
converted_module.weight = converted_module.weight.cuda()
converted_module.bias = converted_module.bias.to("cuda")

Detailed examples

Please check out example.py to see more detailed example usages of the to_mutable_module function.

Contributing

Please feel free to submit issues or pull requests!

You might also like...
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

Implements Stacked-RNN in numpy and torch with manual forward and backward functions

Recurrent Neural Networks Implements simple recurrent network and a stacked recurrent network in numpy and torch respectively. Both flavours implement

A torch.Tensor-like DataFrame library supporting multiple execution runtimes and Arrow as a common memory format

TorchArrow (Warning: Unstable Prototype) This is a prototype library currently under heavy development. It does not currently have stable releases, an

A complete end-to-end demonstration in which we collect training data in Unity and use that data to train a deep neural network to predict the pose of a cube. This model is then deployed in a simulated robotic pick-and-place task.
A complete end-to-end demonstration in which we collect training data in Unity and use that data to train a deep neural network to predict the pose of a cube. This model is then deployed in a simulated robotic pick-and-place task.

Object Pose Estimation Demo This tutorial will go through the steps necessary to perform pose estimation with a UR3 robotic arm in Unity. You’ll gain

Python implementation of MULTIseq barcode alignment using fuzzy string matching and GMM barcode assignment

Python implementation of MULTIseq barcode alignment using fuzzy string matching and GMM barcode assignment.

 MM1 and MMC Queue Simulation using python - Results and parameters in excel and csv files
MM1 and MMC Queue Simulation using python - Results and parameters in excel and csv files

implementation of MM1 and MMC Queue on randomly generated data and evaluate simulation results then compare with analytical results and draw a plot curve for them, simulate some integrals and compare results and run monte carlo algorithm with them

Torch-based tool for quantizing high-dimensional vectors using additive codebooks

Trainable multi-codebook quantization This repository implements a utility for use with PyTorch, and ideally GPUs, for training an efficient quantizer

Torch implementation of
Torch implementation of "Enhanced Deep Residual Networks for Single Image Super-Resolution"

NTIRE2017 Super-resolution Challenge: SNU_CVLab Introduction This is our project repository for CVPR 2017 Workshop (2nd NTIRE). We, Team SNU_CVLab, (B

Automatic number plate recognition using tech:  Yolo, OCR, Scene text detection, scene text recognation, flask, torch
Automatic number plate recognition using tech: Yolo, OCR, Scene text detection, scene text recognation, flask, torch

Automatic Number Plate Recognition Automatic Number Plate Recognition (ANPR) is the process of reading the characters on the plate with various optica

Releases(v1.1.2)
Owner
Kento Nishi
17-year-old programmer at Lynbrook High School, with strong interests in AI/Machine Learning. Open source developer and researcher at the Four Eyes Lab.
Kento Nishi
Optimizes image files by converting them to webp while also updating all references.

About Optimizes images by (re-)saving them as webp. For every file it replaced it automatically updates all references. Works on single files as well

Watermelon Wolverine 18 Dec 23, 2022
Relative Human dataset, CVPR 2022

Relative Human (RH) contains multi-person in-the-wild RGB images with rich human annotations, including: Depth layers (DLs): relative depth relationsh

Yu Sun 112 Dec 02, 2022
Code for ICLR 2020 paper "VL-BERT: Pre-training of Generic Visual-Linguistic Representations".

VL-BERT By Weijie Su, Xizhou Zhu, Yue Cao, Bin Li, Lewei Lu, Furu Wei, Jifeng Dai. This repository is an official implementation of the paper VL-BERT:

Weijie Su 698 Dec 18, 2022
A-ESRGAN aims to provide better super-resolution images by using multi-scale attention U-net discriminators.

A-ESRGAN: Training Real-World Blind Super-Resolution with Attention-based U-net Discriminators The authors are hidden for the purpose of double blind

77 Dec 16, 2022
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022
D2LV: A Data-Driven and Local-Verification Approach for Image Copy Detection

Facebook AI Image Similarity Challenge: Matching Track —— Team: imgFp This is the source code of our 3rd place solution to matching track of Image Sim

16 Dec 25, 2022
Contour-guided image completion with perceptual grouping (BMVC 2021 publication)

Contour-guided Image Completion with Perceptual Grouping Authors Morteza Rezanejad*, Sidharth Gupta*, Chandra Gummaluru, Ryan Marten, John Wilder, Mic

Sid Gupta 6 Dec 27, 2022
Paddle pit - Rethinking Spatial Dimensions of Vision Transformers

基于Paddle实现PiT ——Rethinking Spatial Dimensions of Vision Transformers,arxiv 官方原版代

Hongtao Wen 4 Jan 15, 2022
This repository contains the code for the paper Neural RGB-D Surface Reconstruction

Neural RGB-D Surface Reconstruction Paper | Project Page | Video Neural RGB-D Surface Reconstruction Dejan Azinović, Ricardo Martin-Brualla, Dan B Gol

Dejan 406 Jan 04, 2023
PyG (PyTorch Geometric) - A library built upon PyTorch to easily write and train Graph Neural Networks (GNNs)

PyG (PyTorch Geometric) is a library built upon PyTorch to easily write and train Graph Neural Networks (GNNs) for a wide range of applications related to structured data.

PyG 16.5k Jan 08, 2023
LogAvgExp - Pytorch Implementation of LogAvgExp

LogAvgExp - Pytorch Implementation of LogAvgExp for Pytorch Install $ pip instal

Phil Wang 31 Oct 14, 2022
这是一个yolox-keras的源码,可以用于训练自己的模型。

YOLOX:You Only Look Once目标检测模型在Keras当中的实现 目录 性能情况 Performance 实现的内容 Achievement 所需环境 Environment 小技巧的设置 TricksSet 文件下载 Download 训练步骤 How2train 预测步骤 Ho

Bubbliiiing 64 Nov 10, 2022
PyTorch Code of "Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spatiotemporal Dynamics"

Memory In Memory Networks It is based on the paper Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spati

Yang Li 12 May 30, 2022
Implementation of the paper Scalable Intervention Target Estimation in Linear Models (NeurIPS 2021), and the code to generate simulation results.

Scalable Intervention Target Estimation in Linear Models Implementation of the paper Scalable Intervention Target Estimation in Linear Models (NeurIPS

0 Oct 25, 2021
Real-Time Multi-Contact Model Predictive Control via ADMM

Here, you can find the code for the paper 'Real-Time Multi-Contact Model Predictive Control via ADMM'. Code is currently being cleared up and optimize

17 Dec 28, 2022
Unofficial PyTorch implementation of "RTM3D: Real-time Monocular 3D Detection from Object Keypoints for Autonomous Driving" (ECCV 2020)

RTM3D-PyTorch The PyTorch Implementation of the paper: RTM3D: Real-time Monocular 3D Detection from Object Keypoints for Autonomous Driving (ECCV 2020

Nguyen Mau Dzung 271 Nov 29, 2022
Combining Reinforcement Learning and Constraint Programming for Combinatorial Optimization

Hybrid solving process for combinatorial optimization problems Combinatorial optimization has found applications in numerous fields, from aerospace to

117 Dec 13, 2022
Predicting the duration of arrival delays for commercial flights.

Flight Delay Prediction Our objective is to predict arrival delays of commercial flights. According to the US Department of Transportation, about 21%

Jordan Silke 1 Jan 11, 2022
Official project repository for 'Normality-Calibrated Autoencoder for Unsupervised Anomaly Detection on Data Contamination'

NCAE_UAD Official project repository of 'Normality-Calibrated Autoencoder for Unsupervised Anomaly Detection on Data Contamination' Abstract In this p

Jongmin Andrew Yu 2 Feb 10, 2022
Automatic voice-synthetised summaries of latest research papers on arXiv

PaperWhisperer PaperWhisperer is a Python application that keeps you up-to-date with research papers. How? It retrieves the latest articles from arXiv

Valerio Velardo 124 Dec 20, 2022