PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

Overview

SupContrast: Supervised Contrastive Learning

This repo covers an reference implementation for the following papers in PyTorch, using CIFAR as an illustrative example:
(1) Supervised Contrastive Learning. Paper
(2) A Simple Framework for Contrastive Learning of Visual Representations. Paper

Loss Function

The loss function SupConLoss in losses.py takes features (L2 normalized) and labels as input, and return the loss. If labels is None or not passed to the it, it degenerates to SimCLR.

Usage:

from losses import SupConLoss

# define loss with a temperature `temp`
criterion = SupConLoss(temperature=temp)

# features: [bsz, n_views, f_dim]
# `n_views` is the number of crops from each image
# better be L2 normalized in f_dim dimension
features = ...
# labels: [bsz]
labels = ...

# SupContrast
loss = criterion(features, labels)
# or SimCLR
loss = criterion(features)
...

Comparison

Results on CIFAR-10:

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy 95.0
SupContrast ResNet50 Supervised Contrastive 96.0
SimCLR ResNet50 Unsupervised Contrastive 93.6

Results on CIFAR-100:

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy 75.3
SupContrast ResNet50 Supervised Contrastive 76.5
SimCLR ResNet50 Unsupervised Contrastive 70.7

Results on ImageNet (Stay tuned):

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy -
SupContrast ResNet50 Supervised Contrastive 79.1 (MoCo trick)
SimCLR ResNet50 Unsupervised Contrastive -

Running

You might use CUDA_VISIBLE_DEVICES to set proper number of GPUs, and/or switch to CIFAR100 by --dataset cifar100.
(1) Standard Cross-Entropy

python main_ce.py --batch_size 1024 \
  --learning_rate 0.8 \
  --cosine --syncBN \

(2) Supervised Contrastive Learning
Pretraining stage:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5 \
  --temp 0.1 \
  --cosine

You can also specify --syncBN but I found it not crucial for SupContrast (syncBN 95.9% v.s. BN 96.0%).
Linear evaluation stage:

python main_linear.py --batch_size 512 \
  --learning_rate 5 \
  --ckpt /path/to/model.pth

(3) SimCLR
Pretraining stage:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5 \
  --temp 0.5 \
  --cosine --syncBN \
  --method SimCLR

The --method SimCLR flag simply stops labels from being passed to SupConLoss criterion. Linear evaluation stage:

python main_linear.py --batch_size 512 \
  --learning_rate 1 \
  --ckpt /path/to/model.pth

On custom dataset:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5  \ 
  --temp 0.1 --cosine \
  --dataset path \
  --data_folder ./path \
  --mean "(0.4914, 0.4822, 0.4465)" \
  --std "(0.2675, 0.2565, 0.2761)" \
  --method SimCLR

The --data_folder must be of form ./path/label/xxx.png folowing https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.ImageFolder convension.

and

t-SNE Visualization

(1) Standard Cross-Entropy

(2) Supervised Contrastive Learning

(3) SimCLR

Reference

@Article{khosla2020supervised,
    title   = {Supervised Contrastive Learning},
    author  = {Prannay Khosla and Piotr Teterwak and Chen Wang and Aaron Sarna and Yonglong Tian and Phillip Isola and Aaron Maschinot and Ce Liu and Dilip Krishnan},
    journal = {arXiv preprint arXiv:2004.11362},
    year    = {2020},
}
Owner
Yonglong Tian
CS Ph.D. student in AI @ MIT
Yonglong Tian
Regularizing Generative Adversarial Networks under Limited Data (CVPR 2021)

Regularizing Generative Adversarial Networks under Limited Data [Project Page][Paper] Implementation for our GAN regularization method. The proposed r

Google 148 Nov 18, 2022
TigerLily: Finding drug interactions in silico with the Graph.

Drug Interaction Prediction with Tigerlily Documentation | Example Notebook | Youtube Video | Project Report Tigerlily is a TigerGraph based system de

Benedek Rozemberczki 91 Dec 30, 2022
DPT: Deformable Patch-based Transformer for Visual Recognition (ACM MM2021)

DPT This repo is the official implementation of DPT: Deformable Patch-based Transformer for Visual Recognition (ACM MM2021). We provide code and model

CASIA-IVA-Lab 111 Dec 21, 2022
Official code for the ICCV 2021 paper "DECA: Deep viewpoint-Equivariant human pose estimation using Capsule Autoencoders"

DECA Official code for the ICCV 2021 paper "DECA: Deep viewpoint-Equivariant human pose estimation using Capsule Autoencoders". All the code is writte

23 Dec 01, 2022
Simple, efficient and flexible vision toolbox for mxnet framework.

MXbox: Simple, efficient and flexible vision toolbox for mxnet framework. MXbox is a toolbox aiming to provide a general and simple interface for visi

Ligeng Zhu 31 Oct 19, 2019
Code and data for paper "Deep Photo Style Transfer"

deep-photo-styletransfer Code and data for paper "Deep Photo Style Transfer" Disclaimer This software is published for academic and non-commercial use

Fujun Luan 9.9k Dec 29, 2022
MixRNet(Using mixup as regularization and tuning hyper-parameters for ResNets)

MixRNet(Using mixup as regularization and tuning hyper-parameters for ResNets) Using mixup data augmentation as reguliraztion and tuning the hyper par

Bhanu 2 Jan 16, 2022
A memory-efficient implementation of DenseNets

efficient_densenet_pytorch A PyTorch =1.0 implementation of DenseNets, optimized to save GPU memory. Recent updates Now works on PyTorch 1.0! It uses

Geoff Pleiss 1.4k Dec 25, 2022
Notspot robot simulation - Python version

Notspot robot simulation - Python version This repository contains all the files and code needed to simulate the notspot quadrupedal robot using Gazeb

50 Sep 26, 2022
Implementation and replication of ProGen, Language Modeling for Protein Generation, in Jax

ProGen - (wip) Implementation and replication of ProGen, Language Modeling for Protein Generation, in Pytorch and Jax (the weights will be made easily

Phil Wang 71 Dec 01, 2022
Pytorch cuda extension of grid_sample1d

Grid Sample 1d pytorch cuda extension of grid sample 1d. Since pytorch only supports grid sample 2d/3d, I extend the 1d version for efficiency. The fo

lyricpoem 24 Dec 03, 2022
Code for CVPR2019 paper《Unequal Training for Deep Face Recognition with Long Tailed Noisy Data》

Unequal-Training-for-Deep-Face-Recognition-with-Long-Tailed-Noisy-Data. This is the code of CVPR 2019 paper《Unequal Training for Deep Face Recognition

Zhong Yaoyao 68 Jan 07, 2023
Tools for robust generative diffeomorphic slice to volume reconstruction

RGDSVR Tools for Robust Generative Diffeomorphic Slice to Volume Reconstructions (RGDSVR) This repository provides tools to implement the methods in t

Lucilio Cordero-Grande 0 Oct 29, 2021
Banglore House Prediction Using Flask Server (Python)

Banglore House Prediction Using Flask Server (Python) 🌐 Links 🌐 📂 Repo In this repository, I've implemented a Machine Learning-based Bangalore Hous

Dhyan Shah 1 Jan 24, 2022
An AutoML Library made with Optuna and PyTorch Lightning

An AutoML Library made with Optuna and PyTorch Lightning Installation Recommended pip install -U gradsflow From source pip install git+https://github.

GradsFlow 294 Dec 17, 2022
Codebase for Attentive Neural Hawkes Process (A-NHP) and Attentive Neural Datalog Through Time (A-NDTT)

Introduction Codebase for the paper Transformer Embeddings of Irregularly Spaced Events and Their Participants. This codebase contains two packages: a

Alan Yang 28 Dec 12, 2022
Sub-tomogram-Detection - Deep learning based model for Cyro ET Sub-tomogram-Detection

Deep learning based model for Cyro ET Sub-tomogram-Detection High degree of stru

Siddhant Kumar 2 Feb 04, 2022
RSC-Net: 3D Human Pose, Shape and Texture from Low-Resolution Images and Videos

RSC-Net: 3D Human Pose, Shape and Texture from Low-Resolution Images and Videos Implementation for "3D Human Pose, Shape and Texture from Low-Resoluti

XiangyuXu 42 Nov 10, 2022