Official codes: Self-Supervised Learning by Estimating Twin Class Distribution

Related tags

Deep LearningTWIST
Overview

TWIST: Self-Supervised Learning by Estimating Twin Class Distributions

Architecture

Codes and pretrained models for TWIST:

@article{wang2021self,
  title={Self-Supervised Learning by Estimating Twin Class Distributions},
  author={Wang, Feng and Kong, Tao and Zhang, Rufeng and Liu, Huaping and Li, Hang},
  journal={arXiv preprint arXiv:2110.07402},
  year={2021}
}

TWIST is a novel self-supervised representation learning method by classifying large-scale unlabeled datasets in an end-to-end way. We employ a siamese network terminated by a softmax operation to produce twin class distributions of two augmented images. Without supervision, we enforce the class distributions of different augmentations to be consistent. In the meantime, we regularize the class distributions to make them sharp and diverse. TWIST can naturally avoid the trivial solutions without specific designs such as asymmetric network, stop-gradient operation, or momentum encoder.

formula

Models and Results

Main Models for Representation Learning

arch params epochs linear download
Model with multi-crop and self-labeling
ResNet-50 24M 850 75.5% backbone only full ckpt args log eval logs
ResNet-50w2 94M 250 77.7% backbone only full ckpt args log eval logs
DeiT-S 21M 300 75.6% backbone only full ckpt args log eval logs
ViT-B 86M 300 77.3% backbone only full ckpt args log eval logs
Model without multi-crop and self-labeling
ResNet-50 24M 800 72.6% backbone only full ckpt args log eval logs

Model for unsupervised classification

arch params epochs NMI AMI ARI ACC download
ResNet-50 24M 800 74.4 57.7 30.1 40.5 backbone only full ckpt args log
Top-3 predictions for unsupervised classification

Top-3

Semi-Supervised Results

arch 1% labels 10% labels 100% labels
resnet-50 61.5% 71.7% 78.4%
resnet-50w2 67.2% 75.3% 80.3%

Detection Results

Task AP all AP 50 AP 75
VOC07+12 detection 58.1 84.2 65.4
COCO detection 41.9 62.6 45.7
COCO instance segmentation 37.9 59.7 40.6

Single-node Training

ResNet-50 (requires 8 GPUs, Top-1 Linear 72.6%)

python3 -m torch.distributed.launch --nproc_per_node=8 --use_env train.py \
  --data-path ${DATAPATH} \
  --output_dir ${OUTPUT} \
  --aug barlow \
  --batch-size 256 \
  --dim 32768 \
  --epochs 800 

Multi-node Training

ResNet-50 (requires 16 GPUs spliting over 2 nodes for multi-crop training, Top-1 Linear 75.5%)

python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \
  --nnodes=${WORKER_NUM} \
  --node_rank=${MACHINE_ID} \
  --master_addr=${HOST} \
  --master_port=${PORT} train.py \
  --data-path ${DATAPATH} \
  --output_dir ${OUTPUT}

ResNet-50w2 (requires 32 GPUs spliting over 4 nodes for multi-crop training, Top-1 Linear 77.7%)

python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \
  --nnodes=${WORKER_NUM} \
  --node_rank=${MACHINE_ID} \
  --master_addr=${HOST} \
  --master_port=${PORT} train.py \
  --data-path ${DATAPATH} \
  --output_dir ${OUTPUT} \
  --backbone 'resnet50w2' \
  --batch-size 60 \
  --bunch-size 240 \
  --epochs 250 \
  --mme_epochs 200 

DeiT-S (requires 16 GPUs spliting over 2 nodes for multi-crop training, Top-1 Linear 75.6%)

python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \
  --nnodes=${WORKER_NUM} \
  --node_rank=${MACHINE_ID} \
  --master_addr=${HOST} \
  --master_port=${PORT} train.py \
  --data-path ${DATAPATH} \
  --output_dir ${OUTPUT} \
  --backbone 'vit_s' \
  --batch-size 128 \
  --bunch-size 256 \
  --clip_norm 3.0 \
  --epochs 300 \
  --mme_epochs 300 \
  --lam1 -0.6 \
  --lam2 1.0 \
  --local_crops_number 6 \
  --lr 0.0005 \
  --momentum_start 0.996 \
  --momentum_end 1.0 \
  --optim admw \
  --use_momentum_encoder 1 \
  --weight_decay 0.06 \
  --weight_decay_end 0.06 

ViT-B (requires 32 GPUs spliting over 4 nodes for multi-crop training, Top-1 Linear 77.3%)

python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \
  --nnodes=${WORKER_NUM} \
  --node_rank=${MACHINE_ID} \
  --master_addr=${HOST} \
  --master_port=${PORT} train.py \
  --data-path ${DATAPATH} \
  --output_dir ${OUTPUT} \
  --backbone 'vit_b' \
  --batch-size 64 \
  --bunch-size 256 \
  --clip_norm 3.0 \
  --epochs 300 \
  --mme_epochs 300 \
  --lam1 -0.6 \
  --lam2 1.0 \
  --local_crops_number 6 \
  --lr 0.00075 \
  --momentum_start 0.996 \
  --momentum_end 1.0 \
  --optim admw \
  --use_momentum_encoder 1 \
  --weight_decay 0.06 \
  --weight_decay_end 0.06 

Linear Classification

For ResNet-50

python3 evaluate.py \
  ${DATAPATH} \
  ${OUTPUT}/checkpoint.pth \
  --weight-decay 0 \
  --checkpoint-dir ${OUTPUT}/linear_multihead/ \
  --batch-size 1024 \
  --val_epoch 1 \
  --lr-classifier 0.2

For DeiT-S

python3 -m torch.distributed.launch --nproc_per_node=8 evaluate_vitlinear.py \
  --arch vit_s \
  --pretrained_weights ${OUTPUT}/checkpoint.pth \
  --lr 0.02 \
  --data_path ${DATAPATH} \
  --output_dir ${OUTPUT} \

For ViT-B

python3 -m torch.distributed.launch --nproc_per_node=8 evaluate_vitlinear.py \
  --arch vit_b \
  --pretrained_weights ${OUTPUT}/checkpoint.pth \
  --lr 0.0015 \
  --data_path ${DATAPATH} \
  --output_dir ${OUTPUT} \

Semi-supervised Learning

Command for training semi-supervised classification

1% Percent (61.5%)

python3 evaluate.py ${DATAPATH} ${MODELPATH} \
  --weights finetune \
  --lr-backbone 0.04 \
  --lr-classifier 0.2 \
  --train-percent 1 \
  --weight-decay 0 \
  --epochs 20 \
  --backbone 'resnet50'

10% Percent (71.7%)

python3 evaluate.py ${DATAPATH} ${MODELPATH} \
  --weights finetune \
  --lr-backbone 0.02 \
  --lr-classifier 0.2 \
  --train-percent 10 \
  --weight-decay 0 \
  --epochs 20 \
  --backbone 'resnet50'

100% Percent (78.4%)

python3 evaluate.py ${DATAPATH} ${MODELPATH} \
  --weights finetune \
  --lr-backbone 0.01 \
  --lr-classifier 0.2 \
  --train-percent 100 \
  --weight-decay 0 \
  --epochs 30 \
  --backbone 'resnet50'

Detection

Instruction

  1. Install detectron2.

  2. Convert a pre-trained MoCo model to detectron2's format:

    python3 detection/convert-pretrain-to-detectron2.py ${MODELPATH} ${OUTPUTPKLPATH}
    
  3. Put dataset under "detection/datasets" directory, following the directory structure requried by detectron2.

  4. Training: VOC

    cd detection/
    python3 train_net.py \
      --config-file voc_fpn_1fc/pascal_voc_R_50_FPN_24k_infomin.yaml \
      --num-gpus 8 \
      MODEL.WEIGHTS ../${OUTPUTPKLPATH}
    

    COCO

    python3 train_net.py \
      --config-file infomin_configs/R_50_FPN_1x_infomin.yaml \
      --num-gpus 8 \
      MODEL.WEIGHTS ../${OUTPUTPKLPATH}
    
Owner
Bytedance Inc.
Bytedance Inc.
PyTorch/TorchScript compiler for NVIDIA GPUs using TensorRT

PyTorch/TorchScript compiler for NVIDIA GPUs using TensorRT

NVIDIA Corporation 1.8k Dec 30, 2022
Tracking Progress in Question Answering over Knowledge Graphs

Tracking Progress in Question Answering over Knowledge Graphs Table of contents Question Answering Systems with Descriptions The QA Systems Table cont

Knowledge Graph Question Answering 47 Jan 02, 2023
Compute descriptors for 3D point cloud registration using a multi scale sparse voxel architecture

MS-SVConv : 3D Point Cloud Registration with Multi-Scale Architecture and Self-supervised Fine-tuning Compute features for 3D point cloud registration

42 Jul 25, 2022
Video Background Music Generation with Controllable Music Transformer (ACM MM 2021 Oral)

CMT Code for paper Video Background Music Generation with Controllable Music Transformer (ACM MM 2021 Best Paper Award) [Paper] [Site] Directory Struc

Zhaokai Wang 198 Dec 27, 2022
Segmentation models with pretrained backbones. PyTorch.

Python library with Neural Networks for Image Segmentation based on PyTorch. The main features of this library are: High level API (just two lines to

Pavel Yakubovskiy 6.6k Jan 06, 2023
Semantic Segmentation for Aerial Imagery using Convolutional Neural Network

This repo has been deprecated because whole things are re-implemented by using Chainer and I did refactoring for many codes. So please check this newe

Shunta Saito 27 Sep 23, 2022
PyTorch Implementation for "ForkGAN with SIngle Rainy NIght Images: Leveraging the RumiGAN to See into the Rainy Night"

ForkGAN with Single Rainy Night Images: Leveraging the RumiGAN to See into the Rainy Night By Seri Lee, Department of Engineering, Seoul National Univ

Seri Lee 52 Oct 12, 2022
Pi-NAS: Improving Neural Architecture Search by Reducing Supernet Training Consistency Shift (ICCV 2021)

Π-NAS This repository provides the evaluation code of our submitted paper: Pi-NAS: Improving Neural Architecture Search by Reducing Supernet Training

Jiqi Zhang 18 Aug 18, 2022
Awesome-google-colab - Google Colaboratory Notebooks and Repositories

Unofficial Google Colaboratory Notebook and Repository Gallery Please contact me to take over and revamp this repo (it gets around 30k views and 200k

Derek Snow 1.2k Jan 03, 2023
Code for the paper: On Pathologies in KL-Regularized Reinforcement Learning from Expert Demonstrations

Non-Parametric Prior Actor-Critic (N-PPAC) This repository contains the code for On Pathologies in KL-Regularized Reinforcement Learning from Expert D

Cong Lu 5 May 13, 2022
Tutorials, assignments, and competitions for MIT Deep Learning related courses.

MIT Deep Learning This repository is a collection of tutorials for MIT Deep Learning courses. More added as courses progress. Tutorial: Deep Learning

Lex Fridman 9.5k Jan 07, 2023
(ICCV 2021) Official code of "Dressing in Order: Recurrent Person Image Generation for Pose Transfer, Virtual Try-on and Outfit Editing."

Dressing in Order (DiOr) 👚 [Paper] 👖 [Webpage] 👗 [Running this code] The official implementation of "Dressing in Order: Recurrent Person Image Gene

Aiyu Cui 277 Dec 28, 2022
CM building dataset Timisoara

CM_building_dataset_Timisoara Date created: Febr-2020 The Timi\c{s}oara Building Dataset - TMBuD - is composed of 160 images with the resolution of 76

Orhei Ciprian 5 Sep 07, 2022
学习 python3 以来写的一些垃圾玩具……

和东哥做兄弟 Author: chiupam 版权 未经本人同意,仓库内所有资源文件,禁止任何公众号、自媒体、开发者进行任何形式的转载、发布、搬运。 声明 这不是一个开源项目,只是把 GitHub 当作一个代码的存储空间,本项目不接受任何开源要求。 仅用于学习研究,禁止用于商业用途,不能保证其合法性

Chiupam 67 Mar 26, 2022
Pytorch code for ICRA'21 paper: "Hierarchical Cross-Modal Agent for Robotics Vision-and-Language Navigation"

Hierarchical Cross-Modal Agent for Robotics Vision-and-Language Navigation This repository is the pytorch implementation of our paper: Hierarchical Cr

43 Nov 21, 2022
PyTorch source code for Distilling Knowledge by Mimicking Features

LSHFM.detection This is the PyTorch source code for Distilling Knowledge by Mimicking Features. And this project contains code for object detection wi

Guo-Hua Wang 4 Dec 17, 2022
This porject is intented to build the most accurate model for predicting the porbability of loan default

Estimating-Loan-Default-Probability IBA ML2 Mid-project / Kaggle Competition This porject is intented to build the most accurate model for predicting

Adil Gahramanov 1 Jan 24, 2022
Using pretrained GROVER to extract the atomic fingerprints from molecule

Extracting atomic fingerprints from molecules using pretrained Graph Neural Network models (GROVER).

Xuan Vu Nguyen 1 Jan 28, 2022
FLSim a flexible, standalone library written in PyTorch that simulates FL settings with a minimal, easy-to-use API

Federated Learning Simulator (FLSim) is a flexible, standalone core library that simulates FL settings with a minimal, easy-to-use API. FLSim is domain-agnostic and accommodates many use cases such a

Meta Research 162 Jan 02, 2023
Melanoma Skin Cancer Detection using Convolutional Neural Networks and Transfer Learning🕵🏻‍♂️

This is a Kaggle competition in which we have to identify if the given lesion image is malignant or not for Melanoma which is a type of skin cancer.

Vipul Shinde 1 Jan 27, 2022