EsViT: Efficient self-supervised Vision Transformers

Overview

Efficient Self-Supervised Vision Transformers (EsViT)

PWC

PyTorch implementation for EsViT, built with two techniques:

  • A multi-stage Transformer architecture. Three multi-stage Transformer variants are implemented under the folder models.
  • A region-level matching pre-train task. The region-level matching task is implemented in function DDINOLoss(nn.Module) (Line 648) in main_esvit.py. Please use --use_dense_prediction True, otherwise only the view-level task is used.
Efficiency vs accuracy comparison under the linear classification protocol on ImageNet with EsViT
Figure: Efficiency vs accuracy comparison under the linear classification protocol on ImageNet. Left: Throughput of all SoTA SSL vision systems, circle sizes indicates model parameter counts; Right: performance over varied parameter counts for models with moderate (throughout/#parameters) ratio. Please refer Section 4.1 for details.

Pretrained models

You can download the full checkpoint (trained with both view-level and region-level tasks, batch size=512 and ImageNet-1K.), which contains backbone and projection head weights for both student and teacher networks.

arch params linear k-nn download logs
EsViT (Swin-T, W=7) 28M 78.0% 75.7% full ckpt train linear knn
EsViT (Swin-S, W=7) 49M 79.5% 77.7% full ckpt train linear knn
EsViT (Swin-B, W=7) 87M 80.4% 78.9% full ckpt train linear knn
EsViT (Swin-T, W=14) 28M 78.7% 77.0% full ckpt train linear knn
EsViT (Swin-S, W=14) 49M 80.8% 79.1% full ckpt train linear knn
EsViT (Swin-B, W=14) 87M 81.3% 79.3% full ckpt train linear knn

EsViT (Swin-T, W=7) with different pre-train datasets (view-level task only)

arch params batch size pre-train dataset linear k-nn download logs
EsViT 28M 512 ImageNet-1K 77.0% 74.2% full ckpt train linear knn
EsViT 28M 1024 ImageNet-1K 77.1% 73.7% full ckpt train linear knn
EsViT 28M 1024 WebVision-v1 75.4% 69.4% full ckpt train linear knn
EsViT 28M 1024 OpenImages-v4 69.6% 60.3% full ckpt train linear knn
EsViT 28M 1024 ImageNet-22K 73.5% 66.1% full ckpt train linear knn

Pre-training

One-node training

To train on 1 node with 16 GPUs for Swin-T model size:

PROJ_PATH=your_esvit_project_path
DATA_PATH=$PROJ_PATH/project/data/imagenet

OUT_PATH=$PROJ_PATH/output/esvit_exp/ssl/swin_tiny_imagenet/
python -m torch.distributed.launch --nproc_per_node=16 main_esvit.py --arch swin_tiny --data_path $DATA_PATH/train --output_dir $OUT_PATH --batch_size_per_gpu 32 --epochs 300 --teacher_temp 0.07 --warmup_epochs 10 --warmup_teacher_temp_epochs 30 --norm_last_layer false --use_dense_prediction True --cfg experiments/imagenet/swin/swin_tiny_patch4_window7_224.yaml 

The main training script is main_esvit.py and conducts the training loop, taking the following options (among others) as arguments:

  • --use_dense_prediction: whether or not to use the region matching task in pre-training
  • --arch: switch between different sparse self-attention in the multi-stage Transformer architecture. Example architecture choices for EsViT training include [swin_tiny, swin_small, swin_base, swin_large,cvt_tiny, vil_2262]. The configuration files should be adjusted accrodingly, we provide example below. One may specify the network configuration by editing the YAML file under experiments/imagenet/*/*.yaml. The default window size=7; To consider a multi-stage architecture with window size=14, please choose yaml files with window14 in filenames.

To train on 1 node with 16 GPUs for Convolutional vision Transformer (CvT) models:

python -m torch.distributed.launch --nproc_per_node=16 main_evsit.py --arch cvt_tiny --data_path $DATA_PATH/train --output_dir $OUT_PATH --batch_size_per_gpu 32 --epochs 300 --teacher_temp 0.07 --warmup_epochs 10 --warmup_teacher_temp_epochs 30 --norm_last_layer false --use_dense_prediction True --aug-opt dino_aug --cfg experiments/imagenet/cvt_v4/s1.yaml

To train on 1 node with 16 GPUs for Vision Longformer (ViL) models:

python -m torch.distributed.launch --nproc_per_node=16 main_evsit.py --arch vil_2262 --data_path $DATA_PATH/train --output_dir $OUT_PATH --batch_size_per_gpu 32 --epochs 300 --teacher_temp 0.07 --warmup_epochs 10 --warmup_teacher_temp_epochs 30 --norm_last_layer false --use_dense_prediction True --aug-opt dino_aug --cfg experiments/imagenet/vil/vil_small/base.yaml MODEL.SPEC.MSVIT.ARCH 'l1,h3,d96,n2,s1,g1,p4,f7,a0_l2,h6,d192,n2,s1,g1,p2,f7,a0_l3,h12,d384,n6,s0,g1,p2,f7,a0_l4,h24,d768,n2,s0,g0,p2,f7,a0' MODEL.SPEC.MSVIT.MODE 1 MODEL.SPEC.MSVIT.VIL_MODE_SWITCH 0.75

Multi-node training

To train on 2 nodes with 16 GPUs each (total 32 GPUs) for Swin-Small model size:

OUT_PATH=$PROJ_PATH/exp_output/esvit_exp/swin/swin_small/bl_lr0.0005_gpu16_bs16_multicrop_epoch300_dino_aug_window14
python main_evsit_mnodes.py --num_nodes 2 --num_gpus_per_node 16 --data_path $DATA_PATH/train --output_dir $OUT_PATH/continued_from0200_dense --batch_size_per_gpu 16 --arch swin_small --zip_mode True --epochs 300 --teacher_temp 0.07 --warmup_epochs 10 --warmup_teacher_temp_epochs 30 --norm_last_layer false --cfg experiments/imagenet/swin/swin_small_patch4_window14_224.yaml --use_dense_prediction True --pretrained_weights_ckpt $OUT_PATH/checkpoint0200.pth

Evaluation:

k-NN and Linear classification on ImageNet

To train a supervised linear classifier on frozen weights on a single node with 4 gpus, run eval_linear.py. To train a k-NN classifier on frozen weights on a single node with 4 gpus, run eval_knn.py. Please specify --arch, --cfg and --pretrained_weights to choose a pre-trained checkpoint. If you want to evaluate the last checkpoint of EsViT with Swin-T, you can run for example:

PROJ_PATH=your_esvit_project_path
DATA_PATH=$PROJ_PATH/project/data/imagenet

OUT_PATH=$PROJ_PATH/exp_output/esvit_exp/swin/swin_tiny/bl_lr0.0005_gpu16_bs32_dense_multicrop_epoch300
CKPT_PATH=$PROJ_PATH/exp_output/esvit_exp/swin/swin_tiny/bl_lr0.0005_gpu16_bs32_dense_multicrop_epoch300/checkpoint.pth

python -m torch.distributed.launch --nproc_per_node=4 eval_linear.py --data_path $DATA_PATH --output_dir $OUT_PATH/lincls/epoch0300 --pretrained_weights $CKPT_PATH --checkpoint_key teacher --batch_size_per_gpu 256 --arch swin_tiny --cfg experiments/imagenet/swin/swin_tiny_patch4_window7_224.yaml --n_last_blocks 4 --num_labels 1000 MODEL.NUM_CLASSES 0

python -m torch.distributed.launch --nproc_per_node=4 eval_knn.py --data_path $DATA_PATH --dump_features $OUT_PATH/features/epoch0300 --pretrained_weights $CKPT_PATH --checkpoint_key teacher --batch_size_per_gpu 256 --arch swin_tiny --cfg experiments/imagenet/swin/swin_tiny_patch4_window7_224.yaml MODEL.NUM_CLASSES 0

Analysis/Visualization of correspondence and attention maps

You can analyze the learned models by running python run_analysis.py. One example to analyze EsViT (Swin-T) is shown.

For an invidiual image (with path --image_path $IMG_PATH), we visualize the attention maps and correspondence of the last layer:

python run_analysis.py --arch swin_tiny --image_path $IMG_PATH --output_dir $OUT_PATH --pretrained_weights $CKPT_PATH --learning ssl --seed $SEED --cfg experiments/imagenet/swin/swin_tiny_patch4_window7_224.yaml --vis_attention True --vis_correspondence True MODEL.NUM_CLASSES 0 

For an image dataset (with path --data_path $DATA_PATH), we quantatively measure the correspondence:

python run_analysis.py --arch swin_tiny --data_path $DATA_PATH --output_dir $OUT_PATH --pretrained_weights $CKPT_PATH --learning ssl --seed $SEED --cfg experiments/imagenet/swin/swin_tiny_patch4_window7_224.yaml  --measure_correspondence True MODEL.NUM_CLASSES 0 

For more examples, please see scripts/scripts_local/run_analysis.sh.

Citation

If you find this repository useful, please consider giving a star and citation 🍺 :

@article{li2021esvit,
  title={Efficient Self-supervised Vision Transformers for Representation Learning},
  author={Li, Chunyuan and Yang, Jianwei and Zhang, Pengchuan and Gao, Mei and Xiao, Bin and Dai, Xiyang and Yuan, Lu and Gao, Jianfeng},
  journal={arXiv preprint arXiv:2106.09785},
  year={2021}
}

Related Projects/Codebase

[Swin Transformers] [Vision Longformer] [Convolutional vision Transformers (CvT)] [Focal Transformers]

Acknowledgement

Our implementation is built partly upon packages: [Dino] [Timm]

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.

Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.

Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
Change Detection in SAR Images Based on Multiscale Capsule Network

SAR_CD_MS_CapsNet Code for the paper "Change Detection in SAR Images Based on Multiscale Capsule Network" , IEEE Geoscience and Remote Sensing Letters

Feng Gao 21 Nov 29, 2022
Code for the paper One Thing One Click: A Self-Training Approach for Weakly Supervised 3D Semantic Segmentation, CVPR 2021.

One Thing One Click One Thing One Click: A Self-Training Approach for Weakly Supervised 3D Semantic Segmentation (CVPR2021) Code for the paper One Thi

44 Dec 12, 2022
Code for generating a single image pretraining dataset

Single Image Pretraining of Visual Representations As shown in the paper A critical analysis of self-supervision, or what we can learn from a single i

Yuki M. Asano 12 Dec 19, 2022
Arxiv harvester - Poor man's simple harvester for arXiv resources

Poor man's simple harvester for arXiv resources This modest Python script takes

Patrice Lopez 5 Oct 18, 2022
A minimal implementation of face-detection models using flask, gunicorn, nginx, docker, and docker-compose

Face-Detection-flask-gunicorn-nginx-docker This is a simple implementation of dockerized face-detection restful-API implemented with flask, Nginx, and

Pooya-Mohammadi 30 Dec 17, 2022
Pywonderland - A tour in the wonderland of math with python.

A Tour in the Wonderland of Math with Python A collection of python scripts for drawing beautiful figures and animating interesting algorithms in math

Zhao Liang 4.1k Jan 03, 2023
Instant Real-Time Example-Based Style Transfer to Facial Videos

FaceBlit: Instant Real-Time Example-Based Style Transfer to Facial Videos The official implementation of FaceBlit: Instant Real-Time Example-Based Sty

Aneta Texler 131 Dec 19, 2022
Model that predicts the probability of a Twitter user being anti-vaccination.

stylebody {text-align: justify}/style AVAXTAR: Anti-VAXx Tweet AnalyzeR AVAXTAR is a python package to identify anti-vaccine users on twitter. The

10 Sep 27, 2022
Code for 'Blockwise Sequential Model Learning for Partially Observable Reinforcement Learning' (AAAI 2022)

Blockwise Sequential Model Learning Code for 'Blockwise Sequential Model Learning for Partially Observable Reinforcement Learning' (AAAI 2022) For ins

2 Jun 17, 2022
MARE - Multi-Attribute Relation Extraction

MARE - Multi-Attribute Relation Extraction Repository for the paper submission: #TODO: insert link, when available Environment Tested with Ubuntu 18.0

0 May 11, 2021
Python Actor concurrency library

Thespian Actor Library This library provides the framework of an Actor model for use by applications implementing Actors. Thespian Site with Documenta

Kevin Quick 177 Dec 11, 2022
Unofficial PyTorch Implementation of "DOLG: Single-Stage Image Retrieval with Deep Orthogonal Fusion of Local and Global Features"

Pytorch Implementation of Deep Orthogonal Fusion of Local and Global Features (DOLG) This is the unofficial PyTorch Implementation of "DOLG: Single-St

DK 96 Jan 06, 2023
METS/ALTO OCR enhancing tool by the National Library of Luxembourg (BnL)

Nautilus-OCR The National Library of Luxembourg (BnL) started its first initiative in digitizing newspapers, with layout recognition and OCR on articl

National Library of Luxembourg 36 Dec 05, 2022
Semantically Contrastive Learning for Low-light Image Enhancement

Semantically Contrastive Learning for Low-light Image Enhancement Here, we propose an effective semantically contrastive learning paradigm for Low-lig

48 Dec 16, 2022
Code for the paper Progressive Pose Attention for Person Image Generation in CVPR19 (Oral).

Pose-Transfer Code for the paper Progressive Pose Attention for Person Image Generation in CVPR19(Oral). The paper is available here. Video generation

Tengteng Huang 679 Jan 04, 2023
The Malware Open-source Threat Intelligence Family dataset contains 3,095 disarmed PE malware samples from 454 families

MOTIF Dataset The Malware Open-source Threat Intelligence Family (MOTIF) dataset contains 3,095 disarmed PE malware samples from 454 families, labeled

Booz Allen Hamilton 112 Dec 13, 2022
Official implementation of Monocular Quasi-Dense 3D Object Tracking

Monocular Quasi-Dense 3D Object Tracking Monocular Quasi-Dense 3D Object Tracking (QD-3DT) is an online framework detects and tracks objects in 3D usi

Visual Intelligence and Systems Group 441 Dec 20, 2022
PyTorch implementation HoroPCA: Hyperbolic Dimensionality Reduction via Horospherical Projections

HoroPCA This code is the official PyTorch implementation of the ICML 2021 paper: HoroPCA: Hyperbolic Dimensionality Reduction via Horospherical Projec

HazyResearch 52 Nov 14, 2022
Disease Informed Neural Networks (DINNs) — neural networks capable of learning how diseases spread, forecasting their progression, and finding their unique parameters (e.g. death rate).

DINN We introduce Disease Informed Neural Networks (DINNs) — neural networks capable of learning how diseases spread, forecasting their progression, a

19 Dec 10, 2022
Reliable probability face embeddings

ProbFace, arxiv This is a demo code of training and testing [ProbFace] using Tensorflow. ProbFace is a reliable Probabilistic Face Embeddging (PFE) me

Kaen Chan 34 Dec 31, 2022