PuppetGAN - Cross-Domain Feature Disentanglement and Manipulation just got way better! πŸš€

Overview

Better Cross-Domain Feature Disentanglement and Manipulation with Improved PuppetGAN

Quite cool... Right?

Introduction

This repo contains a TensorFlow implementation of PuppetGAN as well as an improved version of it, capable of manipulating features up to 100% better and up to 300% faster! 😎

PuppetGAN is model that extends the CycleGAN idea and is capable of extracting and manipulating features from a domain using examples from a different domain. On top of that, one amazing aspect of PuppetGAN is that it does not require a great amount of data; the biggest dataset I used contained 5000 sets of examples while the smallest one just slightly over 1000 sets of examples!

The Model(s)

Overview

PuppetGAN consists of 4 different components; one that is responsible for learning to reconstruct the input images, one that is responsible for learning to disentangle the the Attribute of Interest, a CycleGAN component and an Attribute CycleGAN. The Attribute CycleGAN acts in a similar manner to CycleGAN with the exception that it deals with cross-domain inputs.



The full architecture of the baseline PuppetGAN (the image is copied from the original paper)

With this repo I add a few more components, which I call Roids, that greatly improve the performance of the Baseline PuppetGAN. One Roid is applied in the disentanglement part and the rest in the attribute cycle part while the objective of all of them is pretty much the same; to guarantee better disentanglement!

  • The original architecture performs the disentanglement only in the synthetic domain and this ability is passed to the real domain through implicitly. The disentanglement Roid takes advantage of the CycleGAN model and performs the disentanglement in the translations of the synthetic images passing the ability explicitly to the real domain.

  • The attribute cycle Roids act in a similar way, but they instead force the attributes, other that the Attribute of Interest, of the cross-domain translations to be as precise as possible. This can be seen as a more strict version of the disentanglement Roid as well.



The Disentanglement Roid



The Attribute Cycle Roids

Implementation

The only difference between my Baseline and the model from the paper is that my generators and discriminators are modified versions of the ones used in TensorFlow's CycleGAN tutorial. The fact that the creators of PuppetGAN used ResNet blocks may be partially responsible for the memorization effect that seems to be present in some of the results of the paper since the skip connections can allow information to be passed unchanged between different layers.

Other than that, all my implementations use exactly the same parameters as the ones in the original model. Also, neither my architectures nor the parameters have been modified at all between different datasets.

Performance

Both my Baseline implementation and my proposed architecture(s) significantly outperform the original PuppetGAN!

Rotation of MNIST digits

By the Numbers

Just like in the original paper, all the reported scores are for the MNIST dataset. Due to the fact that I didn't have access to the size dataset, I was able to measure the performance of my models only in the rotation dataset.

PuppetGAN Accuracy Epoch
Original (paper) 0.97 0.40 0.01 -
My Baseline 0.96 0.59 0.01 300
Roids in Attribute Cycle Component 0.97 0.82 0.02 100
Roids in Disentanglement Component 0.91 0.73 0.01 250
Roids in Both Components 0.97 0.79 0.01 300
  • Accuracy (The closer to 1 the better)

The accuracy measures, using a LeNet-5 network, how well the original class is preserved. In other words, this metric is indicative of how well the model manages to disentangle without affecting the rest of the attributes. As we'll see later it is possible though to get very high accuracy while having suboptimal disentanglement performance...

  • (The closer to 1 the better)

This score is the correlation coefficient between the Attribute of Interest between the known and the generated images and it captures how well the model manipulates the Attribute of Interest.

  • (The closer to 0 the better)

This score captures how similar are the results between images that have identical the Attribute of Interest and different the rest of the attributes. For this metric I report the standard deviation instead of the variance, that it is mentioned in the paper, due to the fact that the variance of my models was magnitudes smaller than the one reported on the paper. This makes me believe that the standard deviation was used in the paper as well.

Discussion about the Results

Mouth manipulation after 440 epochs, using the Baseline.

Mouth manipulation after 190 epochs with Roids in the Attribute Cycle component. The model learns to both open and close the mouth more accurately, disentangle in a better way, produce more clear images and all that way faster!

The most well balanced model seems to be one that uses both kinds of Roids, since it achieves the same accuracy and score as the original model while increasing the manipulation score by more than 30% compared to my Baseline implementation and almost 100% compared to the original paper. Nevertheless, although it is intuitive that a combination of all the Roids would yield better results, I believe that more experiments are required to determine if its benefits are sufficient to outweigh the great speed up of the model that uses Roids only in the Attribute Cycle component.

MNIST rotation after adding Roids on the Attribute Cycle component

For now, I would personally favor the model that uses only the Roids of the Attribute Cycle component due to the fact that it manages to outperform every other model in the AoI manipulation score at 1/3 of the time, while having an insignificant difference in the value of . As an extra trick, I found that not updating the discriminator in the Attribute Cycle Roids could improve the performance slightly, but that's just an additional hack.

Each Roid implicitly affects the weight of its respective loss due to the fact that extra terms are added to it. In order to ensure that the performance boost is not caused by the increased loss weight, I am providing a comparison between the performance of the model with the Roids in the Attribute Cycle component and the Baseline model with twice the weights of the Attribute Cycle Component.

PuppetGAN Accuracy Epoch
Original (paper) 0.97 0.40 0.01 -
My Baseline 0.96 0.59 0.01 300
Weighted Baseline 0.84 0.85 0.01 100
Weighted Baseline 0.93 0.72 0.01 150
Weighted Baseline 0.92 0.68 0.01 200
Weighted Baseline 0.95 0.63 0.01 300
Roids in Attribute Cycle Component 0.97 0.82 0.02 100

The above results show that increasing the weights of the Attribute Cycle losses can slightly increase the performance of PuppetGAN, but such a model would be comparable to the Baseline and not to the model that utilizes the Roids.

Comparison to the original results

A significant drawback of the original model is that seems to memorizes seen images instead of editing the given ones. This can be observed in the rotation results reported in the paper where the representation of a real digit may change during the rotation or different representations of a real digit may have the same rotated representations. This doesn't stop it though from having a very high accuracy, which highlights why this metric is not necessarily ideal for calculating the quality of the disentanglement.

The rotation results of the paper

Another issue with both the model of the paper and my models can be observed in the mouth dataset, where PuppetGAN confuses the microphone with the opening of the mouth; when the synthetic image dictates a wider opening, PuppetGAN moves the microphone closer to the mouth. This effect is slightly bigger in my Baseline but I believe that it is due to the fact that I haven't done any hyper-parameter tuning; some experimentation with the magnitude of the noise or with the weights of the different components could eliminate it. Also, the model with Roids in the Attribute of Interest seems to deal with issue better than the Baseline.

Running the Code

You can manage all the dependencies with Pipenv using the provided Pipfile. This allows for easier reproducibility of the code due to the fact that Pipenv creates a virtual environment containing all the necessary libraries. Just run pipenv shell in the base directory of the project and you're ready to go!

On the other hand, if for any reason you don't want to use Pipenv you can install all the required libraries using the provided requirements.txt file. Neither this file nor Pipenv take care of cuda though; in all my experiments I used cuda 7.5.18.

In order to download the datasets, you can use the fetch_data.sh script which downloads and extracts them in the correct directory, running:

. fetch_data.sh

Unfortunately, I am not allowed to publish any dataset other than MNIST, but feel free to ask the authors of the original PuppetGAN for them, following the instructions on their website πŸ™‚ .

Training a Model

To start a new training, simply run:

python3 main.py

This will automatically look first for any existing checkpoints and will restore the latest one. If you want to continue the training from a specific checkpoint just run:

python3 main.py -c [checkpoint number]

or

python3 main.py --ckpt=[checkpoint number]

To help you keep better track of your work, every time you start a new training, a configuration report is created in ./PuppetGAN/results/config.txt which stores a detailed report of your current configuration. This report contains all your hyper-parameters and their respective values as well as the whole architecture of the model you are using, including every single layer, its parameters and how it is connected to the rest of the model.

Also, to help you keep better track of your process, every a certain number of epochs my model creates in ./PuppetGAN/results a sample of evaluation rows of generated images along with gif animations for these rows to visualize better the performance of your model.

On top of that, in ./PuppetGAN/results are also stored plots of both the supervised and the adversarial losses as well as the images that are produced during the training. This allows you to have in a single folder everything you need to evaluate an experiment, keep track of its progress and reproduce its results!

Unless you want to experiment with different architectures, PuppetGAN/config.json is the only file you'll need. This file allows you to control all the hyper-parameters of the model without having to look at any of code! More specifically, the parameters you can control are:

  • dataset : The dataset to use. You can choose between "mnist", "mouth" and "light".

  • epochs : The number of epochs that the model will be trained for.

  • noise std : The standard deviation of the noise that will be applied to the translated images. The mean of the noise is 0.

  • bottleneck noise : The standard deviation of the noise that will be applied to the bottleneck. The mean of the noise is 0.

  • on roids : Whether or not to use the proposed Roids.

  • learning rates-real generator : The learning rate of the real generator.

  • learning rates-real discriminator : The learning rate of the real discriminator

  • learning rates-synthetic generator : The learning rate of the synthetic generator.

  • learning rate-synthetic discriminator : The learning rate of the synthetic discriminator.

  • losses weights-reconstruction : The weight of the reconstruction loss.

  • losses weights-disentanglement : The weight of the disentanglement loss.

  • losses weights-cycle : The weight of the cycle loss.

  • losses weights-attribute cycle b3 : The weight of part of the attribute cycle loss that is a function of the synthetic image that has both the Attribute of Interest and all the rest of the attributes.

  • losses weights-attribute cycle a : The weight of part of the attribute cycle loss that is a function of the real image.

  • batch size : The batch size. Depending on the kind of the dataset different values can be given.

  • image size : At what size to resize the images of the dataset.

  • save images every : Every how many epochs to save the training images and the sample of the evaluation images.

  • save model every : Every how many epochs to create a checkpoint. Keep in mind that the 5 latest checkpoints are always kept during training.

Evaluation of a Model

You can start an evaluation just by running:

python3 main.py -t

or

python3 main.py --test

Just like with training, this will look for the latest checkpoint; if you want to evaluate the performance of a different checkpoint you can simply use the -c and --ckpt options the same way as before.

During the evaluation process, the model creates all the rows of the generated images, where each cell corresponds to the generated image for the respective synthetic and a real input. Additionally, for each of the evaluation images, their corresponding gif file is also created to help you get a better idea of your results!

If you want to calculate the scores of your model in the MNIST dataset you can use my ./PuppetGAN/eval_rotation.py script, by running:

python3 eval_rotation.py -p [path to the directory of your evaluation images]

or

python3 eval_rotation.py -path=[path to the directory of your evaluation images]

You can also specify a path to save the evaluation report file using the option -t or --target-path. For example, let's say you have just trained and produced the evaluation images for a model and you want to get the evaluation scores for epoch 100 and save the report in the folder of this epoch. Then you should just run:

# make sure you are in ./PuppetGAN
python3 eval_rotation.py -p results/test/100/images -t results/test/100

For a fair comparison I am also providing the checkpoint of my LeNet-5 network in ./PuppetGAN/checkpoints/lenet5. If the eval_rotation.py script doesn't detect the checkpoint it will train one from scratch and in this case there may be a small difference in the accuracy of your model.

Owner
Giorgos Karantonis
Passionate about AI, ML, DL and other abbreviations.
Giorgos Karantonis
High performance, easy-to-use, and scalable machine learning (ML) package, including linear model (LR), factorization machines (FM), and field-aware factorization machines (FFM) for Python and CLI interface.

What is xLearn? xLearn is a high performance, easy-to-use, and scalable machine learning package that contains linear model (LR), factorization machin

Chao Ma 3k Jan 03, 2023
Learning Representational Invariances for Data-Efficient Action Recognition

Learning Representational Invariances for Data-Efficient Action Recognition Official PyTorch implementation for Learning Representational Invariances

Virginia Tech Vision and Learning Lab 27 Nov 22, 2022
Do you like Quick, Draw? Well what if you could train/predict doodles drawn inside Streamlit? Also draws lines, circles and boxes over background images for annotation.

Streamlit - Drawable Canvas Streamlit component which provides a sketching canvas using Fabric.js. Features Draw freely, lines, circles, boxes and pol

Fanilo Andrianasolo 325 Dec 28, 2022
Newt - a Gaussian process library in JAX.

Newt __ \/_ (' \`\ _\, \ \\/ /`\/\ \\ \ \\

AaltoML 0 Nov 02, 2021
Python implementation of Wu et al (2018)'s registration fusion

reg-fusion Projection of a central sulcus probability map using the RF-ANTs approach (right hemisphere shown). This is a Python implementation of Wu e

Dan Gale 26 Nov 12, 2021
Systemic Evolutionary Chemical Space Exploration for Drug Discovery

SECSE SECSE: Systemic Evolutionary Chemical Space Explorer Chemical space exploration is a major task of the hit-finding process during the pursuit of

64 Dec 16, 2022
automated systems to assist guarding corona Virus precautions for Closed Rooms (e.g. Halls, offices, etc..)

Automatic-precautionary-guard automated systems to assist guarding corona Virus precautions for Closed Rooms (e.g. Halls, offices, etc..) what is this

badra 0 Jan 06, 2022
Some bravo or inspiring research works on the topic of curriculum learning.

Towards Scalable Unpaired Virtual Try-On via Patch-Routed Spatially-Adaptive GAN Official code for NeurIPS 2021 paper "Towards Scalable Unpaired Virtu

131 Jan 07, 2023
Benchmark for the generalization of 3D machine learning models across different remeshing/samplings of a surface.

Discretization Robust Correspondence Benchmark One challenge of machine learning on 3D surfaces is that there are many different representations/sampl

Nicholas Sharp 10 Sep 30, 2022
Object DGCNN and DETR3D, Our implementations are built on top of MMdetection3D.

This repo contains the implementations of Object DGCNN (https://arxiv.org/abs/2110.06923) and DETR3D (https://arxiv.org/abs/2110.06922). Our implementations are built on top of MMdetection3D.

Wang, Yue 539 Jan 07, 2023
An atmospheric growth and evolution model based on the EVo degassing model and FastChem 2.0

EVolve Linking planetary mantles to atmospheric chemistry through volcanism using EVo and FastChem. Overview EVolve is a linked mantle degassing and a

Pip Liggins 2 Jan 17, 2022
A Kaggle competition: discriminate gender based on handwriting

Gender discrimination based on handwriting See http://fastml.com/gender-discrimination/ for description. prep_data.py - a first step chunk_by_authors.

Zygmunt ZajΔ…c 22 Jul 20, 2022
Code for "CloudAAE: Learning 6D Object Pose Regression with On-line Data Synthesis on Point Clouds" @ICRA2021

CloudAAE This is an tensorflow implementation of "CloudAAE: Learning 6D Object Pose Regression with On-line Data Synthesis on Point Clouds" Files log:

Gee 35 Nov 14, 2022
Predicts an answer in yes or no.

Oui-ou-non-prediction Predicts an answer in 'yes' or 'no'. It is based on the game 'effeuiller la marguerite' in which the person plucks flower petals

Ananya Gupta 1 Jan 15, 2022
Project page for End-to-end Recovery of Human Shape and Pose

End-to-end Recovery of Human Shape and Pose Angjoo Kanazawa, Michael J. Black, David W. Jacobs, Jitendra Malik CVPR 2018 Project Page Requirements Pyt

1.4k Dec 29, 2022
The audio-video synchronization of MKV Container Format is exploited to achieve data hiding

The audio-video synchronization of MKV Container Format is exploited to achieve data hiding, where the hidden data can be utilized for various management purposes, including hyper-linking, annotation

Maxim Zaika 1 Nov 17, 2021
Official PyTorch Implementation of Embedding Transfer with Label Relaxation for Improved Metric Learning, CVPR 2021

Embedding Transfer with Label Relaxation for Improved Metric Learning Official PyTorch implementation of CVPR 2021 paper Embedding Transfer with Label

Sungyeon Kim 37 Dec 06, 2022
HackBMU-5.0-Team-Ctrl-Alt-Elite - HackBMU 5.0 Team Ctrl Alt Elite

HackBMU-5.0-Team-Ctrl-Alt-Elite The search is over. We present to you β€˜Health-A-

3 Feb 19, 2022
CS50x-AI - Artificial Intelligence with Python from Harvard University

CS50x-AI Artificial Intelligence with Python from Harvard University πŸ“– Table of

Hosein Damavandi 6 Aug 22, 2022
Transfer Learning Shootout for PyTorch's model zoo (torchvision)

pytorch-retraining Transfer Learning shootout for PyTorch's model zoo (torchvision). Load any pretrained model with custom final layer (num_classes) f

Alexander Hirner 169 Jun 29, 2022