Non-Autoregressive Translation with Layer-Wise Prediction and Deep Supervision

Overview

This repo contains the implementation of our paper:

Non-Autoregressive Translation with Layer-Wise Prediction and Deep Supervision

Paper Link

Replication

Python environment

pip install -e . # under DSLP directory
pip install tensorflow tensorboard sacremoses nltk Ninja omegaconf
pip install 'fuzzywuzzy[speedup]'
pip install hydra-core==1.0.6
pip install sacrebleu==1.5.1
pip install git+https://github.com/dugu9sword/lunanlp.git
git clone --recursive https://github.com/parlance/ctcdecode.git
cd ctcdecode && pip install .

Dataset

We downloaded the distilled data from FairSeq

Preprocessed by

TEXT=wmt14_ende_distill
python3 fairseq_cli/preprocess.py --source-lang en --target-lang de \
   --trainpref $TEXT/train.en-de --validpref $TEXT/valid.en-de --testpref $TEXT/test.en-de \
   --destdir data-bin/wmt14.en-de_kd --workers 40 --joined-dictionary

Or you can download all the binarized files here.

Hyperparameters

EN<->RO EN<->DE
--validate-interval-updates 300 500
number of tokens per batch 32K 128K
--dropout 0.3 0.1

Note:

  1. We found that label smoothing for CTC-based models are not useful (at least not with our implementation), it is suggested to keep --label-smoothing as 0 for them.
  2. Dropout rate plays a significant role for GLAT, CMLM, and the Vanilla NAT. On WMT'14 EN->De, for example, the Vanilla NAT with dropout 0.1 reaches 21.18 BLEU; but only gives 19.68 BLEU with dropout 0.3.

Training:

We provide the scripts for replicating the results on WMT'14 EN->DE task. For other tasks, you need to adapt the binary path, --source-lang, --target-lang, and some other hyperparameters accordingly.

GLAT with DSLP

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_glat --criterion glat_loss --arch glat_sd --noise full_mask \ 
   --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 --glat-mode glat \ 
   --length-loss-factor 0.1 --pred-length-offset 

CMLM with DSLP

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch glat_sd --noise full_mask \ 
   --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 \
   --length-loss-factor 0.1 --pred-length-offset 

Vanilla NAT with DSLP

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_sd --noise full_mask \ 
   --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 \
   --length-loss-factor 0.1 --pred-length-offset 

Vanilla NAT with DSLP and Mixed Training:

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_sd --noise full_mask \ 
   --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192  --ss-ratio 0.3 --fixed-ss-ratio --masked-loss \ 
   --length-loss-factor 0.1 --pred-length-offset 

CTC with DSLP:

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_ctc_sd --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  --label-smoothing 0.0 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 

CTC with DSLP and Mixed Training:

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_ctc_sd_ss --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  --label-smoothing 0.0 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 --ss-ratio 0.3 --fixed-ss-ratio

Evaluation

Average the last best 5 checkpoints with scripts/average_checkpoints.py, our results are based on either the best checkpoint or the averaged checkpoint, depending on their valid set BLEU.

fairseq-generate data-bin/wmt14.en-de_kd  --path PATH_TO_A_CHECKPOINT \
    --gen-subset test --task translation_lev --iter-decode-max-iter 0 \
    --iter-decode-eos-penalty 0 --beam 1 --remove-bpe --print-step --batch-size 100

Note: 1) Add --plain-ctc --model-overrides '{"ctc_beam_size": 1, "plain_ctc": True}' if it is CTC based; 2) Change the task to translation_glat if it is GLAT based.

Output

We in addition provide the output of CTC w/ DSLP, CTC w/ DSLP & Mixed Training, Vanilla NAT w/ DSLP, Vanilla NAT w/ DSLP with Mixed Training, GLAT w/ DSLP, and CMLM w/ DSLP for review purpose.

Model Reference Hypothesis
CTC w/ DSLP ref hyp
CTC w/ DSLP & Mixed Training ref hyp
Vanilla NAT w/ DSLP ref hyp
Vanilla NAT w/ DSLP & Mixed Training ref hyp
GLAT w/ DSLP ref hyp
CMLM w/ DSLP ref hyp

Note: The output is on WMT'14 EN-DE. The references are paired with hypotheses for each model.

Training Efficiency

We show the training efficiency of our DSLP model based on vanilla NAT model. Specifically, we compared the BLUE socres of vanilla NAT and vanilla NAT with DSLP & Mixed Training on the same traning time (in hours).

As we observed, our DSLP model achieves much higher BLUE scores shortly after the training started (~3 hours). It shows that our DSLP is much more efficient in training, as our model ahieves higher BLUE scores with the same amount of training cost.

Efficiency

We run the experiments with 8 Tesla V100 GPUs. The batch size is 128K tokens, and each model is trained with 300K updates.

Owner
Chenyang Huang
Stay hungry, stay foolish
Chenyang Huang
Rethinking the Truly Unsupervised Image-to-Image Translation - Official PyTorch Implementation (ICCV 2021)

Rethinking the Truly Unsupervised Image-to-Image Translation (ICCV 2021) Each image is generated with the source image in the left and the average sty

Clova AI Research 436 Dec 27, 2022
NLP made easy

GluonNLP: Your Choice of Deep Learning for NLP GluonNLP is a toolkit that helps you solve NLP problems. It provides easy-to-use tools that helps you l

Distributed (Deep) Machine Learning Community 2.5k Jan 04, 2023
Chinese Grammatical Error Diagnosis

nlp-CGED Chinese Grammatical Error Diagnosis 中文语法纠错研究 基于序列标注的方法 所需环境 Python==3.6 tensorflow==1.14.0 keras==2.3.1 bert4keras==0.10.6 笔者使用了开源的bert4keras

12 Nov 25, 2022
Graphical user interface for Argos Translate

Argos Translate GUI Website | GitHub | PyPI Graphical user interface for Argos Translate. Install pip3 install argostranslategui

Argos Open Tech 16 Dec 07, 2022
This repository contains the code, models and datasets discussed in our paper "Few-Shot Question Answering by Pretraining Span Selection"

Splinter This repository contains the code, models and datasets discussed in our paper "Few-Shot Question Answering by Pretraining Span Selection", to

Ori Ram 88 Dec 31, 2022
DaCy: The State of the Art Danish NLP pipeline using SpaCy

DaCy: A SpaCy NLP Pipeline for Danish DaCy is a Danish preprocessing pipeline trained in SpaCy. At the time of writing it has achieved State-of-the-Ar

Kenneth Enevoldsen 71 Jan 06, 2023
Long text token classification using LongFormer

Long text token classification using LongFormer

abhishek thakur 161 Aug 07, 2022
自然言語で書かれた時間情報表現を抽出/規格化するルールベースの解析器

ja-timex 自然言語で書かれた時間情報表現を抽出/規格化するルールベースの解析器 概要 ja-timex は、現代日本語で書かれた自然文に含まれる時間情報表現を抽出しTIMEX3と呼ばれるアノテーション仕様に変換することで、プログラムが利用できるような形に規格化するルールベースの解析器です。

Yuki Okuda 116 Nov 09, 2022
Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition

SEW (Squeezed and Efficient Wav2vec) The repo contains the code of the paper "Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speec

ASAPP Research 67 Dec 01, 2022
Understanding the Difficulty of Training Transformers

Admin Understanding the Difficulty of Training Transformers Guided by our analyses, we propose Adaptive Model Initialization (Admin), which successful

Liyuan Liu 300 Dec 29, 2022
Multi Task Vision and Language

12-in-1: Multi-Task Vision and Language Representation Learning Please cite the following if you use this code. Code and pre-trained models for 12-in-

Meta Research 711 Jan 08, 2023
Paddle2.x version AI-Writer

Paddle2.x 版本AI-Writer 用魔改 GPT 生成网文。Tuned GPT for novel generation.

yujun 74 Jan 04, 2023
🤕 spelling exceptions builder for lazy people

🤕 spelling exceptions builder for lazy people

Vlad Bokov 3 May 12, 2022
Knowledge Graph,Question Answering System,基于知识图谱和向量检索的医疗诊断问答系统

Knowledge Graph,Question Answering System,基于知识图谱和向量检索的医疗诊断问答系统

wangle 823 Dec 28, 2022
⚡ Automatically decrypt encryptions without knowing the key or cipher, decode encodings, and crack hashes ⚡

Translations 🇩🇪 DE 🇫🇷 FR 🇭🇺 HU 🇮🇩 ID 🇮🇹 IT 🇳🇱 NL 🇧🇷 PT-BR 🇷🇺 RU 🇨🇳 ZH ➡️ Documentation | Discord | Installation Guide ⬅️ Fully autom

11.2k Jan 05, 2023
A simple Flask site that allows users to create, update, and delete posts in a database, as well as perform basic NLP tasks on the posts.

A simple Flask site that allows users to create, update, and delete posts in a database, as well as perform basic NLP tasks on the posts.

Ian 1 Jan 15, 2022
VoiceFixer VoiceFixer is a framework for general speech restoration.

VoiceFixer VoiceFixer is a framework for general speech restoration. We aim at the restoration of severly degraded speech and historical speech. Paper

Leo 174 Jan 06, 2023
Python library for processing Chinese text

SnowNLP: Simplified Chinese Text Processing SnowNLP是一个python写的类库,可以方便的处理中文文本内容,是受到了TextBlob的启发而写的,由于现在大部分的自然语言处理库基本都是针对英文的,于是写了一个方便处理中文的类库,并且和TextBlob

Rui Wang 6k Jan 02, 2023
Google AI 2018 BERT pytorch implementation

BERT-pytorch Pytorch implementation of Google AI's 2018 BERT, with simple annotation BERT 2018 BERT: Pre-training of Deep Bidirectional Transformers f

Junseong Kim 5.3k Jan 07, 2023
A2T: Towards Improving Adversarial Training of NLP Models (EMNLP 2021 Findings)

A2T: Towards Improving Adversarial Training of NLP Models This is the source code for the EMNLP 2021 (Findings) paper "Towards Improving Adversarial T

QData 17 Oct 15, 2022