PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

Overview

SupContrast: Supervised Contrastive Learning

This repo covers an reference implementation for the following papers in PyTorch, using CIFAR as an illustrative example:
(1) Supervised Contrastive Learning. Paper
(2) A Simple Framework for Contrastive Learning of Visual Representations. Paper

Loss Function

The loss function SupConLoss in losses.py takes features (L2 normalized) and labels as input, and return the loss. If labels is None or not passed to the it, it degenerates to SimCLR.

Usage:

from losses import SupConLoss

# define loss with a temperature `temp`
criterion = SupConLoss(temperature=temp)

# features: [bsz, n_views, f_dim]
# `n_views` is the number of crops from each image
# better be L2 normalized in f_dim dimension
features = ...
# labels: [bsz]
labels = ...

# SupContrast
loss = criterion(features, labels)
# or SimCLR
loss = criterion(features)
...

Comparison

Results on CIFAR-10:

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy 95.0
SupContrast ResNet50 Supervised Contrastive 96.0
SimCLR ResNet50 Unsupervised Contrastive 93.6

Results on CIFAR-100:

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy 75.3
SupContrast ResNet50 Supervised Contrastive 76.5
SimCLR ResNet50 Unsupervised Contrastive 70.7

Results on ImageNet (Stay tuned):

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy -
SupContrast ResNet50 Supervised Contrastive 79.1 (MoCo trick)
SimCLR ResNet50 Unsupervised Contrastive -

Running

You might use CUDA_VISIBLE_DEVICES to set proper number of GPUs, and/or switch to CIFAR100 by --dataset cifar100.
(1) Standard Cross-Entropy

python main_ce.py --batch_size 1024 \
  --learning_rate 0.8 \
  --cosine --syncBN \

(2) Supervised Contrastive Learning
Pretraining stage:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5 \
  --temp 0.1 \
  --cosine

You can also specify --syncBN but I found it not crucial for SupContrast (syncBN 95.9% v.s. BN 96.0%).
Linear evaluation stage:

python main_linear.py --batch_size 512 \
  --learning_rate 5 \
  --ckpt /path/to/model.pth

(3) SimCLR
Pretraining stage:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5 \
  --temp 0.5 \
  --cosine --syncBN \
  --method SimCLR

The --method SimCLR flag simply stops labels from being passed to SupConLoss criterion. Linear evaluation stage:

python main_linear.py --batch_size 512 \
  --learning_rate 1 \
  --ckpt /path/to/model.pth

On custom dataset:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5  \ 
  --temp 0.1 --cosine \
  --dataset path \
  --data_folder ./path \
  --mean "(0.4914, 0.4822, 0.4465)" \
  --std "(0.2675, 0.2565, 0.2761)" \
  --method SimCLR

The --data_folder must be of form ./path/label/xxx.png folowing https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.ImageFolder convension.

and

t-SNE Visualization

(1) Standard Cross-Entropy

(2) Supervised Contrastive Learning

(3) SimCLR

Reference

@Article{khosla2020supervised,
    title   = {Supervised Contrastive Learning},
    author  = {Prannay Khosla and Piotr Teterwak and Chen Wang and Aaron Sarna and Yonglong Tian and Phillip Isola and Aaron Maschinot and Ce Liu and Dilip Krishnan},
    journal = {arXiv preprint arXiv:2004.11362},
    year    = {2020},
}
Owner
Yonglong Tian
CS Ph.D. student in AI @ MIT
Yonglong Tian
GPOEO is a micro-intrusive GPU online energy optimization framework for iterative applications

GPOEO GPOEO is a micro-intrusive GPU online energy optimization framework for iterative applications. We also implement ODPP [1] as a comparison. [1]

瑞雪轻飏 8 Sep 10, 2022
Prometheus Exporter for data scraped from datenplattform.darmstadt.de

darmstadt-opendata-exporter Scrapes data from https://datenplattform.darmstadt.de and presents it in the Prometheus Exposition format. Pull requests w

Martin Weinelt 2 Apr 12, 2022
Self-supervised Product Quantization for Deep Unsupervised Image Retrieval - ICCV2021

Self-supervised Product Quantization for Deep Unsupervised Image Retrieval Pytorch implementation of SPQ Accepted to ICCV 2021 - paper Young Kyun Jang

Young Kyun Jang 71 Dec 27, 2022
Rlmm blender toolkit - A set of tools to streamline level generation in UDK straight from Blender

rlmm_blender_toolkit A set of tools to streamline level generation in UDK straig

Rocket League Mapmaking 0 Jan 15, 2022
Code for Greedy Gradient Ensemble for Visual Question Answering (ICCV 2021, Oral)

Greedy Gradient Ensemble for De-biased VQA Code release for "Greedy Gradient Ensemble for Robust Visual Question Answering" (ICCV 2021, Oral). GGE can

21 Jun 29, 2022
Pose Detection and Machine Learning for real-time body posture analysis during exercise to provide audiovisual feedback on improvement of form.

Posture: Pose Tracking and Machine Learning for prescribing corrective suggestions to improve posture and form while exercising. This repository conta

Pratham Mehta 10 Nov 11, 2022
GRaNDPapA: Generator of Rad Names from Decent Paper Acronyms

GRaNDPapA: Generator of Rad Names from Decent Paper Acronyms Trying to publish a new machine learning model and can't write a decent title for your pa

264 Nov 08, 2022
The first dataset on shadow generation for the foreground object in real-world scenes.

Object-Shadow-Generation-Dataset-DESOBA Object Shadow Generation is to deal with the shadow inconsistency between the foreground object and the backgr

BCMI 105 Dec 30, 2022
Implementation for paper: Self-Regulation for Semantic Segmentation

Self-Regulation for Semantic Segmentation This is the PyTorch implementation for paper Self-Regulation for Semantic Segmentation, ICCV 2021. Citing SR

Dong ZHANG 30 Nov 21, 2022
Code release for "BoxeR: Box-Attention for 2D and 3D Transformers"

BoxeR By Duy-Kien Nguyen, Jihong Ju, Olaf Booij, Martin R. Oswald, Cees Snoek. This repository is an official implementation of the paper BoxeR: Box-A

Nguyen Duy Kien 111 Dec 07, 2022
Torchlight2 lan game server tool - A message forwarding tool for Torchlight 2 lan game

Torchlight 2 Lan Game Server Tool A message forwarding tool for Torchlight 2 lan

Huaijun Jiang 3 Nov 01, 2022
Easily Process a Batch of Cox Models

ezcox: Easily Process a Batch of Cox Models The goal of ezcox is to operate a batch of univariate or multivariate Cox models and return tidy result. ⏬

Shixiang Wang 15 May 23, 2022
PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

Yonglong Tian 2.2k Jan 08, 2023
Object detection (YOLO) with pytorch, OpenCV and python

Real Time Object/Face Detection Using YOLO-v3 This project implements a real time object and face detection using YOLO algorithm. You only look once,

1 Aug 04, 2022
PyTorch implementation of MoCo: Momentum Contrast for Unsupervised Visual Representation Learning

MoCo: Momentum Contrast for Unsupervised Visual Representation Learning This is a PyTorch implementation of the MoCo paper: @Article{he2019moco, aut

Meta Research 3.7k Jan 02, 2023
Face Recognition & AI Based Smart Attendance Monitoring System.

In today’s generation, authentication is one of the biggest problems in our society. So, one of the most known techniques used for authentication is h

Sagar Saha 1 Jan 14, 2022
This repo provides the official code for TransBTS: Multimodal Brain Tumor Segmentation Using Transformer (https://arxiv.org/pdf/2103.04430.pdf).

TransBTS: Multimodal Brain Tumor Segmentation Using Transformer This repo is the official implementation for TransBTS: Multimodal Brain Tumor Segmenta

Raymond 247 Dec 28, 2022
This is the repository for The Machine Learning Workshops, published by AI DOJO

This is the repository for The Machine Learning Workshops, published by AI DOJO. It contains all the workshop's code with supporting project files necessary to work through the code.

AI Dojo 12 May 06, 2022
Pytorch Implementation of Auto-Compressing Subset Pruning for Semantic Image Segmentation

Pytorch Implementation of Auto-Compressing Subset Pruning for Semantic Image Segmentation Introduction ACoSP is an online pruning algorithm that compr

Merantix 8 Dec 07, 2022
Discord Multi Tool that focuses on design and easy usage

Multi-Tool-v1.0 Discord Multi Tool that focuses on design and easy usage Delete webhook Block all friends Spam webhook Modify webhook Webhook info Tok

Lodi#0001 24 May 23, 2022