A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
Thumbor-bootcamp - learning and contribution experience with ❤️ and 🤗 from the thumbor team

Thumbor-bootcamp - learning and contribution experience with ❤️ and 🤗 from the thumbor team

Thumbor (by @globocom) 9 Jul 11, 2022
Example applications, dashboards, scripts, notebooks, and other utilities built using Polygon.io

Polygon.io Examples Example applications, dashboards, scripts, notebooks, and other utilities built using Polygon.io. Examples Preview Name Type Langu

Tim Paine 4 Jun 01, 2022
This is a vscode extension with a Virtual Assistant that you can play with when you are bored or you need help..

VS Code Virtual Assistant This is a vscode extension with a Virtual Assistant that you can play with when you are bored or you need help. Its currentl

Soham Ghugare 6 Aug 22, 2021
A Python version of Canvacord

A copy of canvacord made in python! Installation Run any of these commands in terminal: Mac / Linux pip install canvacord Windows python -m pip insta

10 Mar 28, 2022
This is a program for Carbon Emission calculator.

Summary This is a program for Carbon Emission calculator. Usage This will calculate the carbon emission by each person on various factors. Contributor

Ankit Rane 2 Feb 18, 2022
A feed generator. Currently supports generating RSS feeds from Google, Bing, and Yahoo news.

A feed generator. Currently supports generating RSS feeds from Google, Bing, and Yahoo news.

Josh Cardenzana 0 Dec 13, 2021
LSO, also known as Linux Swap Operator, is a software with both GUI and terminal versions that you can manage the Swap area for Linux operating systems.

LSO - Linux Swap Operator Türkçe - LSO Nedir? LSO, diğer adıyla Linux Swap Operator Linux işletim sistemleri için Swap alanını yönetebileceğiniz hem G

Eren İnce 4 Feb 09, 2022
SimCSE在中文任务上的简单实验

SimCSE 中文测试 SimCSE在常见中文数据集上的测试,包含ATEC、BQ、LCQMC、PAWSX、STS-B共5个任务。 介绍 博客:https://kexue.fm/archives/8348 论文:《SimCSE: Simple Contrastive Learning of Sente

苏剑林(Jianlin Su) 504 Jan 04, 2023
Modelling and Implementation of Cable Driven Parallel Manipulator System with Tension Control

Cable Driven Parallel Robots (CDPR) is also known as Cable-Suspended Robots are the emerging and flexible end effector manipulation system. Cable-driven parallel robots (CDPRs) are categorized as a t

Siddharth U 0 Jul 19, 2022
Credit Card Fraud Detection

Credit Card Fraud Detection For this project, I used the datasets from the kaggle competition called IEEE-CIS Fraud Detection. The competition aims to

RayWu 4 Jun 21, 2022
Скрипт позволяет заводить задачи в Панель мониторинга YouTrack на основе парсинга сайта safe-surf.ru

Скрипт позволяет заводить задачи в Панель мониторинга YouTrack на основе парсинга сайта safe-surf.ru

Bad_karma 3 Feb 12, 2022
Fabric mod where anyone can PR anything, concerning or not. I'll merge everything as soon as it works.

Guess What Will Happen In This Fabric mod where anyone can PR anything, concerning or not (Unless it's too concerning). I'll merge everything as soon

anatom 65 Dec 25, 2022
A tool for checking if the external data used in Flatpak manifests is still up to date

Flatpak External Data Checker This is a tool for checking for outdated or broken links of external data in Flatpak manifests. Motivation Flatpak apps

Flathub 76 Dec 24, 2022
An open letter in support of Richard Matthew Stallman being reinstated by the Free Software Foundation

An open letter in support of RMS. To sign, click here and name the file username.yaml (replace username with your name) with the following content

2.4k Jan 07, 2023
TickerRain is an open-source web app that stores and analysis Reddit posts in a transparent and semi-interactive manner.

TickerRain is an open-source web app that stores and analysis Reddit posts in a transparent and semi-interactive manner

GonVas 180 Oct 08, 2022
flake8 plugin which forbids match statements (PEP 634)

flake8-match flake8 plugin which forbids match statements (PEP 634)

Anthony Sottile 25 Nov 01, 2022
Minutaria is a basic educational Python timer used to learn python and software testing libraries.

minutaria minutaria is a basic educational Python timer. The project is educational, it aims to teach myself programming, python programming, python's

1 Jul 16, 2021
A code base for python programs the goal is to integrate all the useful and essential functions

Base Dev EN This GitHub will be available in French and English FR Ce GitHub sera disponible en français et en anglais Author License Screen EN 🇬🇧 D

Pikatsuto 1 Mar 07, 2022
ChainJacking is a tool to find which of your Go lang direct GitHub dependencies is susceptible to ChainJacking attack.

ChainJacking is a tool to find which of your Go lang direct GitHub dependencies is susceptible to ChainJacking attack.

Checkmarx 36 Nov 02, 2022
Fast Base64 encoding/decoding in Python

Fast Base64 implementation This project is a wrapper on libbase64. It aims to provide a fast base64 implementation for base64 encoding/decoding. Insta

Matthieu Darbois 96 Dec 26, 2022