Feature extraction made simple with torchextractor

Overview

torchextractor: PyTorch Intermediate Feature Extraction

PyPI - Python Version PyPI Read the Docs Upload Python Package GitHub

Introduction

Too many times some model definitions get remorselessly copy-pasted just because the forward function does not return what the person expects. You provide module names and torchextractor takes care of the extraction for you.It's never been easier to extract feature, add an extra loss or plug another head to a network. Ler us know what amazing things you build with torchextractor!

Installation

pip install torchextractor  # stable
pip install git+https://github.com/antoinebrl/torchextractor.git  # latest

Requirements:

  • Python >= 3.6+
  • torch >= 1.4.0

Usage

import torch
import torchvision
import torchextractor as tx

model = torchvision.models.resnet18(pretrained=True)
model = tx.Extractor(model, ["layer1", "layer2", "layer3", "layer4"])
dummy_input = torch.rand(7, 3, 224, 224)
model_output, features = model(dummy_input)
feature_shapes = {name: f.shape for name, f in features.items()}
print(feature_shapes)

# {
#   'layer1': torch.Size([1, 64, 56, 56]),
#   'layer2': torch.Size([1, 128, 28, 28]),
#   'layer3': torch.Size([1, 256, 14, 14]),
#   'layer4': torch.Size([1, 512, 7, 7]),
# }

See more examples Binder Open In Colab

Read the documentation

FAQ

• How do I know the names of the modules?

You can print all module names like this:

tx.list_module_names(model)

# OR

for name, module in model.named_modules():
    print(name)

• Why do some operations not get listed?

It is not possible to add hooks if operations are not defined as modules. Therefore, F.relu cannot be captured but nn.Relu() can.

• How can I avoid listing all relevant modules?

You can specify a custom filtering function to hook the relevant modules:

# Hook everything !
module_filter_fn = lambda module, name: True

# Capture of all modules inside first layer
module_filter_fn = lambda module, name: name.startswith("layer1")

# Focus on all convolutions
module_filter_fn = lambda module, name: isinstance(module, torch.nn.Conv2d)

model = tx.Extractor(model, module_filter_fn=module_filter_fn)

• Is it compatible with ONNX?

tx.Extractor is compatible with ONNX! This means you can also access intermediate features maps after the export.

Pro-tip: name the output nodes by using output_names when calling torch.onnx.export.

• Is it compatible with TorchScript?

Not yet, but we are working on it. Compiling registered hook of a module was just recently added in PyTorch v1.8.0.

• "One more thing!" 😉

By default we capture the latest output of the relevant modules, but you can specify your own custom operations.

For example, to accumulate features over 10 forward passes you can do the following:

import torch
import torchvision
import torchextractor as tx

model = torchvision.models.resnet18(pretrained=True)

def capture_fn(module, input, output, module_name, feature_maps):
    if module_name not in feature_maps:
        feature_maps[module_name] = []
    feature_maps[module_name].append(output)

extractor = tx.Extractor(model, ["layer3", "layer4"], capture_fn=capture_fn)

for i in range(20):
    for i in range(10):
        x = torch.rand(7, 3, 224, 224)
        model(x)
    feature_maps = extractor.collect()

    # Do your stuffs here

    # Discard collected elements
    extractor.clear_placeholder()

Contributing

All feedbacks and contributions are welcomed. Feel free to report an issue or to create a pull request!

If you want to get hands-on:

  1. (Fork and) clone the repo.
  2. Create a virtual environment: virtualenv -p python3 .venv && source .venv/bin/activate
  3. Install dependencies: pip install -r requirements.txt && pip install -r requirements-dev.txt
  4. Hook auto-formatting tools: pre-commit install
  5. Hack as much as you want!
  6. Run tests: python -m unittest discover -vs ./tests/
  7. Share your work and create a pull request.

To Build documentation:

cd docs
pip install requirements.txt
make html
You might also like...
Deep Image Search is an AI-based image search engine that includes deep transfor learning features Extraction and tree-based vectorized search.
Deep Image Search is an AI-based image search engine that includes deep transfor learning features Extraction and tree-based vectorized search.

Deep Image Search - AI-Based Image Search Engine Deep Image Search is an AI-based image search engine that includes deep transfer learning features Ex

Cross-media Structured Common Space for Multimedia Event Extraction (ACL2020)
Cross-media Structured Common Space for Multimedia Event Extraction (ACL2020)

Cross-media Structured Common Space for Multimedia Event Extraction Table of Contents Overview Requirements Data Quickstart Citation Overview The code

Source code for paper "Document-Level Relation Extraction with Adaptive Thresholding and Localized Context Pooling", AAAI 2021

ATLOP Code for AAAI 2021 paper Document-Level Relation Extraction with Adaptive Thresholding and Localized Context Pooling. If you make use of this co

Training data extraction on GPT-2

Training data extraction from GPT-2 This repository contains code for extracting training data from GPT-2, following the approach outlined in the foll

This repository contains the code for our fast polygonal building extraction from overhead images pipeline.
This repository contains the code for our fast polygonal building extraction from overhead images pipeline.

Polygonal Building Segmentation by Frame Field Learning We add a frame field output to an image segmentation neural network to improve segmentation qu

Adversarial Robustness Toolbox (ART) - Python Library for Machine Learning Security - Evasion, Poisoning, Extraction, Inference - Red and Blue Teams
Adversarial Robustness Toolbox (ART) - Python Library for Machine Learning Security - Evasion, Poisoning, Extraction, Inference - Red and Blue Teams

Adversarial Robustness Toolbox (ART) is a Python library for Machine Learning Security. ART provides tools that enable developers and researchers to defend and evaluate Machine Learning models and applications against the adversarial threats of Evasion, Poisoning, Extraction, and Inference. ART supports all popular machine learning frameworks (TensorFlow, Keras, PyTorch, MXNet, scikit-learn, XGBoost, LightGBM, CatBoost, GPy, etc.), all data types (images, tables, audio, video, etc.) and machine learning tasks (classification, object detection, speech recognition, generation, certification, etc.).

Implementation for our AAAI2021 paper (Entity Structure Within and Throughout: Modeling Mention Dependencies for Document-Level Relation Extraction).
Implementation for our AAAI2021 paper (Entity Structure Within and Throughout: Modeling Mention Dependencies for Document-Level Relation Extraction).

SSAN Introduction This is the pytorch implementation of the SSAN model (see our AAAI2021 paper: Entity Structure Within and Throughout: Modeling Menti

An Efficient Implementation of Analytic Mesh Algorithm for 3D Iso-surface Extraction from Neural Networks
An Efficient Implementation of Analytic Mesh Algorithm for 3D Iso-surface Extraction from Neural Networks

AnalyticMesh Analytic Marching is an exact meshing solution from neural networks. Compared to standard methods, it completely avoids geometric and top

[ACL 20] Probing Linguistic Features of Sentence-level Representations in Neural Relation Extraction

REval Table of Contents Introduction Overview Requirements Installation Probing Usage Citation License 🎓 Introduction REval is a simple framework for

Comments
  • Only extracting part of the intermediate feature with DataParallel

    Only extracting part of the intermediate feature with DataParallel

    Hi @antoinebrl,

    I am using torch.nn.DataParallel on a 2-GPU machine with a batch size of N. Data parallel training will split the input data batch into 2 pieces sequentially and sends them to GPUs.

    When using torchextractor to obtain the intermediate feature, the input data size and the output size are both N as expected, but the feature size becomes N/2. Does this mean we only extract the features of one GPU? I'm not sure because I didn't find an exact match.

    Can you please explain why this happens? Maybe the normal behavior is returning features from all GPUs or from a specified one?

    A minimal example to reproduce:

    import torch
    import torchvision
    import torchextractor as tx
    
    model = torchvision.models.resnet18(pretrained=True)
    model_gpu = torch.nn.DataParallel(torchvision.models.resnet18(pretrained=True))
    model_gpu.cuda()
    
    model = tx.Extractor(model, ["layer1"])
    model_gpu = tx.Extractor(model_gpu, ["module.layer1"])
    dummy_input = torch.rand(8, 3, 224, 224)
    _, features = model(dummy_input)
    _, features_gpu = model_gpu(dummy_input)
    feature_shapes = {name: f.shape for name, f in features.items()}
    print(feature_shapes)
    feature_shapes_gpu = {name: f.shape for name, f in features_gpu.items()}
    print(feature_shapes_gpu)
    
    # {'layer1': torch.Size([8, 64, 56, 56])}
    # {'module.layer1': torch.Size([4, 64, 56, 56])}
    
    opened by wydwww 5
Releases(v0.3.0)
Learning Logic Rules for Document-Level Relation Extraction

LogiRE Learning Logic Rules for Document-Level Relation Extraction We propose to introduce logic rules to tackle the challenges of doc-level RE. Equip

41 Dec 26, 2022
ICNet for Real-Time Semantic Segmentation on High-Resolution Images, ECCV2018

ICNet for Real-Time Semantic Segmentation on High-Resolution Images by Hengshuang Zhao, Xiaojuan Qi, Xiaoyong Shen, Jianping Shi, Jiaya Jia, details a

Hengshuang Zhao 594 Dec 31, 2022
Implementations of CNNs, RNNs, GANs, etc

Tensorflow Programs and Tutorials This repository will contain Tensorflow tutorials on a lot of the most popular deep learning concepts. It'll also co

Adit Deshpande 1k Dec 30, 2022
Pytorch implementation of Learning Rate Dropout.

Learning-Rate-Dropout Pytorch implementation of Learning Rate Dropout. Paper Link: https://arxiv.org/pdf/1912.00144.pdf Train ResNet-34 for Cifar10: r

42 Nov 25, 2022
This repository contains the source code for the paper First Order Motion Model for Image Animation

!!! Check out our new paper and framework improved for articulated objects First Order Motion Model for Image Animation This repository contains the s

13k Jan 09, 2023
Automatic detection and classification of Covid severity degree in LUS (lung ultrasound) scans

Final-Project Final project in the Technion, Biomedical faculty, by Mor Ventura, Dekel Brav & Omri Magen. Subproject 1: Automatic Detection of LUS Cha

Mor Ventura 1 Dec 18, 2021
TransMorph: Transformer for Medical Image Registration

TransMorph: Transformer for Medical Image Registration keywords: Vision Transformer, Swin Transformer, convolutional neural networks, image registrati

Junyu Chen 180 Jan 07, 2023
Official implementation of Rethinking Graph Neural Architecture Search from Message-passing (CVPR2021)

Rethinking Graph Neural Architecture Search from Message-passing Intro The GNAS can automatically learn better architecture with the optimal depth of

Shaofei Cai 48 Sep 30, 2022
Final term project for Bayesian Machine Learning Lecture (XAI-623)

Mixquality_AL Final Term Project For Bayesian Machine Learning Lecture (XAI-623) Youtube Link The presentation is given in YoutubeLink Problem Formula

JeongEun Park 3 Jan 18, 2022
Self-supervised Multi-modal Hybrid Fusion Network for Brain Tumor Segmentation

JBHI-Pytorch This repository contains a reference implementation of the algorithms described in our paper "Self-supervised Multi-modal Hybrid Fusion N

FeiyiFANG 5 Dec 13, 2021
A large-scale database for graph representation learning

A large-scale database for graph representation learning

Scott Freitas 29 Nov 25, 2022
Code for ACM MM 2020 paper "NOH-NMS: Improving Pedestrian Detection by Nearby Objects Hallucination"

NOH-NMS: Improving Pedestrian Detection by Nearby Objects Hallucination The offical implementation for the "NOH-NMS: Improving Pedestrian Detection by

Tencent YouTu Research 64 Nov 11, 2022
ADB-IP-ROTATION - Use your mobile phone to gain a temporary IP address using ADB and data tethering

ADB IP ROTATE This an Python script based on Android Debug Bridge (adb) shell sc

Dor Bismuth 2 Jul 12, 2022
Code for the paper SphereRPN: Learning Spheres for High-Quality Region Proposals on 3D Point Clouds Object Detection, ICIP 2021.

SphereRPN Code for the paper SphereRPN: Learning Spheres for High-Quality Region Proposals on 3D Point Clouds Object Detection, ICIP 2021. Authors: Th

Thang Vu 15 Dec 02, 2022
Instantaneous Motion Generation for Robots and Machines.

Ruckig Instantaneous Motion Generation for Robots and Machines. Ruckig generates trajectories on-the-fly, allowing robots and machines to react instan

Berscheid 374 Dec 23, 2022
A simple rest api that classifies pneumonia infection weather it is Normal, Pneumonia Virus or Pneumonia Bacteria from a chest-x-ray image.

This is a simple rest api that classifies pneumonia infection weather it is Normal, Pneumonia Virus or Pneumonia Bacteria from a chest-x-ray image.

crispengari 3 Jan 08, 2022
StocksMA is a package to facilitate access to financial and economic data of Moroccan stocks.

Creating easier access to the Moroccan stock market data What is StocksMA ? StocksMA is a package to facilitate access to financial and economic data

Salah Eddine LABIAD 28 Jan 04, 2023
Bio-Computing Platform Featuring Large-Scale Representation Learning and Multi-Task Deep Learning “螺旋桨”生物计算工具集

English | 简体中文 Latest News 2021.10.25 Paper "Docking-based Virtual Screening with Multi-Task Learning" is accepted by BIBM 2021. 2021.07.29 PaddleHeli

633 Jan 04, 2023
The coda and data for "Measuring Fine-Grained Domain Relevance of Terms: A Hierarchical Core-Fringe Approach" (ACL '21)

We propose a hierarchical core-fringe learning framework to measure fine-grained domain relevance of terms – the degree that a term is relevant to a broad (e.g., computer science) or narrow (e.g., de

Jie Huang 14 Oct 21, 2022
Deep Anomaly Detection with Outlier Exposure (ICLR 2019)

Outlier Exposure This repository contains the essential code for the paper Deep Anomaly Detection with Outlier Exposure (ICLR 2019). Requires Python 3

Dan Hendrycks 464 Dec 27, 2022