A PyTorch implementation of Mugs proposed by our paper "Mugs: A Multi-Granular Self-Supervised Learning Framework".

Related tags

Deep Learningmugs
Overview

Mugs: A Multi-Granular Self-Supervised Learning Framework

This is a PyTorch implementation of Mugs proposed by our paper "Mugs: A Multi-Granular Self-Supervised Learning Framework". arXiv

PWC

Overall framework of Mugs.

Fig 1. Overall framework of Mugs. In (a), for each image, two random crops of one image are fed into backbones of student and teacher. Three granular supervisions: 1) instance discrimination supervision, 2) local-group discrimination supervision, and 3) group discrimination supervision, are adopted to learn multi-granular representation. In (b), local-group modules in student/teacher averages all patch tokens, and finds top-k neighbors from memory buffer to aggregate them with the average for obtaining a local-group feature.

Pretrained models on ImageNet-1K

You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks.

Table 1. KNN and linear probing performance with their corresponding hyper-parameters, logs and model weights.

arch params pretraining epochs k-nn linear download
ViT-S/16 21M 100 72.3% 76.4% backbone only full ckpt args logs eval logs
ViT-S/16 21M 300 74.8% 78.2% backbone only full ckpt args logs eval logs
ViT-S/16 21M 800 75.6% 78.9% backbone only full ckpt args logs eval logs
ViT-B/16 85M 400 78.0% 80.6% backbone only full ckpt args logs eval logs
ViT-L/16 307M 250 80.3% 82.1% backbone only full ckpt args logs eval logs
Comparison of linear probing accuracy on ImageNet-1K.

Fig 2. Comparison of linear probing accuracy on ImageNet-1K.

Pretraining Settings

Environment

For reproducing, please install PyTorch and download the ImageNet dataset. This codebase has been developed with python version 3.8, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. For the full environment, please refer to our Dockerfile file.

ViT pretraining 🍺

To pretraining each model, please find the exact hyper-parameter settings at the args column of Table 1. For training log and linear probing log, please refer to the log and eval logs column of Table 1.

ViT-Small pretraining:

To run ViT-small for 100 epochs, we use two nodes of total 8 A100 GPUs (total 512 minibatch size) by using following command:

python -m torch.distributed.launch --nproc_per_node=8 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_small 
--group_teacher_temp 0.04 --group_warmup_teacher_temp_epochs 0 --weight_decay_end 0.2 --norm_last_layer false --epochs 100

To run ViT-small for 300 epochs, we use two nodes of total 16 A100 GPUs (total 1024 minibatch size) by using following command:

python -m torch.distributed.launch --nproc_per_node=16 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_small 
--group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 30 --weight_decay_end 0.1 --norm_last_layer false --epochs 300

To run ViT-small for 800 epochs, we use two nodes of total 16 A100 GPUs (total 1024 minibatch size) by using following command:

python -m torch.distributed.launch --nproc_per_node=16 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_small 
--group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 30 --weight_decay_end 0.1 --norm_last_layer false --epochs 800

ViT-Base pretraining:

To run ViT-base for 400 epochs, we use two nodes of total 24 A100 GPUs (total 1024 minibatch size) by using following command:

python -m torch.distributed.launch --nproc_per_node=24 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_base 
--group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 50 --min_lr 2e-06 --weight_decay_end 0.1 --freeze_last_layer 3 --norm_last_layer 
false --epochs 400

ViT-Large pretraining:

To run ViT-large for 250 epochs, we use two nodes of total 40 A100 GPUs (total 640 minibatch size) by using following command:

python -m torch.distributed.launch --nproc_per_node=40 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_large 
--lr 0.0015 --min_lr 1.5e-4 --group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 50 --weight_decay 0.025 
--weight_decay_end 0.08 --norm_last_layer true --drop_path_rate 0.3 --freeze_last_layer 3 --epochs 250

Evaluation

We are cleaning up the evalutation code and will release them when they are ready.

Self-attention visualization

Here we provide the self-attention map of the [CLS] token on the heads of the last layer

Self-attention from a ViT-Base/16 trained with Mugs

Fig 3. Self-attention from a ViT-Base/16 trained with Mugs.

T-SNE visualization

Here we provide the T-SNE visualization of the learned feature by ViT-B/16. We show the fish classes in ImageNet-1K, i.e., the first six classes, including tench, goldfish, white shark, tiger shark, hammerhead, electric ray. See more examples in Appendix.

T-SNE visualization of the learned feature by ViT-B/16.

Fig 4. T-SNE visualization of the learned feature by ViT-B/16.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Citation

If you find this repository useful, please consider giving a star and citation 🍺 :

@inproceedings{mugs2022SSL,
  title={Mugs: A Multi-Granular Self-Supervised Learning Framework},
  author={Pan Zhou and Yichen Zhou and Chenyang Si and Weihao Yu and Teck Khim Ng and Shuicheng Yan},
  booktitle={arXiv preprint arXiv:2203.14415},
  year={2022}
}
Owner
Sea AI Lab
Sea AI Lab
Greedy Gaussian Segmentation

GGS Greedy Gaussian Segmentation (GGS) is a Python solver for efficiently segmenting multivariate time series data. For implementation details, please

Stanford University Convex Optimization Group 72 Dec 07, 2022
WRENCH: Weak supeRvision bENCHmark

🔧 What is it? Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development

Jieyu Zhang 176 Dec 28, 2022
Code-free deep segmentation for computational pathology

NoCodeSeg: Deep segmentation made easy! This is the official repository for the manuscript "Code-free development and deployment of deep segmentation

André Pedersen 26 Nov 23, 2022
PyTorch implementation for View-Guided Point Cloud Completion

PyTorch implementation for View-Guided Point Cloud Completion

22 Jan 04, 2023
Learning High-Speed Flight in the Wild

Learning High-Speed Flight in the Wild This repo contains the code associated to the paper Learning Agile Flight in the Wild. For more information, pl

Robotics and Perception Group 391 Dec 29, 2022
Iterative Training: Finding Binary Weight Deep Neural Networks with Layer Binarization

Iterative Training: Finding Binary Weight Deep Neural Networks with Layer Binarization This repository contains the source code for the paper (link wi

Rakuten Group, Inc. 0 Nov 19, 2021
Tutorial in Python targeted at Epidemiologists. Will discuss the basics of analysis in Python 3

Python-for-Epidemiologists This repository is an introduction to epidemiology analyses in Python. Additionally, the tutorials for my library zEpid are

Paul Zivich 120 Nov 17, 2022
A data-driven approach to quantify the value of classifiers in a machine learning ensemble.

Documentation | External Resources | Research Paper Shapley is a Python library for evaluating binary classifiers in a machine learning ensemble. The

Benedek Rozemberczki 188 Dec 29, 2022
CKD - Collaborative Knowledge Distillation for Heterogeneous Information Network Embedding

Collaborative Knowledge Distillation for Heterogeneous Information Network Embed

zhousheng 9 Dec 05, 2022
Weakly Supervised Learning of Instance Segmentation with Inter-pixel Relations, CVPR 2019 (Oral)

Weakly Supervised Learning of Instance Segmentation with Inter-pixel Relations The code of: Weakly Supervised Learning of Instance Segmentation with I

Jiwoon Ahn 472 Dec 29, 2022
Supporting code for "Autoregressive neural-network wavefunctions for ab initio quantum chemistry".

naqs-for-quantum-chemistry This repository contains the codebase developed for the paper Autoregressive neural-network wavefunctions for ab initio qua

Tom Barrett 24 Dec 23, 2022
😊 Python module for face feature changing

PyWarping Python module for face feature changing Installation pip install pywarping If you get an error: No such file or directory: 'cmake': 'cmake',

Dopevog 10 Sep 10, 2021
Python code for the paper How to scale hyperparameters for quickshift image segmentation

How to scale hyperparameters for quickshift image segmentation Python code for the paper How to scale hyperparameters for quickshift image segmentatio

0 Jan 25, 2022
Nodule Generation Algorithm Baseline and template code for node21 generation track

Nodule Generation Algorithm This codebase implements a simple baseline model, by following the main steps in the paper published by Litjens et al. for

node21challenge 10 Apr 21, 2022
Contextualized Perturbation for Textual Adversarial Attack, NAACL 2021

Contextualized Perturbation for Textual Adversarial Attack Introduction This is a PyTorch implementation of Contextualized Perturbation for Textual Ad

cookielee77 30 Jan 01, 2023
Python script for performing depth completion from sparse depth and rgb images using the msg_chn_wacv20. model in Tensorflow Lite.

TFLite-msg_chn_wacv20-depth-completion Python script for performing depth completion from sparse depth and rgb images using the msg_chn_wacv20. model

Ibai Gorordo 2 Oct 04, 2021
Re-implementation of the Noise Contrastive Estimation algorithm for pyTorch, following "Noise-contrastive estimation: A new estimation principle for unnormalized statistical models." (Gutmann and Hyvarinen, AISTATS 2010)

Noise Contrastive Estimation for pyTorch Overview This repository contains a re-implementation of the Noise Contrastive Estimation algorithm, implemen

Denis Emelin 42 Nov 24, 2022
Instance-wise Occlusion and Depth Orders in Natural Scenes (CVPR 2022)

Instance-wise Occlusion and Depth Orders in Natural Scenes Official source code. Appears at CVPR 2022 This repository provides a new dataset, named In

27 Dec 27, 2022
Implementation for Homogeneous Unbalanced Regularized Optimal Transport

HUROT: An Homogeneous formulation of Unbalanced Regularized Optimal Transport. This repository provides code related to this preprint. This is an alph

Théo Lacombe 1 Feb 17, 2022
Rethinking of Pedestrian Attribute Recognition: A Reliable Evaluation under Zero-Shot Pedestrian Identity Setting

Pytorch Pedestrian Attribute Recognition: A strong PyTorch baseline of pedestrian attribute recognition and multi-label classification.

Jian 79 Dec 18, 2022