A Tensorfflow implementation of Attend, Infer, Repeat

Overview

Attend, Infer, Repeat: Fast Scene Understanding with Generative Models

This is an unofficial Tensorflow implementation of Attend, Infear, Repeat (AIR), as presented in the following paper: S. M. Ali Eslami et. al., Attend, Infer, Repeat: Fast Scene Understanding with Generative Models.

  • Author (of the implementation): Adam Kosiorek, Oxford Robotics Institue, University of Oxford
  • Email: adamk(at)robots.ox.ac.uk
  • Webpage: http://akosiorek.github.io/

I describe the implementation and the issues I run into while working on it in this blog post.

Installation

Install Tensorflow v1.1.0rc1, Sonnet v1.1 and the following dependencies (using pip install -r requirements.txt (preferred) or pip install [package]):

  • matplotlib==1.5.3
  • numpy==1.12.1
  • attrdict==2.0.0
  • scipy==0.18.1

Sample Results

AIR learns to reconstruct objects by painting them one by one in a blank canvas. The below figure comes from a model trained for 175k iterations; the maximum number of steps is set to 3, but there are never more than 2 objects. The first row shows the input images, rows 2-4 are reconstructions at steps 1, 2 and 3 (with marked location of the attention glimpse in red, if it exists). Rows 4-7 are the reconstructed image crops, and above each crop is the probability of executing 1, 2 or 3 steps. If the reconstructed crop is black and there is "0 with ..." written above it, it means that this step was not used.

AIR results

Data

Run ./scripts/create_dataset.sh The script creates train and validation datasets of multi-digit MNIST.

Training

Run ./scripts/train_multi_mnist.sh The training script will run for 300k iteratios and will save model checkpoints and training progress figures every 10k iterations in results/multi_mnist. Tensorflow summaries are also stored in the same folder and Tensorboard can be used for monitoring.

The model seems to be very sensitive to initialisation. It might be necessary to run training multiple times before achieving count step accuracy close to the one reported in the paper.

Experimentation

The jupyter notebook available at attend_infer_repeat/experiment.ipynb can be used for experimentation.

Citation

If you find this repo useful in your research, please consider citing the original paper:

@incollection{Eslami2016,
    title = {Attend, Infer, Repeat: Fast Scene Understanding with Generative Models},
    author = {Eslami, S. M. Ali and Heess, Nicolas and Weber, Theophane and Tassa, Yuval and Szepesvari, David and kavukcuoglu, koray and Hinton, Geoffrey E},
    booktitle = {Advances in Neural Information Processing Systems 29},
    editor = {D. D. Lee and M. Sugiyama and U. V. Luxburg and I. Guyon and R. Garnett},
    pages = {3225--3233},
    year = {2016},
    publisher = {Curran Associates, Inc.},
    url = {http://papers.nips.cc/paper/6230-attend-infer-repeat-fast-scene-understanding-with-generative-models.pdf}
}

License

This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with this program. If not, see http://www.gnu.org/licenses/.

Release Notes

Version 1.0

  • Original unofficial implementation; contains the multi-digit MNIST experiment.
Owner
Adam Kosiorek
I'm a PhD student at the Oxford Robotics Institute. I work on Machine Learning for perception - I'm looking into external memory and attention for RNNs.
Adam Kosiorek
Code for NeurIPS2021 submission "A Surrogate Objective Framework for Prediction+Programming with Soft Constraints"

This repository is the code for NeurIPS 2021 submission "A Surrogate Objective Framework for Prediction+Programming with Soft Constraints". Edit 2021/

10 Dec 20, 2022
Clairvoyance: a Unified, End-to-End AutoML Pipeline for Medical Time Series

Clairvoyance: A Pipeline Toolkit for Medical Time Series Authors: van der Schaar Lab This repository contains implementations of Clairvoyance: A Pipel

van_der_Schaar \LAB 89 Dec 07, 2022
Use .csv files to record, play and evaluate motion capture data.

Purpose These scripts allow you to record mocap data to, and play from .csv files. This approach facilitates parsing of body movement data in statisti

21 Dec 12, 2022
DeepLM: Large-scale Nonlinear Least Squares on Deep Learning Frameworks using Stochastic Domain Decomposition (CVPR 2021)

DeepLM DeepLM: Large-scale Nonlinear Least Squares on Deep Learning Frameworks using Stochastic Domain Decomposition (CVPR 2021) Run Please install th

Jingwei Huang 130 Dec 02, 2022
Open CV - Convert a picture to look like a cartoon sketch in python

Use the video https://www.youtube.com/watch?v=k7cVPGpnels for initial learning.

Sammith S Bharadwaj 3 Jan 29, 2022
Aquarius - Enabling Fast, Scalable, Data-Driven Virtual Network Functions

Aquarius Aquarius - Enabling Fast, Scalable, Data-Driven Virtual Network Functions NOTE: We are currently going through the open-source process requir

Zhiyuan YAO 0 Jun 02, 2022
A Python module for the generation and training of an entry-level feedforward neural network.

ff-neural-network A Python module for the generation and training of an entry-level feedforward neural network. This repository serves as a repurposin

Riadh 2 Jan 31, 2022
Face detection using deep learning.

Face Detection Docker Solution Using Faster R-CNN Dockerface is a deep learning face detector. It deploys a trained Faster R-CNN network on Caffe thro

Nataniel Ruiz 181 Dec 19, 2022
Chainer Implementation of Semantic Segmentation using Adversarial Networks

Semantic Segmentation using Adversarial Networks Requirements Chainer (1.23.0) Differences Use of FCN-VGG16 instead of Dilated8 as Segmentor. Caution

Taiki Oyama 99 Jun 28, 2022
General neural ODE and DAE modules for power system dynamic modeling.

Py_PSNODE General neural ODE and DAE modules for power system dynamic modeling. The PyTorch-based ODE solver is developed based on torchdiffeq. Sample

14 Dec 31, 2022
In this project, we develop a face recognize platform based on MTCNN object-detection netcwork and FaceNet self-supervised network.

模式识别大作业——人脸检测与识别平台 本项目是一个简易的人脸检测识别平台,提供了人脸信息录入和人脸识别的功能。前端采用 html+css+js,后端采用 pytorch,

Xuhua Huang 5 Aug 02, 2022
Official implementation of "Accelerating Reinforcement Learning with Learned Skill Priors", Pertsch et al., CoRL 2020

Accelerating Reinforcement Learning with Learned Skill Priors [Project Website] [Paper] Karl Pertsch1, Youngwoon Lee1, Joseph Lim1 1CLVR Lab, Universi

Cognitive Learning for Vision and Robotics (CLVR) lab @ USC 134 Dec 06, 2022
Shōgun

The SHOGUN machine learning toolbox Unified and efficient Machine Learning since 1999. Latest release: Cite Shogun: Develop branch build status: Donat

Shōgun ML 2.9k Jan 04, 2023
PyTorch implementation of the ACL, 2021 paper Parameter-efficient Multi-task Fine-tuning for Transformers via Shared Hypernetworks.

Parameter-efficient Multi-task Fine-tuning for Transformers via Shared Hypernetworks This repo contains the PyTorch implementation of the ACL, 2021 pa

Rabeeh Karimi Mahabadi 98 Dec 28, 2022
🔪 Elimination based Lightweight Neural Net with Pretrained Weights

ELimNet ELimNet: Eliminating Layers in a Neural Network Pretrained with Large Dataset for Downstream Task Removed top layers from pretrained Efficient

snoop2head 4 Jul 12, 2022
Deep Federated Learning for Autonomous Driving

FADNet: Deep Federated Learning for Autonomous Driving Abstract Autonomous driving is an active research topic in both academia and industry. However,

AIOZ AI 12 Dec 01, 2022
Reproduced Code for Image Forgery Detection papers.

Image Forgery Detection With over 4.5 billion active internet users, the amount of multimedia content being shared every day has surpassed everyone’s

Umar Masud 15 Dec 06, 2022
Implementation of ResMLP, an all MLP solution to image classification, in Pytorch

ResMLP - Pytorch Implementation of ResMLP, an all MLP solution to image classification out of Facebook AI, in Pytorch Install $ pip install res-mlp-py

Phil Wang 178 Dec 02, 2022
TagLab: an image segmentation tool oriented to marine data analysis

TagLab: an image segmentation tool oriented to marine data analysis TagLab was created to support the activity of annotation and extraction of statist

Visual Computing Lab - ISTI - CNR 49 Dec 29, 2022
In this project we predict the forest cover type using the cartographic variables in the training/test datasets.

Kaggle Competition: Forest Cover Type Prediction In this project we predict the forest cover type (the predominant kind of tree cover) using the carto

Marianne Joy Leano 1 Mar 15, 2022