Hierarchical Few-Shot Generative Models

Overview

Hierarchical Few-Shot Generative Models

Giorgio Giannone, Ole Winther

This repo contains code and experiments for the paper Hierarchical Few-Shot Generative Models.


Settings

Clone the repo:

git clone https://github.com/georgosgeorgos/hierarchical-few-shot-generative-models
cd hierarchical-few-shot-generative-models

Create and activate the conda env:

conda env create -f environment.yml
conda activate hfsgm

The code has been tested on Ubuntu 18.04, Python 3.6 and CUDA 11.3

We use wandb for visualization. The first time you run the code you will need to login.

Data

We provide preprocessed Omniglot dataset.

From the main folder, copy the data in data/omniglot_ns/:

wget https://github.com/georgosgeorgos/hierarchical-few-shot-generative-models/releases/download/Omniglot/omni_train_val_test.pkl

For CelebA you need to download the dataset from here.

Dataset

In dataset we provide utilities to process and augment datasets in the few-shot setting. Each dataset is a large collection of small sets. Sets can be created dynamically. The dataset/base.py file collects basic info about the datasets. For binary datasets (omniglot_ns.py) we augment using flipping and rotations. For RGB datasets (celeba.py) we use only flipping.

Experiment

In experiment we implement scripts for model evaluation, experiments and visualizations.

  • attention.py - visualize attention weights and heads for models with learnable aggregations (LAG).
  • cardinality.py - compute ELBOs for different input set size: [1, 2, 5, 10, 20].
  • classifier_mnist.py - few-shot classifiers on MNIST.
  • kl_layer.py - compute KL over z and c for each layer in latent space.
  • marginal.py - compute approximate log-marginal likelihood with 1K importance samples.
  • refine_vis.py - visualize refined samples.
  • sampling_rgb.py - reconstruction, conditional, refined, unconditional sampling for RGB datasets.
  • sampling_transfer.py - reconstruction, conditional, refined, unconditional sampling on transfer datasets.
  • sampling.py - reconstruction, conditional, refined, unconditional sampling for binary datasets.
  • transfer.py - compute ELBOs on MNIST, DoubleMNIST, TripleMNIST.

Model

In model we implement baselines and model variants.

  • base.py - base class for all the models.
  • vae.py - Variational Autoencoder (VAE).
  • ns.py - Neural Statistician (NS).
  • tns.py - NS with learnable aggregation (NS-LAG).
  • cns.py - NS with convolutional latent space (CNS).
  • ctns.py - CNS with learnable aggregation (CNS-LAG).
  • hfsgm.py - Hierarchical Few-Shot Generative Model (HFSGM).
  • thfsgm.py - HFSGM with learnable aggregation (HFSGM-LAG).
  • chfsgm.py - HFSGM with convolutional latent space (CHFSGM).
  • cthfsgm.py - CHFSGM with learnable aggregation (CHFSGM-LAG).

Script

Scripts used for training the models in the paper.

To run a CNS on Omniglot:

sh script/main_cns.sh GPU_NUMBER omniglot_ns

Train a model

To train a generic model run:

python main.py --name {VAE, NS, CNS, CTNS, CHFSGM, CTHFSGM} \
               --model {vae, ns, cns, ctns, chfsgm, cthfsgm} \
               --augment \
               --dataset omniglot_ns \
               --likelihood binary \
               --hidden-dim 128 \
               --c-dim 32 \
               --z-dim 32 \
               --output-dir /output \
               --alpha-step 0.98 \
               --alpha 2 \
               --adjust-lr \
               --scheduler plateau \
               --sample-size {2, 5, 10} \
               --sample-size-test {2, 5, 10} \
               --num-classes 1 \
               --learning-rate 1e-4 \
               --epochs 400 \
               --batch-size 100 \
               --tag (optional string)

If you do not want to save logs, use the flag --dry_run. This flag will call utils/trainer_dry.py instead of trainer.py.


Acknowledgments

A lot of code and ideas borrowed from:

You might also like...
Official PyTorch Implementation of Hypercorrelation Squeeze for Few-Shot Segmentation, arXiv 2021
Official PyTorch Implementation of Hypercorrelation Squeeze for Few-Shot Segmentation, arXiv 2021

Hypercorrelation Squeeze for Few-Shot Segmentation This is the implementation of the paper "Hypercorrelation Squeeze for Few-Shot Segmentation" by Juh

Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch
Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch

Cross Transformers - Pytorch (wip) Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch Install $ pip install cross-t

Official repository for Few-shot Image Generation via Cross-domain Correspondence (CVPR '21)
Official repository for Few-shot Image Generation via Cross-domain Correspondence (CVPR '21)

Few-shot Image Generation via Cross-domain Correspondence Utkarsh Ojha, Yijun Li, Jingwan Lu, Alexei A. Efros, Yong Jae Lee, Eli Shechtman, Richard Zh

[CVPR 2021] Few-shot 3D Point Cloud Semantic Segmentation
[CVPR 2021] Few-shot 3D Point Cloud Semantic Segmentation

Few-shot 3D Point Cloud Semantic Segmentation Created by Na Zhao from National University of Singapore Introduction This repository contains the PyTor

Few-Shot Graph Learning for Molecular Property Prediction

Few-shot Graph Learning for Molecular Property Prediction Introduction This is the source code and dataset for the following paper: Few-shot Graph Lea

Few-shot Relation Extraction via Bayesian Meta-learning on Relation Graphs

Few-shot Relation Extraction via Bayesian Meta-learning on Relation Graphs This is an implemetation of the paper Few-shot Relation Extraction via Baye

The implementation of PEMP in paper
The implementation of PEMP in paper "Prior-Enhanced Few-Shot Segmentation with Meta-Prototypes"

Prior-Enhanced network with Meta-Prototypes (PEMP) This is the PyTorch implementation of PEMP. Overview of PEMP Meta-Prototypes & Adaptive Prototypes

Code and data of the ACL 2021 paper: Few-Shot Text Ranking with Meta Adapted Synthetic Weak Supervision

MetaAdaptRank This repository provides the implementation of meta-learning to reweight synthetic weak supervision data described in the paper Few-Shot

Adaptive Prototype Learning and Allocation for Few-Shot Segmentation (CVPR 2021)
Adaptive Prototype Learning and Allocation for Few-Shot Segmentation (CVPR 2021)

ASGNet The code is for the paper "Adaptive Prototype Learning and Allocation for Few-Shot Segmentation" (accepted to CVPR 2021) [arxiv] Overview data/

Releases(Omniglot)
Owner
Giorgio Giannone
Science is built up with data, as a house is with stones. But a collection of data is no more a science than a heap of stones is a house. (J.H. Poincaré)
Giorgio Giannone
Multiple Object Tracking with Yolov5!

Tracking with yolov5 This implementation is for who need to tracking multi-object only with detector. You can easily track mult-object with your well

9 Nov 08, 2022
ICLR 2021, Fair Mixup: Fairness via Interpolation

Fair Mixup: Fairness via Interpolation Training classifiers under fairness constraints such as group fairness, regularizes the disparities of predicti

Ching-Yao Chuang 49 Nov 22, 2022
The official implementation of EIGNN: Efficient Infinite-Depth Graph Neural Networks (NeurIPS 2021)

EIGNN: Efficient Infinite-Depth Graph Neural Networks The official implementation of EIGNN: Efficient Infinite-Depth Graph Neural Networks (NeurIPS 20

Juncheng Liu 14 Nov 22, 2022
Dialect classification

Dialect-Classification This repository presents the data that was used in a talk at ICKL-5 (5th International Conference on Kurdish Linguistics) at th

Kurdish-BLARK 0 Nov 12, 2021
For IBM Quantum Challenge Africa 2021, 9 September (07:00 UTC) - 20 September (23:00 UTC).

IBM Quantum Challenge Africa 2021 To ensure Africa is able to apply quantum computing to solve problems relevant to the continent, the IBM Research La

Qiskit Community 48 Dec 25, 2022
Python based framework for Automatic AI for Regression and Classification over numerical data.

Python based framework for Automatic AI for Regression and Classification over numerical data. Performs model search, hyper-parameter tuning, and high-quality Jupyter Notebook code generation.

BlobCity, Inc 141 Dec 21, 2022
Fast SHAP value computation for interpreting tree-based models

FastTreeSHAP FastTreeSHAP package is built based on the paper Fast TreeSHAP: Accelerating SHAP Value Computation for Trees published in NeurIPS 2021 X

LinkedIn 369 Jan 04, 2023
EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit

EvoJAX: Hardware-Accelerated Neuroevolution EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit. Built on top of the JA

Google 598 Jan 07, 2023
A benchmark dataset for emulating atmospheric radiative transfer in weather and climate models with machine learning (NeurIPS 2021 Datasets and Benchmarks Track)

ClimART - A Benchmark Dataset for Emulating Atmospheric Radiative Transfer in Weather and Climate Models Official PyTorch Implementation Using deep le

21 Dec 31, 2022
Liquid Warping GAN with Attention: A Unified Framework for Human Image Synthesis

Liquid Warping GAN with Attention: A Unified Framework for Human Image Synthesis, including human motion imitation, appearance transfer, and novel view synthesis. Currently the paper is under review

2.3k Jan 05, 2023
“袋鼯麻麻——智能购物平台”能够精准地定位识别每一个商品

“袋鼯麻麻——智能购物平台”能够精准地定位识别每一个商品,并且能够返回完整地购物清单及顾客应付的实际商品总价格,极大地降低零售行业实际运营过程中巨大的人力成本,提升零售行业无人化、自动化、智能化水平。

thomas-yanxin 192 Jan 05, 2023
Capture all information throughout your model's development in a reproducible way and tie results directly to the model code!

Rubicon Purpose Rubicon is a data science tool that captures and stores model training and execution information, like parameters and outcomes, in a r

Capital One 97 Jan 03, 2023
Realtime_Multi-Person_Pose_Estimation

Introduction Multi Person PoseEstimation By PyTorch Results Require Pytorch Installation git submodule init && git submodule update Demo Download conv

tensorboy 1.3k Jan 05, 2023
X-modaler is a versatile and high-performance codebase for cross-modal analytics.

X-modaler X-modaler is a versatile and high-performance codebase for cross-modal analytics. This codebase unifies comprehensive high-quality modules i

910 Dec 28, 2022
Heart Arrhythmia Classification

This program takes and input of an ECG in European Data Format (EDF) and outputs the classification for heartbeats into normal vs different types of arrhythmia . It uses a deep learning model for cla

4 Nov 02, 2022
Implementation of our recent paper, WOOD: Wasserstein-based Out-of-Distribution Detection.

WOOD Implementation of our recent paper, WOOD: Wasserstein-based Out-of-Distribution Detection. Abstract The training and test data for deep-neural-ne

8 Dec 24, 2022
Generate high quality pictures. GAN. Generative Adversarial Networks

ESRGAN generate high quality pictures. GAN. Generative Adversarial Networks """ Super-resolution of CelebA using Generative Adversarial Networks. The

Lieon 1 Dec 14, 2021
Notebooks for my "Deep Learning with TensorFlow 2 and Keras" course

Deep Learning with TensorFlow 2 and Keras – Notebooks This project accompanies my Deep Learning with TensorFlow 2 and Keras trainings. It contains the

Aurélien Geron 1.9k Dec 15, 2022
the code for paper "Energy-Based Open-World Uncertainty Modeling for Confidence Calibration"

EOW-Softmax This code is for the paper "Energy-Based Open-World Uncertainty Modeling for Confidence Calibration". Accepted by ICCV21. Usage Commnd exa

Yezhen Wang 36 Dec 02, 2022
Prefix-Tuning: Optimizing Continuous Prompts for Generation

Prefix Tuning Files: . ├── gpt2 # Code for GPT2 style autoregressive LM │ ├── train_e2e.py # high-level script

530 Jan 04, 2023