GAN JAX - A toy project to generate images from GANs with JAX

Related tags

Deep LearningGANJax
Overview

GAN JAX - A toy project to generate images from GANs with JAX

This project aims to bring the power of JAX, a Python framework developped by Google and DeepMind to train Generative Adversarial Networks for images generation.

JAX

JAX logo

JAX is a framework developed by Deep-Mind (Google) that allows to build machine learning models in a more powerful (XLA compilation) and flexible way than its counterpart Tensorflow, using a framework almost entirely based on the nd.array of numpy (but stored on the GPU, or TPU if available). It also provides new utilities for gradient computation (per sample, jacobian with backward propagation and forward-propagation, hessian...) as well as a better seed system (for reproducibility) and a tool to batch complicated operations automatically and efficiently.

Github link: https://github.com/google/jax

GAN

GAN diagram

Generative adversarial networks (GANs) are algorithmic architectures that use two neural networks, pitting one against the other (thus the adversarial) in order to generate new, synthetic instances of data that can pass for real data. They are used widely in image generation, video generation and voice generation. GANs were introduced in a paper by Ian Goodfellow and other researchers at the University of Montreal, including Yoshua Bengio, in 2014. Referring to GANs, Facebook’s AI research director Yann LeCun called adversarial training the most interesting idea in the last 10 years in ML. (source)

Original paper: https://arxiv.org/abs/1406.2661

Some ideas have improved the training of the GANs by the years. For example:

Deep Convolution GAN (DCGAN) paper: https://arxiv.org/abs/1511.06434

Progressive Growing GAN (ProGAN) paper: https://arxiv.org/abs/1710.10196

The goal of this project is to implement these ideas in JAX framework.

Installation

You can install JAX following the instruction on JAX - Installation

It is strongly recommended to run JAX on Linux with CUDA available (Windows has no stable support yet). In this case you can install JAX using the following command:

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Then you can install Tensorflow to benefit from tf.data.Dataset to handle the data and the pre-installed dataset. However, Tensorfow allocate memory of the GPU on use (which is not optimal for running calculation with JAX). Therefore, you should install Tensorflow on the CPU instead of the GPU. Visit this site Tensorflow - Installation with pip to install the CPU-only version of Tensorflow 2 depending on your OS and your Python version.

Exemple with Linux and Python 3.9:

pip install tensorflow -f https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow_cpu-2.6.0-cp39-cp39-manylinux2010_x86_64.whl

Then you can install the other librairies from requirements.txt. It will install Haiku and Optax, two usefull add-on libraries to implement and optimize machine learning models with JAX.

pip install -r requirements.txt

Install CelebA dataset (optional)

To use the CelebA dataset, you need to download the dataset from Kaggle and install the images in the folder img_align_celeba/ in data/CelebA/images. It is recommended to download the dataset from this source because the faces are already cropped.

Note: the other datasets will be automatically installed with keras or tensorflow-datasets.

Quick Start

You can test a pretrained GAN model by using apps/test.py. It will download the model from pretrained models (in pre_trained/) and generate pictures. You can change the GAN to test by changing the path in the script.

You can also train your own GAN from scratch with apps/train.py. To change the parameters of the training, you can change the configs in the script. You can also change the dataset or the type of GAN by changing the imports (there is only one workd to change for each).

Example to train a GAN in celeba (64x64):

from utils.data import load_images_celeba_64 as load_images

To train a DCGAN:

from gan.dcgan import DCGAN as GAN

Then you can implement your own GAN and train/test them in your own dataset (by overriding the appropriate functions, check the examples in the repository).

Some results of pre-trained models

- Deep Convolution GAN

  • On MNIST:

DCGAN Cifar10

  • On Cifar10:

DCGAN Cifar10

  • On CelebA (64x64):

DCGAN CelebA-64

- Progressive Growing GAN

  • On MNIST:

  • On Cifar10:

  • On CelebA (64x64):

  • On CelebA (128x128):

Owner
Valentin Goldité
Student at CentraleSupelec (top french Engineer School) specialized in machine learning (Computer Vision, NLP, Audio, RL, Time Analysis).
Valentin Goldité
(3DV 2021 Oral) Filtering by Cluster Consistency for Large-Scale Multi-Image Matching

Scalable Cluster-Consistency Statistics for Robust Multi-Object Matching (3DV 2021 Oral Presentation) Filtering by Cluster Consistency (FCC) is a very

Yunpeng Shi 11 Sep 28, 2022
Nested cross-validation is necessary to avoid biased model performance in embedded feature selection in high-dimensional data with tiny sample sizes

Pruner for nested cross-validation - Sphinx-Doc Nested cross-validation is necessary to avoid biased model performance in embedded feature selection i

1 Dec 15, 2021
Code for STFT Transformer used in BirdCLEF 2021 competition.

STFT_Transformer Code for STFT Transformer used in BirdCLEF 2021 competition. The STFT Transformer is a new way to use Transformers similar to Vision

Jean-François Puget 69 Sep 29, 2022
This repository contains a re-implementation of the code for the CVPR 2021 paper "Omnimatte: Associating Objects and Their Effects in Video."

Omnimatte in PyTorch This repository contains a re-implementation of the code for the CVPR 2021 paper "Omnimatte: Associating Objects and Their Effect

Erika Lu 728 Dec 28, 2022
This repository focus on Image Captioning & Video Captioning & Seq-to-Seq Learning & NLP

Awesome-Visual-Captioning Table of Contents ACL-2021 CVPR-2021 AAAI-2021 ACMMM-2020 NeurIPS-2020 ECCV-2020 CVPR-2020 ACL-2020 AAAI-2020 ACL-2019 NeurI

Ziqi Zhang 362 Jan 03, 2023
Infrastructure as Code (IaC) for a self-hosted version of Gnosis Safe on AWS

Welcome to Yearn Gnosis Safe! Setting up your local environment Infrastructure Deploying Gnosis Safe Prerequisites 1. Create infrastructure for secret

Numan 16 Jul 18, 2022
FaceQgen: Semi-Supervised Deep Learning for Face Image Quality Assessment

FaceQgen FaceQgen: Semi-Supervised Deep Learning for Face Image Quality Assessment This repository is based on the paper: "FaceQgen: Semi-Supervised D

Javier Hernandez-Ortega 3 Aug 04, 2022
A general-purpose programming language, focused on simplicity, safety and stability.

The Rivet programming language A general-purpose programming language, focused on simplicity, safety and stability. Rivet's goal is to be a very power

The Rivet programming language 17 Dec 29, 2022
This repository accompanies our paper “Do Prompt-Based Models Really Understand the Meaning of Their Prompts?”

This repository accompanies our paper “Do Prompt-Based Models Really Understand the Meaning of Their Prompts?” Usage To replicate our results in Secti

Albert Webson 64 Dec 11, 2022
JittorVis - Visual understanding of deep learning models

JittorVis: Visual understanding of deep learning model JittorVis is an open-source library for understanding the inner workings of Jittor models by vi

thu-vis 182 Jan 06, 2023
π-GAN: Periodic Implicit Generative Adversarial Networks for 3D-Aware Image Synthesis

π-GAN: Periodic Implicit Generative Adversarial Networks for 3D-Aware Image Synthesis Project Page | Paper | Data Eric Ryan Chan*, Marco Monteiro*, Pe

375 Dec 31, 2022
Implementation for ACProp ( Momentum centering and asynchronous update for adaptive gradient methdos, NeurIPS 2021)

This repository contains code to reproduce results for submission NeurIPS 2021, "Momentum Centering and Asynchronous Update for Adaptive Gradient Meth

Juntang Zhuang 15 Jun 11, 2022
CyTran: Cycle-Consistent Transformers for Non-Contrast to Contrast CT Translation

CyTran: Cycle-Consistent Transformers for Non-Contrast to Contrast CT Translation We propose a novel approach to translate unpaired contrast computed

Nicolae Catalin Ristea 13 Jan 02, 2023
EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit

EvoJAX: Hardware-Accelerated Neuroevolution EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit. Built on top of the JA

Google 598 Jan 07, 2023
To build a regression model to predict the concrete compressive strength based on the different features in the training data.

Cement-Strength-Prediction Problem Statement To build a regression model to predict the concrete compressive strength based on the different features

Ashish Kumar 4 Jun 11, 2022
CVPR2021: Temporal Context Aggregation Network for Temporal Action Proposal Refinement

Temporal Context Aggregation Network - Pytorch This repo holds the pytorch-version codes of paper: "Temporal Context Aggregation Network for Temporal

Zhiwu Qing 63 Sep 27, 2022
Object Tracking and Detection Using OpenCV

Object tracking is one such application of computer vision where an object is detected in a video, otherwise interpreted as a set of frames, and the object’s trajectory is estimated. For instance, yo

Happy N. Monday 4 Aug 21, 2022
Google AI Open Images - Object Detection Track: Open Solution

Google AI Open Images - Object Detection Track: Open Solution This is an open solution to the Google AI Open Images - Object Detection Track 😃 More c

minerva.ml 46 Jun 22, 2022
Revisiting Oxford and Paris: Large-Scale Image Retrieval Benchmarking

Revisiting Oxford and Paris: Large-Scale Image Retrieval Benchmarking We revisit and address issues with Oxford 5k and Paris 6k image retrieval benchm

Filip Radenovic 188 Dec 17, 2022
[arXiv22] Disentangled Representation Learning for Text-Video Retrieval

Disentangled Representation Learning for Text-Video Retrieval This is a PyTorch implementation of the paper Disentangled Representation Learning for T

Qiang Wang 49 Dec 18, 2022