VideoGPT: Video Generation using VQ-VAE and Transformers

Related tags

Deep LearningVideoGPT
Overview

VideoGPT: Video Generation using VQ-VAE and Transformers

[Paper][Website][Colab][Gradio Demo]

We present VideoGPT: a conceptually simple architecture for scaling likelihood based generative modeling to natural videos. VideoGPT uses VQ-VAE that learns downsampled discrete latent representations of a raw video by employing 3D convolutions and axial self-attention. A simple GPT-like architecture is then used to autoregressively model the discrete latents using spatio-temporal position encodings. Despite the simplicity in formulation and ease of training, our architecture is able to generate samples competitive with state-of-the-art GAN models for video generation on the BAIR Robot dataset, and generate high fidelity natural images from UCF-101 and Tumbler GIF Dataset (TGIF). We hope our proposed architecture serves as a reproducible reference for a minimalistic implementation of transformer based video generation models.

Approach

VideoGPT

Installation

Change the cudatoolkit version compatible to your machine.

$ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
$ pip install git+https://github.com/wilson1yan/VideoGPT.git

Sparse Attention (Optional)

For limited compute scenarios, it may be beneficial to use sparse attention.

$ sudo apt-get install llvm-9-dev
$ DS_BUILD_SPARSE_ATTN=1 pip install deepspeed

After installng deepspeed, you can train a sparse transformer by setting the flag --attn_type sparse in scripts/train_videogpt.py. The default support sparsity configuration is an N-d strided sparsity layout, however, you can write your own arbitrary layouts to use.

Dataset

The default code accepts data as an HDF5 file with the specified format in videogpt/data.py, and a directory format with the follow structure:

video_dataset/
    train/
        class_0/
            video1.mp4
            video2.mp4
            ...
        class_1/
            video1.mp4
            ...
        ...
        class_n/
            ...
    test/
        class_0/
            video1.mp4
            video2.mp4
            ...
        class_1/
            video1.mp4
            ...
        ...
        class_n/
            ...

An example of such a dataset can be constructed from UCF-101 data by running the script

sh scripts/preprocess/create_ucf_dataset.sh datasets/ucf101

You may need to install unrar and unzip for the code to work correctly.

If you do not care about classes, the class folders are not necessary and the dataset file structure can be collapsed into train and test directories of just videos.

Using Pretrained VQ-VAEs

There are four available pre-trained VQ-VAE models. All strides listed with each model are downsampling amounts across THW for the encoders.

  • bair_stride4x2x2: trained on 16 frame 64 x 64 videos from the BAIR Robot Pushing dataset
  • ucf101_stride4x4x4: trained on 16 frame 128 x 128 videos from UCF-101
  • kinetics_stride4x4x4: trained on 16 frame 128 x 128 videos from Kinetics-600
  • kinetics_stride2x4x4: trained on 16 frame 128 x 128 videos from Kinetics-600, with 2x larger temporal latent codes (achieves slightly better reconstruction)
from torchvision.io import read_video
from videogpt import load_vqvae
from videogpt.data import preprocess

video_filename = 'path/to/video_file.mp4'
sequence_length = 16
resolution = 128
device = torch.device('cuda')

vqvae = load_vqvae('kinetics_stride2x4x4')
video = read_video(video_filename, pts_unit='sec')[0]
video = preprocess(video, resolution, sequence_length).unsqueeze(0).to(device)

encodings = vqvae.encode(video)
video_recon = vqvae.decode(encodings)

Training VQ-VAE

Use the scripts/train_vqvae.py script to train a VQ-VAE. Execute python scripts/train_vqvae.py -h for information on all available training settings. A subset of more relevant settings are listed below, along with default values.

VQ-VAE Specific Settings

  • --embedding_dim: number of dimensions for codebooks embeddings
  • --n_codes 2048: number of codes in the codebook
  • --n_hiddens 240: number of hidden features in the residual blocks
  • --n_res_layers 4: number of residual blocks
  • --downsample 4 4 4: T H W downsampling stride of the encoder

Training Settings

  • --gpus 2: number of gpus for distributed training
  • --sync_batchnorm: uses SyncBatchNorm instead of BatchNorm3d when using > 1 gpu
  • --gradient_clip_val 1: gradient clipping threshold for training
  • --batch_size 16: batch size per gpu
  • --num_workers 8: number of workers for each DataLoader

Dataset Settings

  • --data_path : path to an hdf5 file or a folder containing train and test folders with subdirectories of videos
  • --resolution 128: spatial resolution to train on
  • --sequence_length 16: temporal resolution, or video clip length

Training VideoGPT

You can download a pretrained VQ-VAE, or train your own. Afterwards, use the scripts/train_videogpt.py script to train an VideoGPT model for sampling. Execute python scripts/train_videogpt.py -h for information on all available training settings. A subset of more relevant settings are listed below, along with default values.

VideoGPT Specific Settings

  • --vqvae kinetics_stride4x4x4: path to a vqvae checkpoint file, OR a pretrained model name to download. Available pretrained models are: bair_stride4x2x2, ucf101_stride4x4x4, kinetics_stride4x4x4, kinetics_stride2x4x4. BAIR was trained on 64 x 64 videos, and the rest on 128 x 128 videos
  • --n_cond_frames 0: number of frames to condition on. 0 represents a non-frame conditioned model
  • --class_cond: trains a class conditional model if activated
  • --hidden_dim 576: number of transformer hidden features
  • --heads 4: number of heads for multihead attention
  • --layers 8: number of transformer layers
  • --dropout 0.2': dropout probability applied to features after attention and positionwise feedforward layers
  • --attn_type full: full or sparse attention. Refer to the Installation section for install sparse attention
  • --attn_dropout 0.3: dropout probability applied to the attention weight matrix

Training Settings

  • --gpus 2: number of gpus for distributed training
  • --sync_batchnorm: uses SyncBatchNorm instead of BatchNorm3d when using > 1 gpu
  • --gradient_clip_val 1: gradient clipping threshold for training
  • --batch_size 16: batch size per gpu
  • --num_workers 8: number of workers for each DataLoader

Dataset Settings

  • --data_path : path to an hdf5 file or a folder containing train and test folders with subdirectories of videos
  • --resolution 128: spatial resolution to train on
  • --sequence_length 16: temporal resolution, or video clip length

Sampling VideoGPT

After training, the VideoGPT model can be sampled using the scripts/sample_videogpt.py. You may need to install ffmpeg: sudo apt-get install ffmpeg

Reproducing Paper Results

Note that this repo is primarily designed for simplicity and extending off of our method. Reproducing the full paper results can be done using code found at a separate repo. However, be aware that the code is not as clean.

Citation

Please consider using the follow citation when using our code:

@misc{yan2021videogpt,
      title={VideoGPT: Video Generation using VQ-VAE and Transformers}, 
      author={Wilson Yan and Yunzhi Zhang and Pieter Abbeel and Aravind Srinivas},
      year={2021},
      eprint={2104.10157},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
Owner
Wilson Yan
1st year PhD interested in unsupervised learning and reinforcement learning
Wilson Yan
[NeurIPS 2021] Shape from Blur: Recovering Textured 3D Shape and Motion of Fast Moving Objects

[NeurIPS 2021] Shape from Blur: Recovering Textured 3D Shape and Motion of Fast Moving Objects YouTube | arXiv Prerequisites Kaolin is available here:

Denys Rozumnyi 107 Dec 26, 2022
Tensorflow Implementation of the paper "Spectral Normalization for Generative Adversarial Networks" (ICML 2017 workshop)

tf-SNDCGAN Tensorflow implementation of the paper "Spectral Normalization for Generative Adversarial Networks" (https://www.researchgate.net/publicati

Nhat M. Nguyen 248 Nov 25, 2022
Implementation of our recent paper, WOOD: Wasserstein-based Out-of-Distribution Detection.

WOOD Implementation of our recent paper, WOOD: Wasserstein-based Out-of-Distribution Detection. Abstract The training and test data for deep-neural-ne

8 Dec 24, 2022
Official repository for "Exploiting Session Information in BERT-based Session-aware Sequential Recommendation", SIGIR 2022 short.

Session-aware BERT4Rec Official repository for "Exploiting Session Information in BERT-based Session-aware Sequential Recommendation", SIGIR 2022 shor

Jamie J. Seol 22 Dec 13, 2022
All the code and files related to the MI-Lab of UE19CS305 course in sem 5

Machine-Intelligence-Lab-CS305 The compilation of all the code an drelated files from MI-Lab UE19CS305 (of batch 2019-2023) offered by PES University

Arvind Krishna 3 Nov 10, 2022
🤖 A Python library for learning and evaluating knowledge graph embeddings

PyKEEN PyKEEN (Python KnowlEdge EmbeddiNgs) is a Python package designed to train and evaluate knowledge graph embedding models (incorporating multi-m

PyKEEN 1.1k Jan 09, 2023
structured-generative-modeling

This repository contains the implementation for the paper Information Theoretic StructuredGenerative Modeling, Specially thanks for the open-source co

0 Oct 11, 2021
A pytorch implementation of Pytorch-Sketch-RNN

Pytorch-Sketch-RNN A pytorch implementation of https://arxiv.org/abs/1704.03477 In order to draw other things than cats, you will find more drawing da

Alexis David Jacq 172 Dec 12, 2022
To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

Kunal Wadhwa 2 Jan 05, 2022
Modular Gaussian Processes

Modular Gaussian Processes for Transfer Learning 🧩 Introduction This repository contains the implementation of our paper Modular Gaussian Processes f

Pablo Moreno-Muñoz 10 Mar 15, 2022
Collision risk estimation using stochastic motion models

collision_risk_estimation Collision risk estimation using stochastic motion models. This is a new approach, based on stochastic models, to predict the

Unmesh 7 Jun 26, 2022
A Comparative Review of Recent Kinect-Based Action Recognition Algorithms (TIP2020, Matlab codes)

A Comparative Review of Recent Kinect-Based Action Recognition Algorithms This repo contains: the HDG implementation (Matlab codes) for 'Analysis and

Lei Wang 5 Oct 22, 2022
Supplementary materials for ISMIR 2021 LBD paper "Evaluation of Latent Space Disentanglement in the Presence of Interdependent Attributes"

Evaluation of Latent Space Disentanglement in the Presence of Interdependent Attributes Supplementary materials for ISMIR 2021 LBD submission: K. N. W

Karn Watcharasupat 2 Oct 25, 2021
Codebase for "Revisiting spatio-temporal layouts for compositional action recognition" (Oral at BMVC 2021).

Revisiting spatio-temporal layouts for compositional action recognition Codebase for "Revisiting spatio-temporal layouts for compositional action reco

Gorjan 20 Dec 15, 2022
A toolkit for making real world machine learning and data analysis applications in C++

dlib C++ library Dlib is a modern C++ toolkit containing machine learning algorithms and tools for creating complex software in C++ to solve real worl

Davis E. King 11.6k Jan 01, 2023
Pytorch codes for Feature Transfer Learning for Face Recognition with Under-Represented Data

FTLNet_Pytorch Pytorch codes for Feature Transfer Learning for Face Recognition with Under-Represented Data 1. Introduction This repo is an unofficial

1 Nov 04, 2020
Data manipulation and transformation for audio signal processing, powered by PyTorch

torchaudio: an audio library for PyTorch The aim of torchaudio is to apply PyTorch to the audio domain. By supporting PyTorch, torchaudio follows the

1.9k Dec 28, 2022
Deep Learning to Improve Breast Cancer Detection on Screening Mammography

Shield: This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. Deep Learning to Improve Breast

Li Shen 305 Jan 03, 2023
This application explain how we can easily integrate Deepface framework with Python Django application

deepface_suite This application explain how we can easily integrate Deepface framework with Python Django application install redis cache install requ

Mohamed Naji Aboo 3 Apr 18, 2022
This is a package for LiDARTag, described in paper: LiDARTag: A Real-Time Fiducial Tag System for Point Clouds

LiDARTag Overview This is a package for LiDARTag, described in paper: LiDARTag: A Real-Time Fiducial Tag System for Point Clouds (PDF)(arXiv). This wo

University of Michigan Dynamic Legged Locomotion Robotics Lab 159 Dec 21, 2022