Pytorch implementation of Compressive Transformers, from Deepmind

Overview

Compressive Transformer in Pytorch

Pytorch implementation of Compressive Transformers, a variant of Transformer-XL with compressed memory for long-range language modelling. I will also combine this with an idea from another paper that adds gating at the residual intersection. The memory and the gating may be synergistic, and lead to further improvements in both language modeling as well as reinforcement learning.

PyPI version

Install

$ pip install compressive_transformer_pytorch

Usage

import torch
from compressive_transformer_pytorch import CompressiveTransformer

model = CompressiveTransformer(
    num_tokens = 20000,
    emb_dim = 128,                 # embedding dimensions, embedding factorization from Albert paper
    dim = 512,
    depth = 12,
    seq_len = 1024,
    mem_len = 1024,                # memory length
    cmem_len = 1024 // 4,          # compressed memory buffer length
    cmem_ratio = 4,                # compressed memory ratio, 4 was recommended in paper
    reconstruction_loss_weight = 1,# weight to place on compressed memory reconstruction loss
    attn_dropout = 0.1,            # dropout post-attention
    ff_dropout = 0.1,              # dropout in feedforward
    attn_layer_dropout = 0.1,      # dropout for attention layer output
    gru_gated_residual = True,     # whether to gate the residual intersection, from 'Stabilizing Transformer for RL' paper
    mogrify_gru = False,           # experimental feature that adds a mogrifier for the update and residual before gating by the GRU
    memory_layers = range(6, 13),  # specify which layers to use long-range memory, from 'Do Transformers Need LR Memory' paper
    ff_glu = True                  # use GLU variant for feedforward
)

inputs = torch.randint(0, 256, (1, 2048))
masks = torch.ones_like(inputs).bool()

segments = inputs.reshape(1, -1, 1024).transpose(0, 1)
masks = masks.reshape(1, -1, 1024).transpose(0, 1)

logits, memories, aux_loss = model(segments[0], mask = masks[0])
logits,        _, aux_loss = model(segments[1], mask = masks[1], memories = memories)

# memories is a named tuple that contains the memory (mem) and the compressed memory (cmem)

When training, you can use the AutoregressiveWrapper to have memory management across segments taken care of for you. As easy as it gets.

import torch
from compressive_transformer_pytorch import CompressiveTransformer
from compressive_transformer_pytorch import AutoregressiveWrapper

model = CompressiveTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 1024,
    mem_len = 1024,
    cmem_len = 256,
    cmem_ratio = 4,
    memory_layers = [5,6]
).cuda()

model = AutoregressiveWrapper(model)

inputs = torch.randint(0, 20000, (1, 2048 + 1)).cuda()

for loss, aux_loss, _ in model(inputs, return_loss = True):
    (loss + aux_loss).backward()
    # optimizer step and zero grad

# ... after much training ...

# generation is also greatly simplified and automated away
# just pass in the prime, which can be 1 start token or any length
# all is taken care of for you

prime = torch.ones(1, 1).cuda()  # assume 1 is start token
sample = model.generate(prime, 4096)

Citations

@misc{rae2019compressive,
    title   = {Compressive Transformers for Long-Range Sequence Modelling},
    author  = {Jack W. Rae and Anna Potapenko and Siddhant M. Jayakumar and Timothy P. Lillicrap},
    year    = {2019},
    eprint  = {1911.05507},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{parisotto2019stabilizing,
    title   = {Stabilizing Transformers for Reinforcement Learning},
    author  = {Emilio Parisotto and H. Francis Song and Jack W. Rae and Razvan Pascanu and Caglar Gulcehre and Siddhant M. Jayakumar and Max Jaderberg and Raphael Lopez Kaufman and Aidan Clark and Seb Noury and Matthew M. Botvinick and Nicolas Heess and Raia Hadsell},
    year    = {2019},
    eprint  = {1910.06764},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{rae-razavi-2020-transformers,
    title   = "Do Transformers Need Deep Long-Range Memory?",
    author  = "Rae, Jack  and
      Razavi, Ali",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
    month   = jul,
    year    = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url     = "https://www.aclweb.org/anthology/2020.acl-main.672"
}
@article{Shazeer2019FastTD,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam Shazeer},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1911.02150}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{lan2019albert,
    title       = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
    author      = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
    year        = {2019},
    url         = {https://arxiv.org/abs/1909.11942}
}
@misc{ding2021erniedoc,
    title   = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
    author  = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
    year    = {2021},
    eprint  = {2012.15688},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
Comments
  • aux_loss does not update any weigth

    aux_loss does not update any weigth

    Hi lucidrains, thanks for your implementation, it is very elegant and helped me a lot with my disertation. Anyway I can't understand a particular: it seems like aux_loss is not related to any weight because of the detaching in the last part of the SelfAttention layer. With the following code, for example, I get that there is no layer optimized by aux_loss:

    import torch
    from compressive_transformer_pytorch import CompressiveTransformer
    from compressive_transformer_pytorch import AutoregressiveWrapper
    
    model = CompressiveTransformer(
        num_tokens = 20000,
        dim = 512,
        depth = 6,
        seq_len = 1024,
        mem_len = 1024,
        cmem_len = 256,
        cmem_ratio = 4,
        memory_layers = [5,6]
    ).cuda()
    
    model = AutoregressiveWrapper(model)
    
    inputs = torch.randint(0, 20000, (1, 1024)).cuda()
    
    optimizer = torch.optim.Adam(model.parameters())
    
    for loss, aux_loss, _ in model(inputs, return_loss = True):
        optimizer.zero_grad(set_to_none=True)
        loss.backward(retain_graph=True)
        print("OPTIMIZED BY LOSS ************************************************************")
        for module_name, parameter in model.named_parameters():
            if parameter.grad is not None:
                print(module_name)
        optimizer.zero_grad(set_to_none=True)
        aux_loss.backward(retain_graph=True)
        print("OPTIMIZED BY AUX_LOSS ************************************************************")
        for module_name, parameter in model.named_parameters():
            if parameter.grad is not None:
                print(module_name)
    

    I am not expert about the PyTorch mechanisms, so maybe I am getting something wrong. Again thank you

    opened by StefanoBerti 3
  • How to use this for speech/audio generation?

    How to use this for speech/audio generation?

    Great work Phil! In their paper, the authors applied this model to speech modeling, how would you advise on what should I change to use for speech. Because in speech, the data are signals, we do not have num_tokens, nor do we have emb_dim. Our data input is simply, [batch, channel, time]. Any advice?

    opened by jinglescode 3
  • [Error] NameError: name 'math' is not defined in compressive_transformer_pytorch.py

    [Error] NameError: name 'math' is not defined in compressive_transformer_pytorch.py

    hello, I run code "examples/enwik8_simple" now, and I got error as follows:

    train.py:65: DeprecationWarning: The binary mode of fromstring is deprecated, as it behaves surprisingly on unicode inputs. Use frombuffer instead X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) training: 0%| | 0/100000 [00:00<?, ?it/s] Traceback (most recent call last): File "train.py", line 101, in <module> for mlm_loss, aux_loss, is_last in model(next(train_loader), max_batch_size = MAX_BATCH_SIZE, return_loss = True): File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/autoregressive_wrapper.py", line 151, in forward logits, new_mem, aux_loss = self.net(xi_seg_b, mask = mask_seg_b, memories = mem, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 338, in f orward x, = ff(x) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 84, in fo rward out = self.fn(x, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 106, in f orward return self.fn(x, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 140, in f orward x = self.act(x) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 122, in f orward return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) NameError: name 'math' is not defined

    so I inserted "import math" into compressive_transformer_pytorch.py file and it work well. I hope you modify compressive_transformer_pytorch.py code.

    opened by dinoSpeech 3
  • Training enwik8 but loss fail to converge

    Training enwik8 but loss fail to converge

    Hi lucidrains, I appreciate your implementation very much, and it helps me a lot with understanding compressive transformer. However when I tried running your code (enwik8 and exactly the same code in github), and the loss failed to converge after 100 epochs. Is this in expectation ? Or should I do other additional effort to improve, for example tokenizing the raw data in enwik8 and remove all the xml tags ? The figure below is the training and validation loss while I train enwik8 with the same code as in github.

    截圖 2021-03-26 下午5 40 16 截圖 2021-03-26 下午5 41 22

    Thanks and look forward to your reply!

    opened by KaiPoChang 2
  • Details about text generation

    Details about text generation

    Hi lucidrains, Thank you for your excellent code. I am curious about the generation scripts. Could you tell me how to generate text with the compressive transformer? Because it has the compressive memory, maybe we cannot use the current predicted word as the input for the next generation (input length ==1). In addition, if the prompt has 100 words and we use tokens [0:100], tokens[1:101], tokens[2:102]... as the input for the following timesteps, the tokens[1:100] may overlap with the memory, because the memory already contains hidden states for tokens[1:100].

    I would be very appeciated if you can provide the generation scripts!

    Thank you

    opened by theseventhflow 3
  • Links to original tf code - fyi

    Links to original tf code - fyi

    After reading deepmind blog post I was looking forward to downloading model but no luck. Looking forward to your implementation.

    You may be aware of this post and link but if not this is the coder's original tf implementation. Hope it helps.

    Copy of comment to original model request:

    https://github.com/huggingface/transformers/issues/4688

    Interested in model weights too but currently not available. Author does mention releasing tf code here:

    https://news.ycombinator.com/item?id=22290227

    Requires tf 1.15+ and deepmind/sonnet ver 1.36. Link to python script here:

    https://github.com/deepmind/sonnet/blob/cd5b5fa48e15e4d020f744968f5209949ebe750f/sonnet/python/modules/nets/transformer.py#L915

    Have tried running as-is but doesn't appear to have options for training on custom data as per the paper and available data sets.

    opened by GenTxt 8
Releases(0.4.0)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
SOFT: Softmax-free Transformer with Linear Complexity, NeurIPS 2021 Spotlight

SOFT: Softmax-free Transformer with Linear Complexity SOFT: Softmax-free Transformer with Linear Complexity, Jiachen Lu, Jinghan Yao, Junge Zhang, Xia

Fudan Zhang Vision Group 272 Dec 25, 2022
A hyperparameter optimization framework

Optuna: A hyperparameter optimization framework Website | Docs | Install Guide | Tutorial Optuna is an automatic hyperparameter optimization software

7.4k Jan 04, 2023
CvT-ASSD: Convolutional vision-Transformerbased Attentive Single Shot MultiBox Detector (ICTAI 2021 CCF-C 会议)The 33rd IEEE International Conference on Tools with Artificial Intelligence

CvT-ASSD including extra CvT, CvT-SSD, VGG-ASSD models original-code-website: https://github.com/albert-jin/CvT-SSD new-code-website: https://github.c

金伟强 -上海大学人工智能小渣渣~ 5 Mar 07, 2022
Haze Removal can remove slight to extreme cases of haze affecting an image

Haze Removal can remove slight to extreme cases of haze affecting an image. Its most typical use is for landscape photography where the haze causes low contrast and low saturation, but it can also be

Grace Ugochi Nneji 3 Feb 15, 2022
code for paper -- "Seamless Satellite-image Synthesis"

Seamless Satellite-image Synthesis by Jialin Zhu and Tom Kelly. Project site. The code of our models borrows heavily from the BicycleGAN repository an

Light 14 Apr 05, 2022
The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"

Hierarchical Token Semantic Audio Transformer Introduction The Code Repository for "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound

Knut(Ke) Chen 134 Jan 01, 2023
A set of Deep Reinforcement Learning Agents implemented in Tensorflow.

Deep Reinforcement Learning Agents This repository contains a collection of reinforcement learning algorithms written in Tensorflow. The ipython noteb

Arthur Juliani 2.2k Jan 01, 2023
Code and data accompanying our SVRHM'21 paper.

Code and data accompanying our SVRHM'21 paper. Requires tensorflow 1.13, python 3.7, scikit-learn, and pytorch 1.6.0 to be installed. Python scripts i

5 Nov 17, 2021
Deep Markov Factor Analysis (NeurIPS2021)

Deep Markov Factor Analysis (DMFA) Codes and experiments for deep Markov factor analysis (DMFA) model accepted for publication at NeurIPS2021: A. Farn

Sarah Ostadabbas 2 Dec 16, 2022
Relaxed-machines - explorations in neuro-symbolic differentiable interpreters

Relaxed Machines Explorations in neuro-symbolic differentiable interpreters. Baby steps: inc_stop Libraries JAX Haiku Optax Resources Chapter 3 (∂4: A

Nada Amin 6 Feb 02, 2022
A model that attempts to learn and benefit from data collected on card counting.

A model that attempts to learn and benefit from data collected on card counting. A decision tree like model is built to win more often than loose and increase the bet of the player appropriately to c

1 Dec 17, 2021
[ICCV 2021] Code release for "Sub-bit Neural Networks: Learning to Compress and Accelerate Binary Neural Networks"

Sub-bit Neural Networks: Learning to Compress and Accelerate Binary Neural Networks By Yikai Wang, Yi Yang, Fuchun Sun, Anbang Yao. This is the pytorc

Yikai Wang 26 Nov 20, 2022
Code for the paper "Controllable Video Captioning with an Exemplar Sentence"

SMCG Code for the paper "Controllable Video Captioning with an Exemplar Sentence" Introduction We investigate a novel and challenging task, namely con

10 Dec 04, 2022
An OpenAI Gym environment for multi-agent car racing based on Gym's original car racing environment.

Multi-Car Racing Gym Environment This repository contains MultiCarRacing-v0 a multiplayer variant of Gym's original CarRacing-v0 environment. This env

Igor Gilitschenski 56 Nov 01, 2022
Answer a series of contextually-dependent questions like they may occur in natural human-to-human conversations.

SCAI-QReCC-21 [leaderboards] [registration] [forum] [contact] [SCAI] Answer a series of contextually-dependent questions like they may occur in natura

19 Sep 28, 2022
CTRL-C: Camera calibration TRansformer with Line-Classification

CTRL-C: Camera calibration TRansformer with Line-Classification This repository contains the official code and pretrained models for CTRL-C (Camera ca

57 Nov 14, 2022
Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning.

Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning. Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive

<a href=[email protected](SZ)"> 7 Dec 16, 2021
Identifying Stroke Indicators Using Rough Sets

Identifying Stroke Indicators Using Rough Sets With the spirit of reproducible research, this repository contains all the codes required to produce th

Muhammad Salman Pathan 0 Jun 09, 2022
🐤 Nix-TTS: An Incredibly Lightweight End-to-End Text-to-Speech Model via Non End-to-End Distillation

🐤 Nix-TTS An Incredibly Lightweight End-to-End Text-to-Speech Model via Non End-to-End Distillation Rendi Chevi, Radityo Eko Prasojo, Alham Fikri Aji

Rendi Chevi 156 Jan 09, 2023
Crab is a flexible, fast recommender engine for Python that integrates classic information filtering recommendation algorithms in the world of scientific Python packages (numpy, scipy, matplotlib).

Crab - A Recommendation Engine library for Python Crab is a flexible, fast recommender engine for Python that integrates classic information filtering r

python-recsys 1.2k Dec 21, 2022