Personal implementation of paper "Approximate Nearest Neighbor Negative Contrastive Learning for Dense Text Retrieval"

Overview

Approximate Nearest Neighbor Negative Contrastive Learning for Dense Text Retrieval

This repo provides personal implementation of paper Approximate Nearest Neighbor Negative Contrastive Learning for Dense Text Retrieval in a simplified way. The code is refered to official version of ANCE.

Environment

'transformers==2.3.0' 
'pytrec-eval'
'faiss-cpu'
'wget'
'python==3.6.*'

Data Download & Preprocessing

To download all the needed data, run:

bash commands/data_download.sh 

Data Preprocessing

The command to preprocess passage and document data is listed below:

python data/msmarco_data.py 
--data_dir $raw_data_dir \
--out_data_dir $preprocessed_data_dir \ 
--model_type {use rdot_nll for ANCE FirstP, rdot_nll_multi_chunk for ANCE MaxP} \ 
--model_name_or_path roberta-base \ 
--max_seq_length {use 512 for ANCE FirstP, 2048 for ANCE MaxP} \ 
--data_type {use 1 for passage, 0 for document}

The data preprocessing command is included as the first step in the training command file commands/run_train.sh

Warmup for Training

ANCE training starts from a pretrained BM25 warmup checkpoint. The command with our used parameters to train this warmup checkpoint is in commands/run_train_warmup.py and is shown below:

    python3 -m torch.distributed.launch --nproc_per_node=1 ../drivers/run_warmup.py \
    --train_model_type rdot_nll \
    --model_name_or_path roberta-base \
    --task_name MSMarco \
    --do_train \
    --evaluate_during_training \
    --data_dir ${location of your raw data}  
    --max_seq_length 128 
    --per_gpu_eval_batch_size=256 \
    --per_gpu_train_batch_size=32 \
    --learning_rate 2e-4  \
    --logging_steps 100   \
    --num_train_epochs 2.0  \
    --output_dir ${location for checkpoint saving} \
    --warmup_steps 1000  \
    --overwrite_output_dir \
    --save_steps 30000 \
    --gradient_accumulation_steps 1 \
    --expected_train_size 35000000 \
    --logging_steps_per_eval 1 \
    --fp16 \
    --optimizer lamb \
    --log_dir ~/tensorboard/${DLWS_JOB_ID}/logs/OSpass

Training

To train the model(s) in the paper, you need to start two commands in the following order:

  1. run commands/run_train.sh which does three things in a sequence:

    a. Data preprocessing: this is explained in the previous data preprocessing section. This step will check if the preprocess data folder exists, and will be skipped if the checking is positive.

    b. Initial ANN data generation: this step will use the pretrained BM25 warmup checkpoint to generate the initial training data. The command is as follow:

     python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann_data_gen.py 
     --training_dir {# checkpoint location, not used for initial data generation} \ 
     --init_model_dir {pretrained BM25 warmup checkpoint location} \ 
     --model_type rdot_nll \
     --output_dir $model_ann_data_dir \
     --cache_dir $model_ann_data_dir_cache \
     --data_dir $preprocessed_data_dir \
     --max_seq_length 512 \
     --per_gpu_eval_batch_size 16 \
     --topk_training {top k candidates for ANN search(ie:200)} \ 
     --negative_sample {negative samples per query(20)} \ 
     --end_output_num 0 # only set as 0 for initial data generation, do not set this otherwise
    

    c. Training: ANCE training with the most recently generated ANN data, the command is as follow:

     python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann.py 
     --model_type rdot_nll \
     --model_name_or_path $pretrained_checkpoint_dir \
     --task_name MSMarco \
     --triplet {# default = False, action="store_true", help="Whether to run training}\ 
     --data_dir $preprocessed_data_dir \
     --ann_dir {location of the ANN generated training data} \ 
     --max_seq_length 512 \
     --per_gpu_train_batch_size=8 \
     --gradient_accumulation_steps 2 \
     --learning_rate 1e-6 \
     --output_dir $model_dir \
     --warmup_steps 5000 \
     --logging_steps 100 \
     --save_steps 10000 \
     --optimizer lamb 
    
  2. Once training starts, start another job in parallel to fetch the latest checkpoint from the ongoing training and update the training data. To do that, run

     bash commands/run_ann_data_gen.sh
    

    The command is similar to the initial ANN data generation command explained previously

Inference

The command for inferencing query and passage/doc embeddings is the same as that for Initial ANN data generation described above as the first step in ANN data generation is inference. However you need to add --inference to the command to have the program to stop after the initial inference step. commands/run_inference.sh provides a sample command.

Evaluation

The evaluation is done through "Calculate Metrics.ipynb". This notebook calculates full ranking and reranking metrics used in the paper including NDCG, MRR, hole rate, recall for passage/document, dev/eval set specified by user. In order to run it, you need to define the following parameters at the beginning of the Jupyter notebook.

    checkpoint_path = {location for dumpped query and passage/document embeddings which is output_dir from run_ann_data_gen.py}
    checkpoint =  {embedding from which checkpoint(ie: 200000)}
    data_type =  {0 for document, 1 for passage}
    test_set =  {0 for MSMARCO dev_set, 1 for TREC eval_set}
    raw_data_dir = 
    processed_data_dir = 

ANCE VS DPR on OpenQA Benchmarks

We also evaluate ANCE on the OpenQA benchmark used in a parallel work (DPR). At the time of our experiment, only the pre-processed NQ and TriviaQA data are released. Our experiments use the two released tasks and inherit DPR retriever evaluation. The evaluation uses the [email protected]/100 which is whether the Top-20/100 retrieved passages include the answer. We explain the steps to reproduce our results on OpenQA Benchmarks in this section.

Download data

commands/data_download.sh takes care of this step.

ANN data generation & ANCE training

Following the same training philosophy discussed before, the ann data generation and ANCE training for OpenQA require two parallel jobs.

  1. We need to preprocess data and generate an initial training set for ANCE to start training. The command for that is provided in:
commands/run_ann_data_gen_dpr.sh

We keep this data generation job running after it creates an initial training set as it will later keep generating training data with newest checkpoints from the training process.

  1. After an initial training set is generated, we start an ANCE training job with commands provided in:
commands/run_train_dpr.sh

During training, the evaluation metrics will be printed to tensorboards each time it receives new training data. Alternatively, you could check the metrics in the dumped file "ann_ndcg_#" in the directory specified by "model_ann_data_dir" in commands/run_ann_data_gen_dpr.sh each time new training data is generated.

Results

The run_train.sh and run_ann_data_gen.sh files contain the command with the parameters we used for passage ANCE(FirstP), document ANCE(FirstP) and document ANCE(MaxP) Our model achieves the following performance on MSMARCO dev set and TREC eval set :

MSMARCO Dev Passage Retrieval [email protected] [email protected] Steps
ANCE(FirstP) 0.330 0.959 600K
ANCE(MaxP) - - -
TREC DL Passage [email protected] Rerank Retrieval Steps
ANCE(FirstP) 0.677 0.648 600K
ANCE(MaxP) - - -
TREC DL Document [email protected] Rerank Retrieval Steps
ANCE(FirstP) 0.641 0.615 210K
ANCE(MaxP) 0.671 0.628 139K
MSMARCO Dev Passage Retrieval [email protected] Steps
pretrained BM25 warmup checkpoint 0.311 60K
ANCE Single-task Training Top-20 Top-100 Steps
NQ 81.9 87.5 136K
TriviaQA 80.3 85.3 100K
ANCE Multi-task Training Top-20 Top-100 Steps
NQ 82.1 87.9 300K
TriviaQA 80.3 85.2 300K

Click the steps in the table to download the corresponding checkpoints.

Our result for document ANCE(FirstP) TREC eval set top 100 retrieved document per query could be downloaded here. Our result for document ANCE(MaxP) TREC eval set top 100 retrieved document per query could be downloaded here.

The TREC eval set query embedding and their ids for our passage ANCE(FirstP) experiment could be downloaded here. The TREC eval set query embedding and their ids for our document ANCE(FirstP) experiment could be downloaded here. The TREC eval set query embedding and their ids for our document 2048 ANCE(MaxP) experiment could be downloaded here.

The t-SNE plots for all the queries in the TREC document eval set for ANCE(FirstP) could be viewed here.

run_train.sh and run_ann_data_gen.sh files contain the commands with the parameters we used for passage ANCE(FirstP), document ANCE(FirstP) and document 2048 ANCE(MaxP) to reproduce the results in this section. run_train_warmup.sh contains the commands to reproduce the results for the pretrained BM25 warmup checkpoint in this section

Note the steps to reproduce similar results as shown in the table might be a little different due to different synchronizing between training and ann data generation processes and other possible environment differences of the user experiments.

Owner
John
My research interests are machine learning and recommender systems.
John
Implementation of "RaScaNet: Learning Tiny Models by Raster-Scanning Image" from CVPR 2021.

RaScaNet: Learning Tiny Models by Raster-Scanning Images Deploying deep convolutional neural networks on ultra-low power systems is challenging, becau

SAIT (Samsung Advanced Institute of Technology) 5 Dec 26, 2022
Robust Consistent Video Depth Estimation

[CVPR 2021] Robust Consistent Video Depth Estimation This repository contains Python and C++ implementation of Robust Consistent Video Depth, as descr

Facebook Research 213 Dec 17, 2022
Trajectory Prediction with Graph-based Dual-scale Context Fusion

DSP: Trajectory Prediction with Graph-based Dual-scale Context Fusion Introduction This is the project page of the paper Lu Zhang, Peiliang Li, Jing C

HKUST Aerial Robotics Group 103 Jan 04, 2023
Medical Insurance Cost Prediction using Machine earning

Medical-Insurance-Cost-Prediction-using-Machine-learning - Here in this project, I will use regression analysis to predict medical insurance cost for people in different regions, and based on several

1 Dec 27, 2021
Implementation of the "Point 4D Transformer Networks for Spatio-Temporal Modeling in Point Cloud Videos" paper.

Point 4D Transformer Networks for Spatio-Temporal Modeling in Point Cloud Videos Introduction Point cloud videos exhibit irregularities and lack of or

Hehe Fan 101 Dec 29, 2022
Unofficial Implementation of MLP-Mixer, gMLP, resMLP, Vision Permutator, S2MLPv2, RaftMLP, ConvMLP, ConvMixer in Jittor and PyTorch.

Unofficial Implementation of MLP-Mixer, gMLP, resMLP, Vision Permutator, S2MLPv2, RaftMLP, ConvMLP, ConvMixer in Jittor and PyTorch! Now, Rearrange and Reduce in einops.layers.jittor are support!!

130 Jan 08, 2023
A curated list of the top 10 computer vision papers in 2021 with video demos, articles, code and paper reference.

The Top 10 Computer Vision Papers of 2021 The top 10 computer vision papers in 2021 with video demos, articles, code, and paper reference. While the w

Louis-François Bouchard 118 Dec 21, 2022
This project provides a stock market environment using OpenGym with Deep Q-learning and Policy Gradient.

Stock Trading Market OpenAI Gym Environment with Deep Reinforcement Learning using Keras Overview This project provides a general environment for stoc

Kim, Ki Hyun 769 Dec 25, 2022
This repository provides an efficient PyTorch-based library for training deep models.

s3sec Test AWS S3 buckets for read/write/delete access This tool was developed to quickly test a list of s3 buckets for public read, write and delete

Bytedance Inc. 123 Jan 05, 2023
KE-Dialogue: Injecting knowledge graph into a fully end-to-end dialogue system.

Learning Knowledge Bases with Parameters for Task-Oriented Dialogue Systems This is the implementation of the paper: Learning Knowledge Bases with Par

CAiRE 42 Nov 10, 2022
Repository for self-supervised landmark discovery

self-supervised-landmarks Repository for self-supervised landmark discovery Requirements pytorch pynrrd (for 3d images) Usage The use of this models i

Riddhish Bhalodia 2 Apr 18, 2022
Multi-View Radar Semantic Segmentation

Multi-View Radar Semantic Segmentation Paper Multi-View Radar Semantic Segmentation, ICCV 2021. Arthur Ouaknine, Alasdair Newson, Patrick Pérez, Flore

valeo.ai 37 Oct 25, 2022
VIsually-Pivoted Audio and(N) Text

VIP-ANT: VIsually-Pivoted Audio and(N) Text Code for the paper Connecting the Dots between Audio and Text without Parallel Data through Visual Knowled

Yän.PnG 16 Nov 04, 2022
HODEmu, is both an executable and a python library that is based on Ragagnin 2021 in prep.

HODEmu HODEmu, is both an executable and a python library that is based on Ragagnin 2021 in prep. and emulates satellite abundance as a function of co

Antonio Ragagnin 1 Oct 13, 2021
Video2x - A lossless video/GIF/image upscaler achieved with waifu2x, Anime4K, SRMD and RealSR.

Official Discussion Group (Telegram): https://t.me/video2x A Discord server is also available. Please note that most developers are only on Telegram.

K4YT3X 5.9k Dec 31, 2022
Autoencoders pretraining using clustering

Autoencoders pretraining using clustering

IITiS PAN 2 Dec 16, 2021
An original implementation of "Noisy Channel Language Model Prompting for Few-Shot Text Classification"

Channel LM Prompting (and beyond) This includes an original implementation of Sewon Min, Mike Lewis, Hannaneh Hajishirzi, Luke Zettlemoyer. "Noisy Cha

Sewon Min 92 Jan 07, 2023
Neon: an add-on for Lightbulb making it easier to handle component interactions

Neon Neon is an add-on for Lightbulb making it easier to handle component interactions. Installation pip install git+https://github.com/neonjonn/light

Neon Jonn 9 Apr 29, 2022
Real-time pose estimation accelerated with NVIDIA TensorRT

trt_pose Want to detect hand poses? Check out the new trt_pose_hand project for real-time hand pose and gesture recognition! trt_pose is aimed at enab

NVIDIA AI IOT 803 Jan 06, 2023
diablo2 resurrected loot filter

Only For Chinese and Traditional Chinese The filter only for Chinese and Traditional Chinese, i didn't change it for other language.Maybe you could mo

elmagnifico 249 Dec 04, 2022