Official PyTorch implementation of "Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets" (ICLR 2021)

Related tags

Deep LearningMetaD2A
Overview

Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets

This is the official PyTorch implementation for the paper Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets (ICLR 2021) : https://openreview.net/forum?id=rkQuFUmUOg3.

Abstract

Despite the success of recent Neural Architecture Search (NAS) methods on various tasks which have shown to output networks that largely outperform human-designed networks, conventional NAS methods have mostly tackled the optimization of searching for the network architecture for a single task (dataset), which does not generalize well across multiple tasks (datasets). Moreover, since such task-specific methods search for a neural architecture from scratch for every given task, they incur a large computational cost, which is problematic when the time and monetary budget are limited. In this paper, we propose an efficient NAS framework that is trained once on a database consisting of datasets and pretrained networks and can rapidly search a neural architecture for a novel dataset. The proposed MetaD2A (Meta Dataset-to-Architecture) model can stochastically generate graphs (architectures) from a given set (dataset) via a cross-modal latent space learned with amortized meta-learning. Moreover, we also propose a meta-performance predictor to estimate and select the best architecture without direct training on target datasets. The experimental results demonstrate that our model meta-learned on subsets of ImageNet-1K and architectures from NAS-Bench 201 search space successfully generalizes to multiple benchmark datasets including CIFAR-10 and CIFAR-100, with an average search time of 33 GPU seconds. Even under a large search space, MetaD2A is 5.5K times faster than NSGANetV2, a transferable NAS method, with comparable performance. We believe that the MetaD2A proposes a new research direction for rapid NAS as well as ways to utilize the knowledge from rich databases of datasets and architectures accumulated over the past years.

Framework of MetaD2A Model

Prerequisites

  • Python 3.6 (Anaconda)
  • PyTorch 1.6.0
  • CUDA 10.2
  • python-igraph==0.8.2
  • tqdm==4.50.2
  • torchvision==0.7.0
  • python-igraph==0.8.2
  • nas-bench-201==1.3
  • scipy==1.5.2

If you are not familiar with preparing conda environment, please follow the below instructions

$ conda create --name metad2a python=3.6
$ conda activate metad2a
$ conda install pytorch==1.6.0 torchvision cudatoolkit=10.2 -c pytorch
$ pip install nas-bench-201
$ conda install -c conda-forge tqdm
$ conda install -c conda-forge python-igraph
$ pip install scipy

And for data preprocessing,

$ pip install requests

Hardware Spec used for experiments of the paper

  • GPU: A single Nvidia GeForce RTX 2080Ti
  • CPU: Intel(R) Xeon(R) Silver 4114 CPU @ 2.20GHz

NAS-Bench-201

Go to the folder for NAS-Bench-201 experiments (i.e. MetaD2A_nas_bench_201)

$ cd MetaD2A_nas_bench_201

Data Preparation

To download preprocessed data files, run get_files/get_preprocessed_data.py:

$ python get_files/get_preprocessed_data.py

It will take some time to download and preprocess each dataset.

To download MNIST, Pets and Aircraft Datasets, run get_files/get_{DATASET}.py

$ python get_files/get_mnist.py
$ python get_files/get_aircraft.py
$ python get_files/get_pets.py

Other datasets such as Cifar10, Cifar100, SVHN will be automatically downloaded when you load dataloader by torchvision.

If you want to use your own dataset, please first make your own preprocessed data, by modifying process_dataset.py .

$ process_dataset.py

MetaD2A Evaluation (Meta-Test)

You can download trained checkpoint files for generator and predictor

$ python get_files/get_checkpoint.py
$ python get_files/get_predictor_checkpoint.py

1. Evaluation on Cifar10 and Cifar100

By set --data-name as the name of dataset (i.e. cifar10, cifar100), you can evaluate the specific dataset only

# Meta-testing for generator 
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 --test --load-epoch 400 --num-gen-arch 500 --data-name {DATASET_NAME}

After neural architecture generation is completed, meta-performance predictor selects high-performing architectures among the candidates

# Meta-testing for predictor
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56 --test --num-gen-arch 500 --data-name {DATASET_NAME}

2. Evaluation on Other Datasets

By set --data-name as the name of dataset (i.e. mnist, svhn, aircraft, pets), you can evaluate the specific dataset only

# Meta-testing for generator
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 --test --load-epoch 400 --num-gen-arch 50 --data-name {DATASET_NAME}

After neural architecture generation is completed, meta-performance predictor selects high-performing architectures among the candidates

# Meta-testing for predictor
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56 --test --num-gen-arch 50 --data-name {DATASET_NAME}

Meta-Training MetaD2A Model

You can train the generator and predictor as follows

# Meta-training for generator
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 
                 
# Meta-training for predictor
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56 

Results

The results of training architectures which are searched by meta-trained MetaD2A model for each dataset

Accuracy

CIFAR10 CIFAR100 MNIST SVHN Aircraft Oxford-IIT Pets
PC-DARTS 93.66±0.17 66.64±0.04 99.66±0.04 95.40±0.67 46.08±7.00 25.31±1.38
MetaD2A (Ours) 94.37±0.03 73.51±0.00 99.71±0.08 96.34±0.37 58.43±1.18 41.50±4.39

Search Time (GPU Sec)

CIFAR10 CIFAR100 MNIST SVHN Aircraft Oxford-IIT Pets
PC-DARTS 10395 19951 24857 31124 3524 2844
MetaD2A (Ours) 69 96 7 7 10 8

MobileNetV3 Search Space

Go to the folder for MobileNetV3 Search Space experiments (i.e. MetaD2A_mobilenetV3)

$ cd MetaD2A_mobilenetV3

And follow README.md written for experiments of MobileNetV3 Search Space

Citation

If you found the provided code useful, please cite our work.

@inproceedings{
    lee2021rapid,
    title={Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets},
    author={Hayeon Lee and Eunyoung Hyung and Sung Ju Hwang},
    booktitle={ICLR},
    year={2021}
}

Reference

Owner
Ph.D. student @ School of Computing, Korea Advanced Institute of Science and Technology (KAIST)
[2021][ICCV][FSNet] Full-Duplex Strategy for Video Object Segmentation

Full-Duplex Strategy for Video Object Segmentation (ICCV, 2021) Authors: Ge-Peng Ji, Keren Fu, Zhe Wu, Deng-Ping Fan*, Jianbing Shen, & Ling Shao This

Daniel-Ji 55 Dec 22, 2022
Python implementation of Bayesian optimization over permutation spaces.

Bayesian Optimization over Permutation Spaces This repository contains the source code and the resources related to the paper "Bayesian Optimization o

Aryan Deshwal 9 Dec 23, 2022
Code for Massive-scale Decoding for Text Generation using Lattices

Massive-scale Decoding for Text Generation using Lattices Jiacheng Xu, Greg Durrett TL;DR: a new search algorithm to construct lattices encoding many

Jiacheng Xu 37 Dec 18, 2022
Wind Speed Prediction using LSTMs in PyTorch

Implementation of Deep-Forecast using PyTorch Deep Forecast: Deep Learning-based Spatio-Temporal Forecasting Adapted from original implementation Setu

Onur Kaplan 151 Dec 14, 2022
Codecov coverage standard for Python

Python-Standard Last Updated: 01/07/22 00:09:25 What is this? This is a Python application, with basic unit tests, for which coverage is uploaded to C

Codecov 10 Nov 04, 2022
A library for performing coverage guided fuzzing of neural networks

TensorFuzz: Coverage Guided Fuzzing for Neural Networks This repository contains a library for performing coverage guided fuzzing of neural networks,

Brain Research 195 Dec 28, 2022
Cross-platform-profile-pic-changer - Script to change profile pictures across multiple platforms

cross-platform-profile-pic-changer script to change profile pictures across mult

4 Jan 17, 2022
SIEM Logstash parsing for more than hundred technologies

LogIndexer Pipeline Logstash Parsing Configurations for Elastisearch SIEM and OpenDistro for Elasticsearch SIEM Why this project exists The overhead o

146 Dec 29, 2022
(CVPR 2022 Oral) Official implementation for "Surface Representation for Point Clouds"

RepSurf - Surface Representation for Point Clouds [CVPR 2022 Oral] By Haoxi Ran* , Jun Liu, Chengjie Wang ( * : corresponding contact) The pytorch off

Haoxi Ran 264 Dec 23, 2022
Illuminated3D This project participates in the Nasa Space Apps Challenge 2021.

Illuminated3D This project participates in the Nasa Space Apps Challenge 2021.

Eleftheriadis Emmanouil 1 Oct 09, 2021
This repository contains the scripts for downloading and validating scripts for the documents

HC4: HLTCOE CLIR Common-Crawl Collection This repository contains the scripts for downloading and validating scripts for the documents. Document ids,

JHU Human Language Technology Center of Excellence 6 Jun 07, 2022
Road Crack Detection Using Deep Learning Methods

Road-Crack-Detection-Using-Deep-Learning-Methods This is my Diploma Thesis ¨Road Crack Detection Using Deep Learning Methods¨ under the supervision of

Aggelos Katsaliros 3 May 03, 2022
[NeurIPS 2021 Spotlight] Aligning Pretraining for Detection via Object-Level Contrastive Learning

SoCo [NeurIPS 2021 Spotlight] Aligning Pretraining for Detection via Object-Level Contrastive Learning By Fangyun Wei*, Yue Gao*, Zhirong Wu, Han Hu,

Yue Gao 139 Dec 14, 2022
Quantile Regression DQN a Minimal Working Example, Distributional Reinforcement Learning with Quantile Regression

Quantile Regression DQN Quantile Regression DQN a Minimal Working Example, Distributional Reinforcement Learning with Quantile Regression (https://arx

Arsenii Senya Ashukha 80 Sep 17, 2022
PyTorch implementation of the paper Dynamic Token Normalization Improves Vision Transfromers.

Dynamic Token Normalization Improves Vision Transformers This is the PyTorch implementation of the paper Dynamic Token Normalization Improves Vision T

Wenqi Shao 20 Oct 09, 2022
A Python framework for developing parallelized Computational Fluid Dynamics software to solve the hyperbolic 2D Euler equations on distributed, multi-block structured grids.

pyHype: Computational Fluid Dynamics in Python pyHype is a Python framework for developing parallelized Computational Fluid Dynamics software to solve

Mohamed Khalil 21 Nov 22, 2022
PArallel Distributed Deep LEarning: Machine Learning Framework from Industrial Practice (『飞桨』核心框架,深度学习&机器学习高性能单机、分布式训练和跨平台部署)

English | 简体中文 Welcome to the PaddlePaddle GitHub. PaddlePaddle, as the only independent R&D deep learning platform in China, has been officially open

19.4k Jan 04, 2023
NumPy로 구현한 딥러닝 라이브러리입니다. (자동 미분 지원)

Deep Learning Library only using NumPy 본 레포지토리는 NumPy 만으로 구현한 딥러닝 라이브러리입니다. 자동 미분이 구현되어 있습니다. 자동 미분 자동 미분은 미분을 자동으로 계산해주는 기능입니다. 아래 코드는 자동 미분을 활용해 역전파

조준희 17 Aug 16, 2022
Proposal, Tracking and Segmentation (PTS): A Cascaded Network for Video Object Segmentation

Proposal, Tracking and Segmentation (PTS): A Cascaded Network for Video Object Segmentation By Qiang Zhou*, Zilong Huang*, Lichao Huang, Han Shen, Yon

Forest 117 Apr 01, 2022
An essential implementation of BYOL in PyTorch + PyTorch Lightning

Essential BYOL A simple and complete implementation of Bootstrap your own latent: A new approach to self-supervised Learning in PyTorch + PyTorch Ligh

Enrico Fini 48 Sep 27, 2022