The official implementation of ELSA: Enhanced Local Self-Attention for Vision Transformer

Related tags

Deep LearningELSA
Overview

ELSA: Enhanced Local Self-Attention for Vision Transformer

By Jingkai Zhou, Pichao Wang*, Fan Wang, Qiong Liu, Hao Li, Rong Jin

This repo is the official implementation of "ELSA: Enhanced Local Self-Attention for Vision Transformer".

Introduction

Self-attention is powerful in modeling long-range dependencies, but it is weak in local finer-level feature learning. As shown in Figure 1, the performance of local self-attention (LSA) is just on par with convolution and inferior to dynamic filters, which puzzles researchers on whether to use LSA or its counterparts, which one is better, and what makes LSA mediocre. In this work, we comprehensively investigate LSA and its counterparts. We find that the devil lies in the generation and application of spatial attention.

Based on these findings, we propose the enhanced local self-attention (ELSA) with Hadamard attention and the ghost head, as illustrated in Figure 2. Experiments demonstrate the effectiveness of ELSA. Without architecture / hyperparameter modification, The use of ELSA in drop-in replacement boosts baseline methods consistently in both upstream and downstream tasks.

Please refer to our paper for more details.

Model zoo

ImageNet Classification

Model #Params Pretrain Resolution Top1 Acc Download
ELSA-Swin-T 28M ImageNet 1K 224 82.7 google / baidu
ELSA-Swin-S 53M ImageNet 1K 224 83.5 google / baidu
ELSA-Swin-B 93M ImageNet 1K 224 84.0 google / baidu

COCO Object Detection

Backbone Method Pretrain Lr Schd Box mAP Mask mAP #Params Download
ELSA-Swin-T Mask R-CNN ImageNet-1K 1x 45.7 41.1 49M google / baidu
ELSA-Swin-T Mask R-CNN ImageNet-1K 3x 47.5 42.7 49M google / baidu
ELSA-Swin-S Mask R-CNN ImageNet-1K 1x 48.3 43.0 72M google / baidu
ELSA-Swin-S Mask R-CNN ImageNet-1K 3x 49.2 43.6 72M google / baidu
ELSA-Swin-T Cascade Mask R-CNN ImageNet-1K 1x 49.8 43.0 86M google / baidu
ELSA-Swin-T Cascade Mask R-CNN ImageNet-1K 3x 51.0 44.2 86M google / baidu
ELSA-Swin-S Cascade Mask R-CNN ImageNet-1K 1x 51.6 44.4 110M google / baidu
ELSA-Swin-S Cascade Mask R-CNN ImageNet-1K 3x 52.3 45.2 110M google / baidu

ADE20K Semantic Segmentation

Backbone Method Pretrain Crop Size Lr Schd mIoU (ms+flip) #Params Download
ELSA-Swin-T UPerNet ImageNet-1K 512x512 160K 47.9 61M google / baidu
ELSA-Swin-S UperNet ImageNet-1K 512x512 160K 50.4 85M google / baidu

Install

  • Clone this repo:
git clone https://github.com/damo-cv/ELSA.git elsa
cd elsa
  • Create a conda virtual environment and activate it:
conda create -n elsa python=3.7 -y
conda activate elsa
  • Install PyTorch==1.8.0 and torchvision==0.9.0 with CUDA==10.1:
conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.1 -c pytorch
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
cd ../
  • Install mmcv-full==1.3.0
pip install mmcv-full==1.3.0 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html
  • Install other requirements:
pip install -r requirements.txt
  • Install mmdet and mmseg:
cd ./det
pip install -v -e .
cd ../seg
pip install -v -e .
cd ../
  • Build the elsa operation:
cd ./cls/models/elsa
python setup.py install
mv build/lib*/* .
cp *.so ../../../det/mmdet/models/backbones/elsa/
cp *.so ../../../seg/mmseg/models/backbones/elsa/
cd ../../../

Data preparation

We use standard ImageNet dataset, you can download it from http://image-net.org/. Please prepare it under the following file structure:

$ tree data
imagenet
├── train
│   ├── class1
│   │   ├── img1.jpeg
│   │   ├── img2.jpeg
│   │   └── ...
│   ├── class2
│   │   ├── img3.jpeg
│   │   └── ...
│   └── ...
└── val
    ├── class1
    │   ├── img4.jpeg
    │   ├── img5.jpeg
    │   └── ...
    ├── class2
    │   ├── img6.jpeg
    │   └── ...
    └── ...

Also, please prepare the COCO and ADE20K datasets following their links. Then, please link them to det/data and seg/data.

Evaluation

ImageNet Classification

Run following scripts to evaluate pre-trained models on the ImageNet-1K:

cd cls

python validate.py <PATH_TO_IMAGENET> --model elsa_swin_tiny --checkpoint <CHECKPOINT_FILE> \
  --no-test-pool --apex-amp --img-size 224 -b 128

python validate.py <PATH_TO_IMAGENET> --model elsa_swin_small --checkpoint <CHECKPOINT_FILE> \
  --no-test-pool --apex-amp --img-size 224 -b 128

python validate.py <PATH_TO_IMAGENET> --model elsa_swin_base --checkpoint <CHECKPOINT_FILE> \
  --no-test-pool --apex-amp --img-size 224 -b 128 --use-ema

COCO Detection

Run following scripts to evaluate a detector on the COCO:

cd det

# single-gpu testing
python tools/test.py <CONFIG_FILE> <DET_CHECKPOINT_FILE> --eval bbox segm

# multi-gpu testing
tools/dist_test.sh <CONFIG_FILE> <DET_CHECKPOINT_FILE> <GPU_NUM> --eval bbox segm

ADE20K Semantic Segmentation

Run following scripts to evaluate a model on the ADE20K:

cd seg

# single-gpu testing
python tools/test.py <CONFIG_FILE> <SEG_CHECKPOINT_FILE> --aug-test --eval mIoU

# multi-gpu testing
tools/dist_test.sh <CONFIG_FILE> <SEG_CHECKPOINT_FILE> <GPU_NUM> --aug-test --eval mIoU

Training from scratch

Due to randomness, the re-training results may have a gap of about 0.1~0.2% with the numbers in the paper.

ImageNet Classification

Run following scripts to train classifiers on the ImageNet-1K:

cd cls

bash ./distributed_train.sh 8 <PATH_TO_IMAGENET> --model elsa_swin_tiny \
  --epochs 300 -b 128 -j 8 --opt adamw --lr 1e-3 --sched cosine --weight-decay 5e-2 \
  --warmup-epochs 20 --warmup-lr 1e-6 --min-lr 1e-5 --drop-path 0.1 --aa rand-m9-mstd0.5-inc1 \
  --mixup 0.8 --cutmix 1. --remode pixel --reprob 0.25 --clip-grad 5. --amp

bash ./distributed_train.sh 8 <PATH_TO_IMAGENET> --model elsa_swin_small \
  --epochs 300 -b 128 -j 8 --opt adamw --lr 1e-3 --sched cosine --weight-decay 5e-2 \
  --warmup-epochs 20 --warmup-lr 1e-6 --min-lr 1e-5 --drop-path 0.3 --aa rand-m9-mstd0.5-inc1 \
  --mixup 0.8 --cutmix 1. --remode pixel --reprob 0.25 --clip-grad 5. --amp

bash ./distributed_train.sh 8 <PATH_TO_IMAGENET> --model elsa_swin_base \
  --epochs 300 -b 128 -j 8 --opt adamw --lr 1e-3 --sched cosine --weight-decay 5e-2 \
  --warmup-epochs 20 --warmup-lr 1e-6 --min-lr 1e-5 --drop-path 0.5 --aa rand-m9-mstd0.5-inc1 \
  --mixup 0.8 --cutmix 1. --remode pixel --reprob 0.25 --clip-grad 5. --amp --model-ema

If GPU memory is not enough when training elsa_swin_base, you can use two nodes (2 * 8 GPUs), each with a batch size of 64 images/GPU.

COCO Detection / ADE20K Semantic Segmentation

Run following scripts to train models on the COCO / ADE20K:

cd det 
# (or cd seg)

# multi-gpu training
tools/dist_train.sh <CONFIG_FILE> <GPU_NUM> --cfg-options model.pretrained=<PRETRAIN_MODEL> [model.backbone.use_checkpoint=True] [other optional arguments] 

Acknowledgement

This work was supported by Alibaba Group through Alibaba Research Intern Program and the National Natural Science Foundation of China (No.61976094).

Codebase from pytorch-image-models, ddfnet, VOLO, Swin-Transformer, Swin-Transformer-Detection, and Swin-Transformer-Semantic-Segmentation

Citing ELSA

@article{zhou2021ELSA,
  title={ELSA: Enhanced Local Self-Attention for Vision Transformer},
  author={Zhou, Jingkai and Wang, Pichao and Wang, Fan and Liu, Qiong and Li, Hao and Jin, Rong},
  journal={arXiv preprint arXiv:2112.12786},
  year={2021}
}
Owner
DamoCV
CV team of DAMO academy
DamoCV
Tackling the Class Imbalance Problem of Deep Learning Based Head and Neck Organ Segmentation

Info This is the code repository of the work Tackling the Class Imbalance Problem of Deep Learning Based Head and Neck Organ Segmentation from Elias T

2 Apr 20, 2022
Repo for my Tensorflow/Keras CV experiments. Mostly revolving around the Danbooru20xx dataset

SW-CV-ModelZoo Repo for my Tensorflow/Keras CV experiments. Mostly revolving around the Danbooru20xx dataset Framework: TF/Keras 2.7 Training SQLite D

20 Dec 27, 2022
Set of models for classifcation of 3D volumes

Classification models 3D Zoo - Keras and TF.Keras This repository contains 3D variants of popular CNN models for classification like ResNets, DenseNet

69 Dec 28, 2022
The code for 'Deep Residual Fourier Transformation for Single Image Deblurring'

Deep Residual Fourier Transformation for Single Image Deblurring Xintian Mao, Yiming Liu, Wei Shen, Qingli Li and Yan Wang code will be released soon

145 Dec 13, 2022
Code repository accompanying the paper "On Adversarial Robustness: A Neural Architecture Search perspective"

On Adversarial Robustness: A Neural Architecture Search perspective Preparation: Clone the repository: https://github.com/tdchaitanya/nas-robustness.g

Chaitanya Devaguptapu 4 Nov 10, 2022
Pytorch and Torch testing code of CartoonGAN

CartoonGAN-Test-Pytorch-Torch Pytorch and Torch testing code of CartoonGAN [Chen et al., CVPR18]. With the released pretrained models by the authors,

Yijun Li 642 Dec 27, 2022
Information Gain Filtration (IGF) is a method for filtering domain-specific data during language model finetuning. IGF shows significant improvements over baseline fine-tuning without data filtration.

Information Gain Filtration Information Gain Filtration (IGF) is a method for filtering domain-specific data during language model finetuning. IGF sho

4 Jul 28, 2022
🌎 The Modern Declarative Data Flow Framework for the AI Empowered Generation.

🌎 JSONClasses JSONClasses is a declarative data flow pipeline and data graph framework. Official Website: https://www.jsonclasses.com Official Docume

Fillmula Inc. 53 Dec 09, 2022
Simple-Neural-Network From Scratch in Python

Simple-Neural-Network From Scratch in Python This is a simple Neural Network created without any Machine Learning Libraries. The only dependencies are

Aum Shah 1 Dec 28, 2021
dualFace: Two-Stage Drawing Guidance for Freehand Portrait Sketching (CVMJ)

dualFace dualFace: Two-Stage Drawing Guidance for Freehand Portrait Sketching (CVMJ) We provide python implementations for our CVM 2021 paper "dualFac

Haoran XIE 46 Nov 10, 2022
Official code repository for ICCV 2021 paper: Gravity-Aware Monocular 3D Human Object Reconstruction

GraviCap Official code repository for ICCV 2021 paper: Gravity-Aware Monocular 3D Human Object Reconstruction. Gravity-Aware Monocular 3D Human-Object

Rishabh Dabral 15 Dec 09, 2022
An implementation of the AlphaZero algorithm for Gomoku (also called Gobang or Five in a Row)

AlphaZero-Gomoku This is an implementation of the AlphaZero algorithm for playing the simple board game Gomoku (also called Gobang or Five in a Row) f

Junxiao Song 2.8k Dec 26, 2022
Implementations of polygamma, lgamma, and beta functions for PyTorch

lgamma Implementations of polygamma, lgamma, and beta functions for PyTorch. It's very hacky, but that's usually ok for research use. To build, run: .

Rachit Singh 24 Nov 09, 2021
Prototype python implementation of the ome-ngff table spec

Prototype python implementation of the ome-ngff table spec

Kevin Yamauchi 8 Nov 20, 2022
A repo to show how to use custom dataset to train s2anet, and change backbone to resnext101

A repo to show how to use custom dataset to train s2anet, and change backbone to resnext101

jedibobo 3 Dec 28, 2022
最新版本yolov5+deepsort目标检测和追踪,支持5.0版本可训练自己数据集

使用YOLOv5+Deepsort实现车辆行人追踪和计数,代码封装成一个Detector类,更容易嵌入到自己的项目中。

422 Dec 30, 2022
Public implementation of "Learning from Suboptimal Demonstration via Self-Supervised Reward Regression" from CoRL'21

Self-Supervised Reward Regression (SSRR) Codebase for CoRL 2021 paper "Learning from Suboptimal Demonstration via Self-Supervised Reward Regression "

19 Dec 12, 2022
DeepCO3: Deep Instance Co-segmentation by Co-peak Search and Co-saliency

[CVPR19] DeepCO3: Deep Instance Co-segmentation by Co-peak Search and Co-saliency (Oral paper) Authors: Kuang-Jui Hsu, Yen-Yu Lin, Yung-Yu Chuang PDF:

Kuang-Jui Hsu 139 Dec 22, 2022
Python Implementation of the CoronaWarnApp (CWA) Event Registration

Python implementation of the Corona-Warn-App (CWA) Event Registration This is an implementation of the Protocol used to generate event and location QR

MaZderMind 17 Oct 05, 2022
DIR-GNN - Discovering Invariant Rationales for Graph Neural Networks

DIR-GNN "Discovering Invariant Rationales for Graph Neural Networks" (ICLR 2022)

Ying-Xin (Shirley) Wu 70 Nov 13, 2022