Implementation of Feedback Transformer in Pytorch

Overview

Feedback Transformer - Pytorch

Simple implementation of Feedback Transformer in Pytorch. They improve on Transformer-XL by having each token have access to the representations of all previous layers through time. This is achieved by aggregating the outputs of all layers into a shared memory, which each token across layers can attend to at each time step.

The main drawback is longer training time, due to its non-parallel nature. But I thought I'd build it to further exploration and research into this line of work.

Yannic Kilcher video

I also took the liberty to add some various enhancements, including pre-normalization, GLU gated feedforwards, as well as simplified T5 relative positional embeddings.

Install

$ pip install feedback-transformer-pytorch

Usage

import torch
from feedback_transformer_pytorch import FeedbackTransformer

model = FeedbackTransformer(
    num_tokens = 20000,           # number of tokens
    dim = 512,                    # dimension
    depth = 6,                    # depth
    seq_len = 2,                  # the sequence length of each segment or window
    mem_len = 256,                # length of the memory buffer
    dim_head = 64,                # dimension of each head
    heads = 8,                    # number of heads
    attn_dropout = 0.1,           # attention dropout
    ff_dropout = 0.1              # feedforward dropout
).cuda()

x = torch.randint(0, 20000, (2, 64)).cuda()
model(x)  # (2, 64, 20000)

If you would like to have fine control over the memory (when to detach, etc), you can do it with some extra keyword arguments on .forward

import torch
from feedback_transformer_pytorch import FeedbackTransformer

model = FeedbackTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 32,
    mem_len = 256
).cuda()

x1 = torch.randint(0, 20000, (2, 32)).cuda()
x2 = torch.randint(0, 20000, (2, 32)).cuda()
x3 = torch.randint(0, 20000, (2, 32)).cuda()

out1, mem1 = model(x1, return_memory = True)
out2, mem2 = model(x2, memory = mem1, return_memory = True)
out3, mem3 = model(x3, memory = mem2, return_memory = True)  # (2, 32, 20000)

Citations

@misc{fan2021addressing,
    title   = {Addressing Some Limitations of Transformers with Feedback Memory}, 
    author  = {Angela Fan and Thibaut Lavril and Edouard Grave and Armand Joulin and Sainbayar Sukhbaatar},
    year    = {2021},
    eprint  = {2002.09402},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Should it really be using lower layers output for keys and values?

    Should it really be using lower layers output for keys and values?

    Could you explain the logic of how the key-value pairs are formed at these lines and whether it is necessary?

    https://github.com/lucidrains/feedback-transformer-pytorch/blob/d7d8939910d1491f01a3d93ce81d4663925fb389/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L146-L151

    It looks to me that line 146 transforms the output of the layer below (x) to keys and values, and the following lines combine these keys and values with the memory. I thought that x should only be used for forming the query here, and only the existing memory is used for keys and values.

    opened by tarvaina 6
  • In place operation with gradient

    In place operation with gradient

    https://github.com/lucidrains/feedback-transformer-pytorch/blob/main/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L173 I think this is an error.

    opened by hadaev8 4
  • Bug in weighted sum

    Bug in weighted sum

    Bug in https://github.com/lucidrains/feedback-transformer-pytorch/blob/main/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L264

    Should be layer_weight = rearrange(layer_weight, 'd -> d () () ()')

    opened by Victor0118 1
  • Input/Output dimensions

    Input/Output dimensions

    Hey @lucidrains

    Can I check the dimensions of the input and output, is it (seq_len, dim) -> (? ,dim, tokens)?

    model = FeedbackTransformer(
        num_tokens = 20000,           # number of tokens
        dim = 512,                    # dimension
        depth = 6,                    # depth
        seq_len = 2,                  # the sequence length of each segment or window
        mem_len = 256,                # length of the memory buffer
        dim_head = 64,                # dimension of each head
        heads = 8,                    # number of heads
        attn_dropout = 0.1,           # attention dropout
        ff_dropout = 0.1              # feedforward dropout
    ).cuda()
    
    x = torch.randint(0, 256, (2, 512)).cuda()
    model(x)  # (1, 512, 20000)
    
    opened by iiSeymour 1
  • Non intuitive memory usage with cross attention

    Non intuitive memory usage with cross attention

    Give simple 256 dim and 512 len tensor and memory len 16 feedback transformer uses 3.6gm memory after forward pass. With cross attention on 100 len tensor usage grows to 14gb.

    While parallel version uses 3.1gb and 3.5gb.

    Notebooks for testing https://colab.research.google.com/drive/1dRImydFn3WthOXdLYIvdf5bsqjXcmhC5?usp=sharing https://colab.research.google.com/drive/1n653j4Pz9_U7OukhTlUbomAHMvpPXwx0?usp=sharing

    opened by hadaev8 0
  • I think mask padding value should be False

    I think mask padding value should be False

    Here https://github.com/lucidrains/feedback-transformer-pytorch/blob/with-cross-attention/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L181

    opened by hadaev8 0
  • ETA for the enwiki8 example

    ETA for the enwiki8 example

    Hey @lucidrains,

    Any eta on the example for auto-regressive enwiki8 example? I and others would really appreciate it as always :)

    Also, if you can provide an example for training on custom line-by-line TXT datasets, it would be absolutely fantastic.

    Thank you.

    opened by asigalov61 0
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Image Segmentation using U-Net, U-Net with skip connections and M-Net architectures

Brain-Image-Segmentation Segmentation of brain tissues in MRI image has a number of applications in diagnosis, surgical planning, and treatment of bra

Angad Bajwa 8 Oct 27, 2022
Self-Supervised Monocular 3D Face Reconstruction by Occlusion-Aware Multi-view Geometry Consistency[ECCV 2020]

Self-Supervised Monocular 3D Face Reconstruction by Occlusion-Aware Multi-view Geometry Consistency(ECCV 2020) This is an official python implementati

304 Jan 03, 2023
This package implements THOR: Transformer with Stochastic Experts.

THOR: Transformer with Stochastic Experts This PyTorch package implements Taming Sparsely Activated Transformer with Stochastic Experts. Installation

Microsoft 45 Nov 22, 2022
Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two

512x512 flowers after 12 hours of training, 1 gpu 256x256 flowers after 12 hours of training, 1 gpu Pizza 'Lightweight' GAN Implementation of 'lightwe

Phil Wang 1.5k Jan 02, 2023
Official implementation of the paper 'Efficient and Degradation-Adaptive Network for Real-World Image Super-Resolution'

DASR Paper Efficient and Degradation-Adaptive Network for Real-World Image Super-Resolution Jie Liang, Hui Zeng, and Lei Zhang. In arxiv preprint. Abs

81 Dec 28, 2022
CoRe: Contrastive Recurrent State-Space Models

CoRe: Contrastive Recurrent State-Space Models This code implements the CoRe model and reproduces experimental results found in Robust Robotic Control

Apple 21 Aug 11, 2022
[Preprint] "Bag of Tricks for Training Deeper Graph Neural Networks A Comprehensive Benchmark Study" by Tianlong Chen*, Kaixiong Zhou*, Keyu Duan, Wenqing Zheng, Peihao Wang, Xia Hu, Zhangyang Wang

Bag of Tricks for Training Deeper Graph Neural Networks: A Comprehensive Benchmark Study Codes for [Preprint] Bag of Tricks for Training Deeper Graph

VITA 101 Dec 29, 2022
Datasets for new state-of-the-art challenge in disentanglement learning

High resolution disentanglement datasets This repository contains the Falcor3D and Isaac3D datasets, which present a state-of-the-art challenge for co

NVIDIA Research Projects 37 May 26, 2022
PyTorch implementation for the paper Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime

Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime Created by Prarthana Bhattacharyya. Disclaimer: This is n

Prarthana Bhattacharyya 5 Nov 08, 2022
Configure SRX interfaces with Scrapli

Configure SRX interfaces with Scrapli Overview This example will show how to configure interfaces on Juniper's SRX firewalls. In addition to the Pytho

Calvin Remsburg 1 Jan 07, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Generating Digital Painting Lighting Effects via RGB-space Geometry (SIGGRAPH2020/TOG2020)

Project PaintingLight PaintingLight is a project conducted by the Style2Paints team, aimed at finding a method to manipulate the illumination in digit

651 Dec 29, 2022
Official implementation of "An Image is Worth 16x16 Words, What is a Video Worth?" (2021 paper)

An Image is Worth 16x16 Words, What is a Video Worth? paper Official PyTorch Implementation Gilad Sharir, Asaf Noy, Lihi Zelnik-Manor DAMO Academy, Al

213 Nov 12, 2022
minimizer-space de Bruijn graphs (mdBG) for whole genome assembly

rust-mdbg: Minimizer-space de Bruijn graphs (mdBG) for whole-genome assembly rust-mdbg is an ultra-fast minimizer-space de Bruijn graph (mdBG) impleme

Barış Ekim 148 Dec 01, 2022
Request execution of Galaxy SARS-CoV-2 variation analysis workflows on input data you provide.

SARS-CoV-2 processing requests Request execution of Galaxy SARS-CoV-2 variation analysis workflows on input data you provide. Prerequisites This autom

useGalaxy.eu 17 Aug 13, 2022
CFC-Net: A Critical Feature Capturing Network for Arbitrary-Oriented Object Detection in Remote Sensing Images

CFC-Net This project hosts the official implementation for the paper: CFC-Net: A Critical Feature Capturing Network for Arbitrary-Oriented Object Dete

ming71 55 Dec 12, 2022
AI virtual gym is an AI program which can be used to exercise and can be used to see if we are doing the exercises

AI virtual gym is an AI program which can be used to exercise and can be used to see if we are doing the exercises

4 Feb 13, 2022
PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation

PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation Created by Charles R. Qi, Hao Su, Kaichun Mo, Leonidas J. Guibas from Sta

Charles R. Qi 4k Dec 30, 2022
The official implementation of Autoregressive Image Generation using Residual Quantization (CVPR '22)

Autoregressive Image Generation using Residual Quantization (CVPR 2022) The official implementation of "Autoregressive Image Generation using Residual

Kakao Brain 529 Dec 30, 2022
A framework to train language models to learn invariant representations.

Invariant Language Modeling Implementation of the training for invariant language models. Motivation Modern pretrained language models are critical co

6 Nov 16, 2022