A Tensorflow implementation of the Text Conditioned Auxiliary Classifier Generative Adversarial Network for Generating Images from text descriptions

Overview

TAC-GAN

This is the official Tensorflow implementation of TAC-GAN model presented in https://arxiv.org/abs/1703.06412.

Text Conditioned Auxiliary Classifier Generative Adversarial Network, (TAC-GAN) is a text to image Generative Adversarial Network (GAN) for synthesizing images from their text descriptions. TAC-GAN builds upon the AC-GAN by conditioning the generated images on a text description instead of on a class label. In the presented TAC-GAN model, the input vector of the Generative network is built based on a noise vector and another vector containing an embedded representation of the textual description. While the Discriminator is similar to that of the AC-GAN, it is also augmented to receive the text information as input before performing its classification.

For embedding the textual descriptions of the images into vectors we used skip-thought vectors

The following is the architecture of the TAC-GAN model

Prerequisites

Some important dependencies are the following and the rest can be installed using the requirements.txt

  1. Python 3.5
  2. Tensorflow 1.2.0
  3. Theano 0.9.0 : for skip thought vectors
  4. scikit-learn : for skip thought vectors
  5. NLTK 3.2.1 : for skip thought vectors

It is recommended to use a virtual environment for running this project and installing the required dependencies in it by using the requirements.txt file.

The project has been tested on a Ubuntu 14.04 machine with an 12 GB NVIDIA Titen X GPU

1. Setup and Run

1.1. Clone the Repository

git clone https://github.com/dashayushman/TAC-GAN.git
cd TAC-GAN

1.2. Download the Dataset

The model presented in the paper was trained on the flowers dataset. This To train the TAC-GAN on the flowers dataset, first, download the dataset by doing the following,

  1. Download the flower images from here. Extract the 102flowers.tgz file and copy the extracted jpg folder to Data/datasets/flowers

  2. Download the captions from here. Extract the downloaded file, copy the text_c10 folder and paste it in Data/datasets/flowers directory

  3. Download the pretrained skip-thought vectors model from here and copy the downloaded files to Data/skipthoughts

NB: It is recommended to keep all the images in an SSD if available. This makes the batch loading and processing operation faster.

1.3. Data Preprocessing

Extract the skip-thought features for the captions and prepare the dataset for training by running the following script

python dataprep.py --data_dir=Data --dataset=flowers

This script will create a set of pickled files in the datet directory which will be used during training. The following are the available flags for data preparation:

FLAG VALUE TYPE DEFAULT VALUE DESCRIPTION
data_dir str Data The data directory
dataset str flowers Dataset to use. For Eg., "flowers"

1.4. Training

To train TAC-GAN with the default hyper parameters run the following script

python train.py --dataset="flowers" --model_name=TAC-GAN

While training, you can montor samples generated by the model in the Data/training/TAC_GAN/samples directory. Notice that a directory is created according to the "model_name" taht you provide. This directory contains all the data related to a particular experiment. This can also be considered as an "experiment name" too.

The following flags can be set to change the hyperparameters of the network.

FLAG VALUE TYPE DEFAULT VALUE DESCRIPTION
z-dim int 100 Number of dimensions of the Noise vector
t_dim int 256 Number of dimensions for the latent representation of the text embedding.
batch_size int 64 Mini-Batch Size
image_size int 128 Batch size to use during training.
gf_dim int 64 Number of conv filters in the first layer of the generator.
df_dim int 64 Number of conv filters in the first layer of the discriminator.
caption_vector_length int 4800 Length of the caption vector embedding (vector generated using skip-thought vectors model).
n_classes int 102 Number of classes
data_dir String Data Data directory
learning_rate float 0.0002 Learning rate
beta1 float 0.5 Momentum for Adam Update
epochs int 200 Maximum number of epochs to train
save_every int 30 Save model and samples after this many number.of iterations
resume_model Boolean False To Load the pre-trained model
data_set String flowers Which dataset to use: "flowers"
model_name String model_1 Name of the model: Can be anything
train bool True This is True while training and false otherwise. Used for batch normalization

We used the following script (hyper-parameters) in for the results that we show in our paper

python train.py --t_dim=100 --image_size=128 --data_set=flowers --model_name=TAC_GAN --train=True --resume_model=True --z_dim=100 --n_classes=102 --epochs=400 --save_every=20 --caption_vector_length=4800 --batch_size=128

1.5. Monitoring

While training, you can monitor the updates on the terminal as well as by using tensorboard

1.5.1 The Terminal:

Terminal log

1.5.1 Tensorboard:

You can use the following script to start tensorboard and visualize realtime changes:

tensorboard --logdir=Data/training/TAC_GAN/summaries

Tensorboard

2. Generating Images for the text in the dataset

Once you have trained the model for certain epochs you can generate images for all the text descriptions in the dataset use the following script. This will create a synthetic dataset with images generated by the generator.

python train.py --data_set=flowers --epochs=100 --output_dir=Data/synthetic_dataset --checkpoints_dir=Data/training/TAC_GAN/checkpoints

Notice that the checkpoints directory is ls created automatically created inside the model directory after you run the training script.

This script will create the following directory structure:

Data
  |__synthetic_dataset
        |___ds
             |___train
             |___val

the train directory will contain all the images generated from the text descriptions of the images in the training set and the same goes for the val directory.

3. Generating Images from any Text

To generate images from any text, do the following

3.1 Add Text Descriptions:

Write your text descriptions in a file or use the example file Data/text.txt that we have provided in the Data directory. The text description file should contain one text description per line. For example,

a flower with red petals which are pointed
many pointed petals
A yellow flower

3.2 Extract Skip-Thought Vectors:

Run the following script for extracting the Skip-Thought vectors for the text descriptions

python encode_text.py --caption_file=Data/text.txt --data_dir=Data

This script will create a pickle file called Data/enc_text.pkl with features extracted from the text descriptions.

3.3 Generate Images:

To generate images for the text descriptions, run the following script,

python generate_images.py --data_set=flowers --checkpoints_dir=Data/training/TAC_GAN/checkpoints --images_per_caption=30 --data_dir=Data

This will create a directory Data/images_generated_from_text/ with a folder corresponding to every row of the text.txt file. Each of these folders will contain images for that text.

The following are the parameters you need to set, in case you have used different parameters for training the model.

FLAG VALUE TYPE DEFAULT VALUE DESCRIPTION
z-dim int 100 Number of dimensions of the Noise vector
t_dim int 256 Number of dimensions for the latent representation of the text embedding.
batch_size int 64 Mini-Batch Size
image_size int 128 Batch size to use during training.
gf_dim int 64 Number of conv filters in the first layer of the generator.
df_dim int 64 Number of conv filters in the first layer of the discriminator.
caption_vector_length int 4800 Length of the caption vector embedding (vector generated using skip-thought vectors model).
n_classes int 102 Number of classes
data_dir String Data Data directory
learning_rate float 0.0002 Learning rate
beta1 float 0.5 Momentum for Adam Update
images_per_caption int 30 Maximum number of images that you want to generate for each of the text descriptions
data_set String flowers Which dataset to use: "flowers"
checkpoints_dir String /tmp Path to the checkpoints directory which will be used to generate the images

4. Evaluation

We have used two metrics for evaluating TAC-GAN,

  1. Inception-Scope
  2. MS-SSIM score

The links are from where we adapted the code for evaluating TAC-GAN. Before evaluating the model, generate a synthetic dataset by referring to Section 6

4.1 Inception Score

To calculate the inception score, use the following script,

python inception_score.py --output_dir=Data/synthetic_dataset --data_dir=Data --n_images=30000 --image_size=128

This will create a collection of all the generated images in Data/synthetic_dataset/ds_inception and show the inception score on the terminal.

The following are the set of available parameters/flags

FLAG VALUE TYPE DEFAULT VALUE DESCRIPTION
output_dir str Data/ds_inception Directory to dump all the images for calculating the inception score
data_dir str Data/synthetic_dataset/ds The root directory of the synthetic dataset
n_images int 30000 Number of images to consider for calculating inception score
image_size int 128 Size of the image to consider for calculating inception score

4.2 MS-SSIM

To calculate the MS-SSIM score, use the following script,

python inception_score.py --output_dir=Data --data_dir=Data --dataset=flowers --syn_dataset_dir=Data/synthetic_datset/ds

This will create a Data/msssim.tsv tab separated file. The data in this file is structured as follows


     
       
     

     
    
   

Once you have generated the msssim.tsv file, you can use the following script to generate a figure to compare the MS-SSIM score of the images in the real dataset with the images in the synthetic dataset belonging to the same class,

python utility/plot_msssim.py --input_file=Data/msssim.tsv --output_file=Data/msssim

This will create Data/msssim.pdf, which is the .pdf file of the generated figure.

5. Generate Interpolated Images

In our paper we show the effect of interpolating the noise and the text embedding vectors on the generated image. Images are randomply selected and their text descriptions are used to generate synthetic images. The following sub-sections will elaborate on how to do it and which scripts will help you in doing it.

5.1 Z (Noise) Interpolation

For interpolating the noise vector and generating images, use the following scripts

python z_interpolation.py --output_dir=Data/synthetic_dataset --data_set=flowers --checkpoints_dir=Data/training/TAC_GAN/checkpoints --n_images=500

This will generate the interpolated images in Data/synthetic_dataset/z_interpolation/.

5.1 T (Text Embedding) Interpolation

For interpolating the text embedding vectors and generating images, use the following scripts

python t_interpolation.py --output_dir=Data/synthetic_dataset --data_set=flowers --checkpoints_dir=Data/training/TAC_GAN/checkpoints --n_images=500

This will generate the interpolated images in Data/synthetic_dataset/t_interpolation/.

NOTE: Both the above mentioned scripts have the same flags/arguments, which are the following,

FLAG VALUE TYPE DEFAULT VALUE DESCRIPTION
z-dim int 100 Number of dimensions of the Noise vector
t_dim int 256 Number of dimensions for the latent representation of the text embedding.
batch_size int 64 Mini-Batch Size
image_size int 128 Batch size to use during training.
gf_dim int 64 Number of conv filters in the first layer of the generator.
df_dim int 64 Number of conv filters in the first layer of the discriminator.
caption_vector_length int 4800 Length of the caption vector embedding (vector generated using skip-thought vectors model).
n_classes int 102 Number of classes
data_dir String Data Data directory
learning_rate float 0.0002 Learning rate
beta1 float 0.5 Momentum for Adam Update
data_set str flowers The dataset to use: "flowers"
output_dir String Data/synthetic_dataset The directory in which the t_interpolated images will be generated
checkpoints_dir String /tmp Path to the checkpoints directory which will be used to generate the images
n_interp int 100 The factor difference between each interpolation (Should ideally be a multiple of 10)
n_images int 500 Number of images to randomply sample for generating interpolation results

6. References

TAC-GAN

If you find this code usefull, then please use the following BibTex to cite our work.

@article{dash2017tac,
  title={TAC-GAN-Text Conditioned Auxiliary Classifier Generative Adversarial Network},
  author={Dash, Ayushman and Gamboa, John Cristian Borges and Ahmed, Sheraz and Afzal, Muhammad Zeshan and Liwicki, Marcus},
  journal={arXiv preprint arXiv:1703.06412},
  year={2017}
}

Oxford-102 Flowers Dataset

If you use the Oxford-102 Flowers Dataset, then please cite their work using the following BibTex.

@InProceedings{Nilsback08,
   author = "Nilsback, M-E. and Zisserman, A.",
   title = "Automated Flower Classification over a Large Number of Classes",
   booktitle = "Proceedings of the Indian Conference on Computer Vision, Graphics and Image Processing",
   year = "2008",
   month = "Dec"
}

Skip-Thought

If you use the Skip-Thought model in your work like us, then please cite their work using the following BibTex

@article{kiros2015skip,
  title={Skip-Thought Vectors},
  author={Kiros, Ryan and Zhu, Yukun and Salakhutdinov, Ruslan and Zemel, Richard S and Torralba, Antonio and Urtasun, Raquel and Fidler, Sanja},
  journal={arXiv preprint arXiv:1506.06726},
  year={2015}
}

Code

We have referred to the text-to-image and DCGAN-tensorflow repositories for developing our code, and we are extremely thankful to them.

Owner
Ayushman Dash
I am a research scientist, musician and composer. I have an affinity towards Machine Learning, Deep Neural Networks and I wish to add more to the community.
Ayushman Dash
This implements the learning and inference/proposal algorithm described in "Learning to Propose Objects, Krähenbühl and Koltun"

Learning to propose objects This implements the learning and inference/proposal algorithm described in "Learning to Propose Objects, Krähenbühl and Ko

Philipp Krähenbühl 90 Sep 10, 2021
An implementation of "MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing" (ICML 2019).

MixHop and N-GCN ⠀ A PyTorch implementation of "MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing" (ICML 2019)

Benedek Rozemberczki 393 Dec 13, 2022
JstDoS - HTTP Protocol Stack Remote Code Execution Vulnerability

jstDoS If you are going to skid that, please give credits ! ^^ ¿How works? This

apolo 4 Feb 11, 2022
Extreme Rotation Estimation using Dense Correlation Volumes

Extreme Rotation Estimation using Dense Correlation Volumes This repository contains a PyTorch implementation of the paper: Extreme Rotation Estimatio

Ruojin Cai 29 Nov 18, 2022
Differentiable Neural Computers, Sparse Access Memory and Sparse Differentiable Neural Computers, for Pytorch

Differentiable Neural Computers and family, for Pytorch Includes: Differentiable Neural Computers (DNC) Sparse Access Memory (SAM) Sparse Differentiab

ixaxaar 302 Dec 14, 2022
A Real-Time-Strategy game for Deep Learning research

Description DeepRTS is a high-performance Real-TIme strategy game for Reinforcement Learning research. It is written in C++ for performance, but provi

Centre for Artificial Intelligence Research (CAIR) 156 Dec 19, 2022
a general-purpose Transformer based vision backbone

Swin Transformer By Ze Liu*, Yutong Lin*, Yue Cao*, Han Hu*, Yixuan Wei, Zheng Zhang, Stephen Lin and Baining Guo. This repo is the official implement

Microsoft 9.9k Jan 08, 2023
Implementation for ACProp ( Momentum centering and asynchronous update for adaptive gradient methdos, NeurIPS 2021)

This repository contains code to reproduce results for submission NeurIPS 2021, "Momentum Centering and Asynchronous Update for Adaptive Gradient Meth

Juntang Zhuang 15 Jun 11, 2022
DETReg: Unsupervised Pretraining with Region Priors for Object Detection

DETReg: Unsupervised Pretraining with Region Priors for Object Detection Amir Bar, Xin Wang, Vadim Kantorov, Colorado J Reed, Roei Herzig, Gal Chechik

Amir Bar 283 Dec 27, 2022
Adversarial Graph Representation Adaptation for Cross-Domain Facial Expression Recognition (AGRA, ACM 2020, Oral)

Cross Domain Facial Expression Recognition Benchmark Implementation of papers: Cross-Domain Facial Expression Recognition: A Unified Evaluation Benchm

89 Dec 09, 2022
OCRA (Object-Centric Recurrent Attention) source code

OCRA (Object-Centric Recurrent Attention) source code Hossein Adeli and Seoyoung Ahn Please cite this article if you find this repository useful: For

Hossein Adeli 2 Jun 18, 2022
PyTorch implementation of PSPNet

PSPNet with PyTorch Unofficial implementation of "Pyramid Scene Parsing Network" (https://arxiv.org/abs/1612.01105). This repository is just for caffe

Kazuto Nakashima 52 Nov 16, 2022
Pytorch Implementation of Residual Vision Transformers(ResViT)

ResViT Official Pytorch Implementation of Residual Vision Transformers(ResViT) which is described in the following paper: Onat Dalmaz and Mahmut Yurt

ICON Lab 41 Dec 08, 2022
AdamW optimizer and cosine learning rate annealing with restarts

AdamW optimizer and cosine learning rate annealing with restarts This repository contains an implementation of AdamW optimization algorithm and cosine

Maksym Pyrozhok 133 Dec 20, 2022
Code for the IJCAI 2021 paper "Structure Guided Lane Detection"

SGNet Project for the IJCAI 2021 paper "Structure Guided Lane Detection" Abstract Recently, lane detection has made great progress with the rapid deve

Jinming Su 27 Dec 08, 2022
A python implementation of Physics-informed Spline Learning for nonlinear dynamics discovery

PiSL A python implementation of Physics-informed Spline Learning for nonlinear dynamics discovery. Sun, F., Liu, Y. and Sun, H., 2021. Physics-informe

Fangzheng (Andy) Sun 8 Jul 13, 2022
Lightweight library to build and train neural networks in Theano

Lasagne Lasagne is a lightweight library to build and train neural networks in Theano. Its main features are: Supports feed-forward networks such as C

Lasagne 3.8k Dec 29, 2022
Repository for benchmarking graph neural networks

Benchmarking Graph Neural Networks Updates Nov 2, 2020 Project based on DGL 0.4.2. See the relevant dependencies defined in the environment yml files

NTU Graph Deep Learning Lab 2k Jan 03, 2023
TensorFlow implementation of original paper : https://github.com/hszhao/PSPNet

Keras implementation of PSPNet(caffe) Implemented Architecture of Pyramid Scene Parsing Network in Keras. For the best compability please use Python3.

VladKry 386 Dec 29, 2022
Code to reproduce the experiments from our NeurIPS 2021 paper " The Limitations of Large Width in Neural Networks: A Deep Gaussian Process Perspective"

Code To run: python runner.py new --save SAVE_NAME --data PATH_TO_DATA_DIR --dataset DATASET --model model_name [options] --n 1000 - train - t

Geoff Pleiss 5 Dec 12, 2022