DABO: Data Augmentation with Bilevel Optimization

Overview

License

figure figure

DABO: Data Augmentation with Bilevel Optimization [Paper]

The goal is to automatically learn an efficient data augmentation regime for image classification.

Accepted at WACV2021

Table of Contents

Overview

What's new: This method provides a way to automatically learn data augmentation in order to improve the image classification performance. It does not require us to hard code augmentation techniques, which might need domain knowledge or an expensive hyper-parameter search on the validation set.

Key insight: Our method efficiently trains a network that performs data augmentation. This network learns data augmentation by usiing the gradient that flows from computing the classifier's validation loss using an online version of bilevel optimization. We also perform truncated back-propagation in order to significantly reduce the computational cost of bilevel optimization.

How it works: Our method jointly trains a classifier and an augmentation network through the following steps,

figure

  • For each mini batch,a forward pass is made to calculate the training loss.
  • Based on the training loss and the gradient of the training loss, an optimization step is made for the classifier in the inner loop.
  • A forward pass is then made on the classifier with the new weight to calculate the validation loss.
  • The gradient from the validation loss is backpropagated to train the augmentation network.

Results: Our model obtains better results than carefuly hand engineered transformations and GAN-based approaches. Further, the results are competitive against methods that use a policy search on CIFAR10, CIFAR100, BACH, Tiny-Imagenet and Imagenet datasets.

Why it matters: Proper data augmentation can significantly improve generalization performance. Unfortunately, deriving these augmentations require domain expertise or extensive hyper-parameter search. Thus, having an automatic and quick way of identifying efficient data augmentation has a big impact in obtaining better models.

Where to go from here: Performance can be improved by extending the set of learned transformations to non-differentiable transformations. The estimation of the validation loss could also be improved by exploring more the influence of the number of iteration in the inner loop. Finally, the method can be extended to other tasks like object detection of image segmentation.

Experiments

1. Install requirements: Run this command to install the Haven library which helps in managing experiments.

pip install -r requirements.txt

2.1 CIFAR10 experiments: The followng command runs the training and validation loop for CIFAR.

python trainval.py -e cifar -sb ../results -d ../data -r 1

where -e defines the experiment group, -sb is the result directory, and -d is the dataset directory.

2.2 BACH experiments: The followng command runs the training and validation loop on BACH dataset.

python trainval.py -e bach -sb ../results -d ../data -r 1

where -e defines the experiment group, -sb is the result directory, and -d is the dataset directory.

3. Results: Display the results by following the steps below,

figure

Launch Jupyter by running the following on terminal,

jupyter nbextension enable --py widgetsnbextension
jupyter notebook

Then, run the following script on a Jupyter cell,

from haven import haven_jupyter as hj
from haven import haven_results as hr
from haven import haven_utils as hu

# path to where the experiments got saved
savedir_base = ''
exp_list = None

# exp_list = hu.load_py().EXP_GROUPS[]
# get experiments
rm = hr.ResultManager(exp_list=exp_list, 
                      savedir_base=savedir_base, 
                      verbose=0
                     )
y_metrics = ['test_acc']
bar_agg = 'max'
mode = 'bar'
legend_list = ['model.netA.name']
title_list = 'dataset.name'
legend_format = 'Augmentation Netwok: {}'
filterby_list = {'dataset':{'name':'cifar10'}, 'model':{'netC':{'name':'resnet18_meta_2'}}}

# launch dashboard
hj.get_dashboard(rm, vars(), wide_display=True)

Citation

@article{mounsaveng2020learning,
  title={Learning Data Augmentation with Online Bilevel Optimization for Image Classification},
  author={Mounsaveng, Saypraseuth and Laradji, Issam and Ayed, Ismail Ben and Vazquez, David and Pedersoli, Marco},
  journal={arXiv preprint arXiv:2006.14699},
  year={2020}
}
Owner
ElementAI
ElementAI
fklearn: Functional Machine Learning

fklearn: Functional Machine Learning fklearn uses functional programming principles to make it easier to solve real problems with Machine Learning. Th

nubank 1.4k Dec 07, 2022
Learning Saliency Propagation for Semi-supervised Instance Segmentation

Learning Saliency Propagation for Semi-supervised Instance Segmentation PyTorch Implementation This repository contains: the PyTorch implementation of

Berkeley DeepDrive 68 Oct 18, 2022
Recreate CenternetV2 based on MMDET.

Introduction This project is trying to Recreate CenternetV2 based on MMDET, which is proposed in paper Probabilistic two-stage detection. This project

25 Dec 09, 2022
Official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model.

BALLAD This is the official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model. Requirements Python3 Pytorch(1.7.

peng gao 42 Nov 26, 2022
List of all dependencies affected by node-ipc malicious commit

node-ipc-dependencies-list List of all dependencies affected by node-ipc malicious commit as of 17/3/2022 - 19/3/2022 (timestamp) Please improve upon

99 Oct 15, 2022
《Dual-Resolution Correspondence Network》(NeurIPS 2020)

Dual-Resolution Correspondence Network Dual-Resolution Correspondence Network, NeurIPS 2020 Dependency All dependencies are included in asset/dualrcne

Active Vision Laboratory 45 Nov 21, 2022
Tracing Versus Freehand for Evaluating Computer-Generated Drawings (SIGGRAPH 2021)

Tracing Versus Freehand for Evaluating Computer-Generated Drawings (SIGGRAPH 2021) Zeyu Wang, Sherry Qiu, Nicole Feng, Holly Rushmeier, Leonard McMill

Zach Zeyu Wang 23 Dec 09, 2022
95.47% on CIFAR10 with PyTorch

Train CIFAR10 with PyTorch I'm playing with PyTorch on the CIFAR10 dataset. Prerequisites Python 3.6+ PyTorch 1.0+ Training # Start training with: py

5k Dec 30, 2022
The code for the NeurIPS 2021 paper "A Unified View of cGANs with and without Classifiers".

Energy-based Conditional Generative Adversarial Network (ECGAN) This is the code for the NeurIPS 2021 paper "A Unified View of cGANs with and without

sianchen 22 May 28, 2022
Automate issue discovery for your projects against Lightning nightly and releases.

Automated Testing for Lightning EcoSystem Projects Automate issue discovery for your projects against Lightning nightly and releases. You get CPUs, Mu

Pytorch Lightning 41 Dec 24, 2022
FS-Mol: A Few-Shot Learning Dataset of Molecules

FS-Mol is A Few-Shot Learning Dataset of Molecules, containing molecular compounds with measurements of activity against a variety of protein targets. The dataset is presented with a model evaluation

Microsoft 114 Dec 15, 2022
Open source implementation of AceNAS: Learning to Rank Ace Neural Architectures with Weak Supervision of Weight Sharing

AceNAS This repo is the experiment code of AceNAS, and is not considered as an official release. We are working on integrating AceNAS as a built-in st

Yuge Zhang 6 Sep 07, 2022
A texturizer that I just made. Nothing special here.

texturizer This is a little project that I did with an hour's time. It texturizes an image given a image and a texture to texturize it with. There is

1 Nov 11, 2021
FAST-RIR: FAST NEURAL DIFFUSE ROOM IMPULSE RESPONSE GENERATOR

This is the official implementation of our neural-network-based fast diffuse room impulse response generator (FAST-RIR) for generating room impulse responses (RIRs) for a given acoustic environment.

Anton Jeran Ratnarajah 89 Dec 22, 2022
Learning to Segment Instances in Videos with Spatial Propagation Network

Learning to Segment Instances in Videos with Spatial Propagation Network This paper is available at the 2017 DAVIS Challenge website. Check our result

Jingchun Cheng 145 Sep 28, 2022
DeeBERT: Dynamic Early Exiting for Accelerating BERT Inference

DeeBERT This is the code base for the paper DeeBERT: Dynamic Early Exiting for Accelerating BERT Inference. Code in this repository is also available

Castorini 132 Nov 14, 2022
Elastic weight consolidation technique for incremental learning.

Overcoming-Catastrophic-forgetting-in-Neural-Networks Elastic weight consolidation technique for incremental learning. About Use this API if you dont

Shivam Saboo 89 Dec 22, 2022
Research code of ICCV 2021 paper "Mesh Graphormer"

MeshGraphormer ✨ ✨ This is our research code of Mesh Graphormer. Mesh Graphormer is a new transformer-based method for human pose and mesh reconsructi

Microsoft 251 Jan 08, 2023
Official code of paper: MovingFashion: a Benchmark for the Video-to-Shop Challenge

SEAM Match-RCNN Official code of MovingFashion: a Benchmark for the Video-to-Shop Challenge paper Installation Requirements: Pytorch 1.5.1 or more rec

HumaticsLAB 31 Oct 10, 2022
Code for the paper "Functional Regularization for Reinforcement Learning via Learned Fourier Features"

Reinforcement Learning with Learned Fourier Features State-space Soft Actor-Critic Experiments Move to the state-SAC-LFF repository. cd state-SAC-LFF

Alex Li 10 Nov 11, 2022