Code for "AutoMTL: A Programming Framework for Automated Multi-Task Learning"

Related tags

Deep LearningAutoMTL
Overview

AutoMTL: A Programming Framework for Automated Multi-Task Learning

This is the website for our paper "AutoMTL: A Programming Framework for Automated Multi-Task Learning", submitted to MLSys 2022. The arXiv version will be public at Tue, 26 Oct 2021.

Abstract

Multi-task learning (MTL) jointly learns a set of tasks. It is a promising approach to reduce the training and inference time and storage costs while improving prediction accuracy and generalization performance for many computer vision tasks. However, a major barrier preventing the widespread adoption of MTL is the lack of systematic support for developing compact multi-task models given a set of tasks. In this paper, we aim to remove the barrier by developing the first programming framework AutoMTL that automates MTL model development. AutoMTL takes as inputs an arbitrary backbone convolutional neural network and a set of tasks to learn, then automatically produce a multi-task model that achieves high accuracy and has small memory footprint simultaneously. As a programming framework, AutoMTL could facilitate the development of MTL-enabled computer vision applications and even further improve task performance.

overview

Cite

Welcome to cite our work if you find it is helpful to your research. [TODO: cite info]

Description

Environment

conda install pytorch==1.6.0 torchvision==0.7.0 -c pytorch # Or higher
conda install protobuf
pip install opencv-python
pip install scikit-learn

Datasets

We conducted experiments on three popular datasets in multi-task learning (MTL), CityScapes [1], NYUv2 [2], and Tiny-Taskonomy [3]. You can download the them here. For Tiny-Taskonomy, you will need to contact the authors directly. See their official website.

File Structure

├── data
│   ├── dataloader
│   │   ├── *_dataloader.py
│   ├── heads
│   │   ├── pixel2pixel.py
│   ├── metrics
│   │   ├── pixel2pixel_loss/metrics.py
├── framework
│   ├── layer_containers.py
│   ├── base_node.py
│   ├── layer_node.py
│   ├── mtl_model.py
│   ├── trainer.py
├── models
│   ├── *.prototxt
├── utils
└── └── pytorch_to_caffe.py

Code Description

Our code can be divided into three parts: code for data, code of AutoMTL, and others

  • For Data

    • Dataloaders *_dataloader.py: For each dataset, we offer a corresponding PyTorch dataloader with a specific task variable.
    • Heads pixel2pixel.py: The ASPP head [4] is implemented for the pixel-to-pixel vision tasks.
    • Metrics pixel2pixel_loss/metrics.py: For each task, it has its own criterion and metric.
  • AutoMTL

    • Multi-Task Model Generator mtl_model.py: Transfer the given backbone model in the format of prototxt, and the task-specific model head dictionary to a multi-task supermodel.
    • Trainer Tools trainer.py: Meterialize a three-stage training pipeline to search out a good multi-task model for the given tasks. pipeline
  • Others

    • Input Backbone *.prototxt: Typical vision backbone models including Deeplab-ResNet34 [4], MobileNetV2, and MNasNet.
    • Transfer to Prototxt pytorch_to_caffe.py: If you define your own customized backbone model in PyTorch API, we also provide a tool to convert it to a prototxt file.

How to Use

Set up Data

Each task will have its own dataloader for both training and validation, task-specific criterion (loss), evaluation metric, and model head. Here we take CityScapes as an example.

tasks = ['segment_semantic', 'depth_zbuffer']
task_cls_num = {'segment_semantic': 19, 'depth_zbuffer': 1} # the number of classes in each task

You can also define your own dataloader, criterion, and evaluation metrics. Please refer to files in data/ to make sure your customized classes have the same output format as ours to fit for our framework.

dataloader dictionary

trainDataloaderDict = {}
valDataloaderDict = {}
for task in tasks:
    dataset = CityScapes(dataroot, 'train', task, crop_h=224, crop_w=224)
    trainDataloaderDict[task] = DataLoader(dataset, <batch_size>, shuffle=True)

    dataset = CityScapes(dataroot, 'test', task)
    valDataloaderDict[task] = DataLoader(dataset, <batch_size>, shuffle=True)

criterion dictionary

criterionDict = {}
for task in tasks:
    criterionDict[task] = CityScapesCriterions(task)

evaluation metric dictionary

metricDict = {}
for task in tasks:
    metricDict[task] = CityScapesMetrics(task)

task-specific heads dictionary

headsDict = nn.ModuleDict() # must be nn.ModuleDict() instead of python dictionary
for task in tasks:
    headsDict[task] = ASPPHeadNode(<feature_dim>, task_cls_num[task])

Construct Multi-Task Supermodel

prototxt = 'models/deeplab_resnet34_adashare.prototxt' # can be any CNN model
mtlmodel = MTLModel(prototxt, headsDict)

3-stage Training

define the trainer

trainer = Trainer(mtlmodel, trainDataloaderDict, valDataloaderDict, criterionDict, metricDict)

pre-train phase

trainer.pre_train(iters=<total_iter>, lr=<init_lr>, savePath=<save_path>)

policy-train phase

loss_lambda = {'segment_semantic': 1, 'depth_zbuffer': 1, 'policy':0.0005} # the weights for each task and the policy regularization term from the paper
trainer.alter_train_with_reg(iters=<total_iter>, policy_network_iters=<alter_iters>, policy_lr=<policy_lr>, network_lr=<network_lr>, 
                             loss_lambda=loss_lambda, savePath=<save_path>)

Notice that when training the policy and the model weights together, we alternatively train them for specified iters in policy_network_iters.

post-train phase

trainer.post_train(ters=<total_iter>, lr=<init_lr>, 
                   loss_lambda=loss_lambda, savePath=<save_path>, reload=<policy_train_model_name>)

Note: Please refer to Example.ipynb for more details.

References

[1] Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt. The cityscapes dataset for semantic urban scene understanding. CVPR, 3213-3223, 2016.

[2] Silberman, Nathan and Hoiem, Derek and Kohli, Pushmeet and Fergus, Rob. Indoor segmentation and support inference from rgbd images. ECCV, 746-760, 2012.

[3] Zamir, Amir R and Sax, Alexander and Shen, William and Guibas, Leonidas J and Malik, Jitendra and Savarese, Silvio. Taskonomy: Disentangling task transfer learning. CVPR, 3712-3722, 2018.

[4] Chen, Liang-Chieh and Papandreou, George and Kokkinos, Iasonas and Murphy, Kevin and Yuille, Alan L. Deeplab: Semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. PAMI, 834-848, 2017.

Owner
Ivy Zhang
Ivy Zhang
The most simple and minimalistic navigation dashboard.

Navigation This project follows a goal to have simple and lightweight dashboard with different links. I use it to have my own self-hosted service dash

Yaroslav 23 Dec 23, 2022
Official implementation of SIGIR'2021 paper: "Sequential Recommendation with Graph Neural Networks".

SURGE: Sequential Recommendation with Graph Neural Networks This is our TensorFlow implementation for the paper: Sequential Recommendation with Graph

FIB LAB, Tsinghua University 53 Dec 26, 2022
It is a simple library to speed up CLIP inference up to 3x (K80 GPU)

CLIP-ONNX It is a simple library to speed up CLIP inference up to 3x (K80 GPU) Usage Install clip-onnx module and requirements first. Use this trick !

Gerasimov Maxim 93 Dec 20, 2022
PyTorch implementation of "Image-to-Image Translation Using Conditional Adversarial Networks".

pix2pix-pytorch PyTorch implementation of Image-to-Image Translation Using Conditional Adversarial Networks. Based on pix2pix by Phillip Isola et al.

mrzhu 383 Dec 17, 2022
Paddle-Adversarial-Toolbox (PAT) is a Python library for Deep Learning Security based on PaddlePaddle.

Paddle-Adversarial-Toolbox Paddle-Adversarial-Toolbox (PAT) is a Python library for Deep Learning Security based on PaddlePaddle. Model Zoo Common FGS

AgentMaker 17 Nov 08, 2022
This is an official repository of CLGo: Learning to Predict 3D Lane Shape and Camera Pose from a Single Image via Geometry Constraints

CLGo This is an official repository of CLGo: Learning to Predict 3D Lane Shape and Camera Pose from a Single Image via Geometry Constraints An earlier

刘芮金 32 Dec 20, 2022
Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly

Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly Code for this paper Ultra-Data-Efficient GAN Tra

VITA 77 Oct 05, 2022
PPLNN is a Primitive Library for Neural Network is a high-performance deep-learning inference engine for efficient AI inferencing

PPLNN is a Primitive Library for Neural Network is a high-performance deep-learning inference engine for efficient AI inferencing

943 Jan 07, 2023
My personal Home Assistant configuration.

About This is my personal Home Assistant configuration. My guiding princile is to have full local control of all my devices. I intend everything to ru

Chris Turra 13 Jun 07, 2022
Towards Implicit Text-Guided 3D Shape Generation (CVPR2022)

Towards Implicit Text-Guided 3D Shape Generation Towards Implicit Text-Guided 3D Shape Generation (CVPR2022) Code for the paper [Towards Implicit Text

55 Dec 16, 2022
Regularizing Generative Adversarial Networks under Limited Data (CVPR 2021)

Regularizing Generative Adversarial Networks under Limited Data [Project Page][Paper] Implementation for our GAN regularization method. The proposed r

Google 148 Nov 18, 2022
Official implementation of the paper WAV2CLIP: LEARNING ROBUST AUDIO REPRESENTATIONS FROM CLIP

Wav2CLIP 🚧 WIP 🚧 Official implementation of the paper WAV2CLIP: LEARNING ROBUST AUDIO REPRESENTATIONS FROM CLIP 📄 🔗 Ho-Hsiang Wu, Prem Seetharaman

Descript 240 Dec 13, 2022
Code for "Universal inference meets random projections: a scalable test for log-concavity"

How to use this repository This repository contains code to replicate the results of "Universal inference meets random projections: a scalable test fo

Robin Dunn 0 Nov 21, 2021
🔀 Visual Room Rearrangement

AI2-THOR Rearrangement Challenge Welcome to the 2021 AI2-THOR Rearrangement Challenge hosted at the CVPR'21 Embodied-AI Workshop. The goal of this cha

AI2 55 Dec 22, 2022
Audio2Face - Audio To Face With Python

Audio2Face Discription We create a project that transforms audio to blendshape w

FACEGOOD 724 Dec 26, 2022
TakeInfoatNistforICS - Take Information in NIST NVD for ICS

Take Information in NIST NVD for ICS This project developed with Python. When yo

5 Sep 05, 2022
68 keypoint annotations for COFW test data

68 keypoint annotations for COFW test data This repository contains manually annotated 68 keypoints for COFW test data (original annotation of CFOW da

31 Dec 06, 2022
:fire: 2D and 3D Face alignment library build using pytorch

Face Recognition Detect facial landmarks from Python using the world's most accurate face alignment network, capable of detecting points in both 2D an

Adrian Bulat 6k Dec 31, 2022
Apache Spark - A unified analytics engine for large-scale data processing

Apache Spark Spark is a unified analytics engine for large-scale data processing. It provides high-level APIs in Scala, Java, Python, and R, and an op

The Apache Software Foundation 34.7k Jan 04, 2023
Very deep VAEs in JAX/Flax

Very Deep VAEs in JAX/Flax Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on I

Jamie Townsend 42 Dec 12, 2022