PyTorch evaluation code for Delving Deep into the Generalization of Vision Transformers under Distribution Shifts.

Overview

Out-of-distribution Generalization Investigation on Vision Transformers

This repository contains PyTorch evaluation code for Delving Deep into the Generalization of Vision Transformers under Distribution Shifts.

A Quick Glance of Our Work

A quick glance of our investigation observations. left: Investigation of IID/OOD Generalization Gap implies that ViTs generalize better than CNNs under most types of distribution shifts. right: Combined with generalization-enhancing methods, we achieve significant performance boosts on the OOD data by 4% compared with vanilla ViTs, and consistently outperform the corresponding CNN models. The enhanced ViTs also have smaller IID/OOD Generalization Gap than the ehhanced BiT models.

Taxonomy of Distribution Shifts

Illustration of our taxonomy of distribution shifts. We build the taxonomy upon what kinds of semantic concepts are modified from the original image. We divide the distribution shifts into five cases: background shifts, corruption shifts, texture shifts, destruction shifts, and style shifts. We apply the proxy -distance (PAD) as an empirical measurement of distribution shifts. We select a representative sample of each distribution shift type and rank them by their PAD values (illustrated nearby the stars), respectively. Please refer to the literature for details.

Datasets Used for Investigation

  • Background Shifts. ImageNet-9 is adopted for background shifts. ImageNet-9 is a variety of 9-class datasets with different foreground-background recombination plans, which helps disentangle the impacts of foreground and background signals on classification. In our case, we use the four varieties of generated background with foreground unchanged, including 'Only-FG', 'Mixed-Same', 'Mixed-Rand' and 'Mixed-Next'. The 'Original' data set is used to represent in-distribution data.
  • Corruption Shifts. ImageNet-C is used to examine generalization ability under corruption shifts. ImageNet-C includes 15 types of algorithmically generated corruptions, grouped into 4 categories: ‘noise’, ‘blur’, ‘weather’, and ‘digital’. Each corruption type has five levels of severity, resulting in 75 distinct corruptions.
  • Texture Shifts. Cue Conflict Stimuli and Stylized-ImageNet are used to investigate generalization under texture shifts. Utilizing style transfer, Geirhos et al. generated Cue Conflict Stimuli benchmark with conflicting shape and texture information, that is, the image texture is replaced by another class with other object semantics preserved. In this case, we respectively report the shape and texture accuracy of classifiers for analysis. Meanwhile, Stylized-ImageNet is also produced in Geirhos et al. by replacing textures with the style of randomly selected paintings through AdaIN style transfer.
  • Destruction Shifts. Random patch-shuffling is utilized for destruction shifts to destruct images into random patches. This process can destroy long-range object information and the severity increases as the split numbers grow. In addition, we make a variant by further divide each patch into two right triangles and respectively shuffle two types of triangles. We name the process triangular patch-shuffling.
  • Style Shifts. ImageNet-R and DomainNet are used for the case of style shifts. ImageNet-R contains 30000 images with various artistic renditions of 200 classes of the original ImageNet validation data set. The renditions in ImageNet-R are real-world, naturally occurring variations, such as paintings or embroidery, with textures and local image statistics which differ from those of ImageNet images. DomainNet is a recent benchmark dataset for large-scale domain adaptation that consists of 345 classes and 6 domains. As labels of some domains are very noisy, we follow the 7 distribution shift scenarios in Saito et al. with 4 domains (Real, Clipart, Painting, Sketch) picked.

Generalization-Enhanced Vision Transformers

A framework overview of the three designed generalization-enhanced ViTs. All networks use a Vision Transformer as feature encoder and a label prediction head . Under this setting, the inputs to the models have labeled source examples and unlabeled target examples. top left: T-ADV promotes the network to learn domain-invariant representations by introducing a domain classifier for domain adversarial training. top right: T-MME leverage the minimax process on the conditional entropy of target data to reduce the distribution gap while learning discriminative features for the task. The network uses a cosine similarity-based classifier architecture to produce class prototypes. bottom: T-SSL is an end-to-end prototype-based self-supervised learning framework. The architecture uses two memory banks and to calculate cluster centroids. A cosine classifier is used for classification in this framework.

Run Our Code

Environment Installation

conda create -n vit python=3.6
conda activate vit
conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.0 -c pytorch

Before Running

conda activate vit
PYTHONPATH=$PYTHONPATH:.

Evaluation

CUDA_VISIBLE_DEVICES=0 python main.py \
--model deit_small_b16_384 \
--num-classes 345 \
--checkpoint data/checkpoints/deit_small_b16_384_baseline_real.pth.tar \
--meta-file data/metas/DomainNet/sketch_test.jsonl \
--root-dir data/images/DomainNet/sketch/test

Experimental Results

DomainNet

DeiT_small_b16_384

confusion matrix for the baseline model

clipart painting real sketch
clipart 80.25 33.75 55.26 43.43
painting 36.89 75.32 52.08 31.14
real 50.59 45.81 84.78 39.31
sketch 52.16 35.27 48.19 71.92

Above used models could be found here.

Remarks

  • These results may slightly differ from those in our paper due to differences of the environments.

  • We will continuously update this repo.

Citation

If you find these investigations useful in your research, please consider citing:

@misc{zhang2021delving,  
      title={Delving Deep into the Generalization of Vision Transformers under Distribution Shifts}, 
      author={Chongzhi Zhang and Mingyuan Zhang and Shanghang Zhang and Daisheng Jin and Qiang Zhou and Zhongang Cai and Haiyu Zhao and Shuai Yi and Xianglong Liu and Ziwei Liu},  
      year={2021},  
      eprint={2106.07617},  
      archivePrefix={arXiv},  
      primaryClass={cs.CV}  
}
Owner
Chongzhi Zhang
I am a Master Degree Candidate student, from Beihang University.
Chongzhi Zhang
[TIP2020] Adaptive Graph Representation Learning for Video Person Re-identification

Introduction This is the PyTorch implementation for Adaptive Graph Representation Learning for Video Person Re-identification. Get started git clone h

WuYiming 41 Dec 12, 2022
190 Jan 03, 2023
Histology images query (unsupervised)

110-1-NTU-DBME5028-Histology-images-query Final Project: Histology images query (unsupervised) Kaggle: https://www.kaggle.com/c/histology-images-query

1 Jan 05, 2022
Pytorch implementation of the paper Time-series Generative Adversarial Networks

TimeGAN-pytorch Pytorch implementation of the paper Time-series Generative Adversarial Networks presented at NeurIPS'19. Jinsung Yoon, Daniel Jarrett

Zhiwei ZHANG 21 Nov 24, 2022
SweiNet is an uncertainty-quantifying shear wave speed (SWS) estimator for ultrasound shear wave elasticity (SWE) imaging.

SweiNet SweiNet is an uncertainty-quantifying shear wave speed (SWS) estimator for ultrasound shear wave elasticity (SWE) imaging. SweiNet takes as in

Felix Jin 3 Mar 31, 2022
MM1 and MMC Queue Simulation using python - Results and parameters in excel and csv files

implementation of MM1 and MMC Queue on randomly generated data and evaluate simulation results then compare with analytical results and draw a plot curve for them, simulate some integrals and compare

Mohamadreza Rezaei 1 Jan 19, 2022
Source code for "FastBERT: a Self-distilling BERT with Adaptive Inference Time".

FastBERT Source code for "FastBERT: a Self-distilling BERT with Adaptive Inference Time". Good News 2021/10/29 - Code: Code of FastPLM is released on

Weijie Liu 584 Jan 02, 2023
PyTorch implementation of the ideas presented in the paper Interaction Grounded Learning (IGL)

Interaction Grounded Learning This repository contains a simple PyTorch implementation of the ideas presented in the paper Interaction Grounded Learni

Arthur Juliani 4 Aug 31, 2022
Angle data is a simple data type.

angledat Angle data is a simple data type. Installing + using Put angledat.py in the main dir of your project. Import it and use. Comments Comments st

1 Jan 05, 2022
Multiple-Object Tracking with Transformer

TransTrack: Multiple-Object Tracking with Transformer Introduction TransTrack: Multiple-Object Tracking with Transformer Models Training data Training

Peize Sun 537 Jan 04, 2023
An image base contains 490 images for learning (400 cars and 90 boats), and another 21 images for testingAn image base contains 490 images for learning (400 cars and 90 boats), and another 21 images for testing

SVM Données Une base d’images contient 490 images pour l’apprentissage (400 voitures et 90 bateaux), et encore 21 images pour fait des tests. Prétrait

Achraf Rahouti 3 Nov 30, 2021
KE-Dialogue: Injecting knowledge graph into a fully end-to-end dialogue system.

Learning Knowledge Bases with Parameters for Task-Oriented Dialogue Systems This is the implementation of the paper: Learning Knowledge Bases with Par

CAiRE 42 Nov 10, 2022
A Pytorch implementation of "LegoNet: Efficient Convolutional Neural Networks with Lego Filters" (ICML 2019).

LegoNet This code is the implementation of ICML2019 paper LegoNet: Efficient Convolutional Neural Networks with Lego Filters Run python train.py You c

YangZhaohui 140 Sep 26, 2022
Compressed Video Action Recognition

Compressed Video Action Recognition Chao-Yuan Wu, Manzil Zaheer, Hexiang Hu, R. Manmatha, Alexander J. Smola, Philipp Krähenbühl. In CVPR, 2018. [Proj

Chao-Yuan Wu 479 Dec 26, 2022
An example to implement a new backbone with OpenMMLab framework.

Backbone example on OpenMMLab framework English | 简体中文 Introduction This is an template repo about how to use OpenMMLab framework to develop a new bac

Ma Zerun 22 Dec 29, 2022
Yolo Traffic Light Detection With Python

Yolo-Traffic-Light-Detection This project is based on detecting the Traffic light. Pretained data is used. This application entertained both real time

Ananta Raj Pant 2 Aug 08, 2022
Exploiting a Zoo of Checkpoints for Unseen Tasks

Exploiting a Zoo of Checkpoints for Unseen Tasks This repo includes code to reproduce all results in the above Neurips paper, authored by Jiaji Huang,

Baidu Research 8 Sep 06, 2022
Codes for 'Dual Parameterization of Sparse Variational Gaussian Processes'

Dual Parameterization of Sparse Variational Gaussian Processes Documentation | Notebooks | API reference Introduction This repository is the official

AaltoML 7 Dec 23, 2022
A pytorch reprelication of the model-based reinforcement learning algorithm MBPO

Overview This is a re-implementation of the model-based RL algorithm MBPO in pytorch as described in the following paper: When to Trust Your Model: Mo

Xingyu Lin 93 Jan 05, 2023
Learn other languages ​​using artificial intelligence with python.

The main idea of ​​the project is to facilitate the learning of other languages. We created a simple AI that will interact with you. Just ask questions that if she knows, she will answer.

Pedro Rodrigues 2 Jun 07, 2022