Official Pytorch implementation of "Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral)"

Overview

Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral): Official Project Webpage

This repository provides the official PyTorch implementation of the following paper:

Learning Debiased Representation via Disentangled Feature Augmentation
Jungsoo Lee* (KAIST AI, Kakao Enterprise), Eungyeup Kim* (KAIST AI, Kakao Enterprise),
Juyoung Lee (Kakao Enterprise), Jihyeon Lee (KAIST AI), and Jaegul Choo (KAIST AI)
(* indicates equal contribution. The order of first authors was chosen by tossing a coin.)
NeurIPS 2021, Oral

Paper: Arxiv

Abstract: Image classification models tend to make decisions based on peripheral attributes of data items that have strong correlation with a target variable (i.e., dataset bias). These biased models suffer from the poor generalization capability when evaluated on unbiased datasets. Existing approaches for debiasing often identify and emphasize those samples with no such correlation (i.e., bias-conflicting) without defining the bias type in advance. However, such bias-conflicting samples are significantly scarce in biased datasets, limiting the debiasing capability of these approaches. This paper first presents an empirical analysis revealing that training with "diverse" bias-conflicting samples beyond a given training set is crucial for debiasing as well as the generalization capability. Based on this observation, we propose a novel feature-level data augmentation technique in order to synthesize diverse bias-conflicting samples. To this end, our method learns the disentangled representation of (1) the intrinsic attributes (i.e., those inherently defining a certain class) and (2) bias attributes (i.e., peripheral attributes causing the bias), from a large number of bias-aligned samples, the bias attributes of which have strong correlation with the target variable. Using the disentangled representation, we synthesize bias-conflicting samples that contain the diverse intrinsic attributes of bias-aligned samples by swapping their latent features. By utilizing these diversified bias-conflicting features during the training, our approach achieves superior classification accuracy and debiasing results against the existing baselines on both synthetic as well as a real-world dataset.

Code Contributors

Jungsoo Lee [Website] [LinkedIn] [Google Scholar] (KAIST AI, Kakao Enterprise)
Eungyeup Kim [Website] [LinkedIn] [Google Scholar] (KAIST AI, Kakao Enterprise)
Juyoung Lee [Website] (Kakao Enterprise)

Pytorch Implementation

Installation

Clone this repository.

git clone https://github.com/kakaoenterprise/Learning-Debiased-Disentangled.git
cd Learning-Debiased-Disentangled
pip install -r requirements.txt

Datasets

We used three datasets in our paper.

Download the datasets with the following url. Note that BFFHQ is the dataset used in "BiaSwap: Removing Dataset Bias with Bias-Tailored Swapping Augmentation" (Kim et al., ICCV 2021). Unzip the files and the directory structures will be as following:

cmnist
 └ 0.5pct / 1pct / 2pct / 5pct
     └ align
     └ conlict
     └ valid
 └ test
cifar10c
 └ 0.5pct / 1pct / 2pct / 5pct
     └ align
     └ conlict
     └ valid
 └ test
bffhq
 └ 0.5pct
 └ valid
 └ test

How to Run

CMNIST

Vanilla
python train.py --dataset cmnist --exp=cmnist_0.5_vanilla --lr=0.01 --percent=0.5pct --train_vanilla --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_1_vanilla --lr=0.01 --percent=1pct --train_vanilla --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_2_vanilla --lr=0.01 --percent=2pct --train_vanilla --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_5_vanilla --lr=0.01 --percent=5pct --train_vanilla --tensorboard --wandb
bash scripts/run_cmnist_vanilla.sh
Ours
python train.py --dataset cmnist --exp=cmnist_0.5_ours --lr=0.01 --percent=0.5pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_1_ours --lr=0.01 --percent=1pct  --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_2_ours --lr=0.01 --percent=2pct  --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_5_ours --lr=0.01 --percent=5pct  --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
bash scripts/run_cmnist_ours.sh

Corrupted CIFAR10

Vanilla
python train.py --dataset cifar10c --exp=cifar10c_0.5_vanilla --lr=0.001 --percent=0.5pct --train_vanilla --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_1_vanilla --lr=0.001 --percent=1pct --train_vanilla --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_2_vanilla --lr=0.001 --percent=2pct --train_vanilla --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_5_vanilla --lr=0.001 --percent=5pct --train_vanilla --tensorboard --wandb
bash scripts/run_cifar10c_vanilla.sh
Ours
python train.py --dataset cifar10c --exp=cifar10c_0.5_ours --lr=0.0005 --percent=0.5pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=1 --lambda_swap_align=1 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_1_ours --lr=0.001 --percent=1pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=5 --lambda_swap_align=5 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_2_ours --lr=0.001 --percent=2pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=5 --lambda_swap_align=5 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_5_ours --lr=0.001 --percent=5pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=1 --lambda_swap_align=1 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
bash scripts/run_cifar10c_ours.sh

BFFHQ

Vanilla
python train.py --dataset bffhq --exp=bffhq_0.5_vanilla --lr=0.0001 --percent=0.5pct --train_vanilla --tensorboard --wandb
bash scripts/run_bffhq_vanilla.sh
Ours
python train.py --dataset bffhq --exp=bffhq_0.5_ours --lr=0.0001 --percent=0.5pct --lambda_swap=0.1 --curr_step=10000 --use_lr_decay --lr_decay_step=10000 --lambda_dis_align 2. --lambda_swap_align 2. --dataset bffhq --train_ours --tensorboard --wandb
bash scripts/run_bffhq_ours.sh

Pretrained Models

In order to test our pretrained models, run the following command.

python test.py --pretrained_path=
   
     --dataset=
    
      --percent=
     

     
    
   

We provide the pretrained models in the following urls.
CMNIST 0.5pct
CMNIST 1pct
CMNIST 2pct
CMNIST 5pct

CIFAR10C 0.5pct
CIFAR10C 1pct
CIFAR10C 2pct
CIFAR10C 5pct

BFFHQ 0.5pct

Citations

Bibtex coming soon!

Contact

Jungsoo Lee

Eungyeup Kim

Juyoung Lee

Kakao Enterprise/Vision Team

Acknowledgments

This work was mainly done when both of the first authors were doing internship at Vision Team/AI Lab/Kakao Enterprise. Our pytorch implementation is based on LfF. Thanks for the implementation.

Owner
Kakao Enterprise Corp.
Kakao Enterprise Corp.
Deep Latent Force Models

Deep Latent Force Models This repository contains a PyTorch implementation of the deep latent force model (DLFM), presented in the paper, Compositiona

Tom McDonald 5 Oct 26, 2022
Aiming at the common training datsets split, spectrum preprocessing, wavelength select and calibration models algorithm involved in the spectral analysis process

Aiming at the common training datsets split, spectrum preprocessing, wavelength select and calibration models algorithm involved in the spectral analysis process, a complete algorithm library is esta

Fu Pengyou 50 Jan 07, 2023
Retrieve and analysis data from SDSS (Sloan Digital Sky Survey)

Author: Behrouz Safari License: MIT sdss A python package for retrieving and analysing data from SDSS (Sloan Digital Sky Survey) Installation Install

Behrouz 3 Oct 28, 2022
An open-source, low-cost, image-based weed detection device for fallow scenarios.

Welcome to the OpenWeedLocator (OWL) project, an opensource hardware and software green-on-brown weed detector that uses entirely off-the-shelf compon

Guy Coleman 145 Jan 05, 2023
《Where am I looking at? Joint Location and Orientation Estimation by Cross-View Matching》(CVPR 2020)

This contains the codes for cross-view geo-localization method described in: Where am I looking at? Joint Location and Orientation Estimation by Cross-View Matching, CVPR2020.

41 Oct 27, 2022
Code of paper: "DropAttack: A Masked Weight Adversarial Training Method to Improve Generalization of Neural Networks"

DropAttack: A Masked Weight Adversarial Training Method to Improve Generalization of Neural Networks Abstract: Adversarial training has been proven to

倪仕文 (Shiwen Ni) 58 Nov 10, 2022
A unified framework to jointly model images, text, and human attention traces.

connect-caption-and-trace This repository contains the reference code for our paper Connecting What to Say With Where to Look by Modeling Human Attent

Meta Research 73 Oct 24, 2022
Merlion: A Machine Learning Framework for Time Series Intelligence

Merlion: A Machine Learning Library for Time Series Table of Contents Introduction Installation Documentation Getting Started Anomaly Detection Foreca

Salesforce 2.8k Dec 30, 2022
BEAS: Blockchain Enabled Asynchronous & Secure Federated Machine Learning

BEAS Blockchain Enabled Asynchronous and Secure Federated Machine Learning Default Network Configuration: The default application uses the HyperLedger

Harpreet Virk 11 Nov 20, 2022
GT China coal model

GT China coal model The full version of a China coal transport model with a very high spatial reslution. What it does The code works in a few steps: T

0 Dec 13, 2021
This repository contains the official implementation code of the paper Improving Multimodal Fusion with Hierarchical Mutual Information Maximization for Multimodal Sentiment Analysis, accepted at EMNLP 2021.

MultiModal-InfoMax This repository contains the official implementation code of the paper Improving Multimodal Fusion with Hierarchical Mutual Informa

Deep Cognition and Language Research (DeCLaRe) Lab 89 Dec 26, 2022
Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal, multi-exposure and multi-focus image fusion.

U2Fusion Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal (VIS-IR, medical), multi

Han Xu 129 Dec 11, 2022
Pytorch Implementation of Value Retrieval with Arbitrary Queries for Form-like Documents.

Value Retrieval with Arbitrary Queries for Form-like Documents Introduction Pytorch Implementation of Value Retrieval with Arbitrary Queries for Form-

Salesforce 13 Sep 15, 2022
Activity image-based video retrieval

Cross-modal-retrieval Our approach is focus on Activity Image-to-Video Retrieval (AIVR) task. The compared methods are state-of-the-art single modalit

BCMI 75 Oct 21, 2021
Code for Pose-Controllable Talking Face Generation by Implicitly Modularized Audio-Visual Representation (CVPR 2021)

Pose-Controllable Talking Face Generation by Implicitly Modularized Audio-Visual Representation (CVPR 2021) Hang Zhou, Yasheng Sun, Wayne Wu, Chen Cha

Hang_Zhou 628 Dec 28, 2022
learned_optimization: Training and evaluating learned optimizers in JAX

learned_optimization: Training and evaluating learned optimizers in JAX learned_optimization is a research codebase for training learned optimizers. I

Google 533 Dec 30, 2022
PyTorch implementation of Glow

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions (https://arxiv.org/abs/1807.03039) Usage: python train.p

Kim Seonghyeon 433 Dec 27, 2022
🔥RandLA-Net in Tensorflow (CVPR 2020, Oral & IEEE TPAMI 2021)

RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds (CVPR 2020) This is the official implementation of RandLA-Net (CVPR2020, Oral

Qingyong 1k Dec 30, 2022
PyTorch implementation of our ICCV paper DeFRCN: Decoupled Faster R-CNN for Few-Shot Object Detection.

Introduction This repo contains the official PyTorch implementation of our ICCV paper DeFRCN: Decoupled Faster R-CNN for Few-Shot Object Detection. Up

133 Dec 29, 2022
A Python 3 package for state-of-the-art statistical dimension reduction methods

direpack: a Python 3 library for state-of-the-art statistical dimension reduction techniques This package delivers a scikit-learn compatible Python 3

Sven Serneels 32 Dec 14, 2022