An essential implementation of BYOL in PyTorch + PyTorch Lightning

Overview

Essential BYOL

A simple and complete implementation of Bootstrap your own latent: A new approach to self-supervised Learning in PyTorch + PyTorch Lightning.

Good stuff:

  • good performance (~67% linear eval accuracy on CIFAR100)
  • minimal code, easy to use and extend
  • multi-GPU / TPU and AMP support provided by PyTorch Lightning
  • ImageNet support (needs testing)
  • linear evaluation is performed during training without any additional forward pass
  • logging with Wandb

Performance

Linear Evaluation Accuracy

Here is the accuracy after training for 1000 epochs:

Dataset [email protected] [email protected]
CIFAR10 91.1% 99.8%
CIFAR100 67.0% 90.5%

Training and Validation Curves

CIFAR10

CIFAR100

Environment

conda create --name essential-byol python=3.8
conda activate essential-byol
conda install pytorch=1.7.1 torchvision=0.8.2 cudatoolkit=XX.X -c pytorch
pip install pytorch-lightning==1.1.6 pytorch-lightning-bolts==0.3 wandb opencv-python

The code has been tested using these versions of the packages, but it will probably work with slightly different environments as well. When your run the code (see below for commands), PyTorch Lightning will probably throw a warning, advising you to install additional packages as gym, sklearn and matplotlib. They are not needed for this implementation to work, but you can install them to get rid of the warnings.

Datasets

Three datasets are supported:

  • CIFAR10
  • CIFAR100
  • ImageNet

For imagenet you need to pass the appropriate --data_dir, while for CIFAR you can just pass --download to download the dataset.

Commands

The repo comes with minimal model specific arguments, check main.py for info. We also support all the arguments of the PyTorch Lightning trainer. Default parameters are optimized for CIFAR100 but can also be used for CIFAR10.

Sample commands for running CIFAR100 on a single GPU setup:

python main.py \
    --gpus 1 \
    --dataset CIFAR100 \
    --batch_size 256 \
    --max_epochs 1000 \
    --arch resnet18 \
    --precision 16 \
    --comment wandb-comment

and multi-GPU setup:

python main.py \
    --gpus 2 \
    --distributed_backend ddp \
    --sync_batchnorm \
    --dataset CIFAR100 \
    --batch_size 256 \
    --max_epochs 1000 \
    --arch resnet18 \
    --precision 16 \
    --comment wandb-comment

Logging

Logging is performed with Wandb, please create an account, and follow the configuration steps in the terminal. You can pass your username using --entity. Training and validation stats are logged at every epoch. If you want to completely disable logging use --offline.

Contribute

Help is appreciated. Stuff that needs work:

  • test ImageNet performance
  • exclude bias and bn from LARS adaptation (see comments in the code)
Owner
Enrico Fini
PhD Student at University of Trento
Enrico Fini
GrabGpu_py: a scripts for grab gpu when gpu is free

GrabGpu_py a scripts for grab gpu when gpu is free. WaitCondition: gpu_memory

tianyuluan 3 Jun 18, 2022
This is a work in progress reimplementation of Instant Neural Graphics Primitives

Neural Hash Encoding This is a work in progress reimplementation of Instant Neural Graphics Primitives Currently this can train an implicit representa

Penn 79 Sep 01, 2022
[CVPR 2021] "Multimodal Motion Prediction with Stacked Transformers": official code implementation and project page.

mmTransformer Introduction This repo is official implementation for mmTransformer in pytorch. Currently, the core code of mmTransformer is implemented

DeciForce: Crossroads of Machine Perception and Autonomy 232 Dec 31, 2022
A small fun project using python OpenCV, mediapipe, and pydirectinput

Here I tried a small fun project using python OpenCV, mediapipe, and pydirectinput. Here we can control moves car game when yellow color come to right box (press key 'd') left box (press key 'a') lef

Sameh Elisha 3 Nov 17, 2022
[CVPR 2022] CoTTA Code for our CVPR 2022 paper Continual Test-Time Domain Adaptation

CoTTA Code for our CVPR 2022 paper Continual Test-Time Domain Adaptation Prerequisite Please create and activate the following conda envrionment. To r

Qin Wang 87 Jan 08, 2023
Complete U-net Implementation with keras

U Net Lowered with Keras Complete U-net Implementation with keras Original Paper Link : https://arxiv.org/abs/1505.04597 Special Implementations : The

Sagnik Roy 14 Oct 10, 2022
Disturbing Target Values for Neural Network regularization: attacking the loss layer to prevent overfitting

Disturbing Target Values for Neural Network regularization: attacking the loss layer to prevent overfitting 1. Classification Task PyTorch implementat

Yongho Kim 0 Apr 24, 2022
Convert BART models to ONNX with quantization. 3X reduction in size, and upto 3X boost in inference speed

fast-Bart Reduction of BART model size by 3X, and boost in inference speed up to 3X BART implementation of the fastT5 library (https://github.com/Ki6a

Siddharth Sharma 19 Dec 09, 2022
WormMovementSimulation - 3D Simulation of Worm Body Movement with Neurons attached to its body

Generate 3D Locomotion Data This module is intended to create 2D video trajector

1 Aug 09, 2022
SynNet - synthetic tree generation using neural networks

SynNet This repo contains the code and analysis scripts for our amortized approach to synthetic tree generation using neural networks. Our model can s

Wenhao Gao 60 Dec 29, 2022
clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation

README clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation CVPR 2021 Authors: Suprosanna Shit and Johannes C. Paetzo

110 Dec 29, 2022
This is the code for the paper "Motion-Focused Contrastive Learning of Video Representations" (ICCV'21).

Motion-Focused Contrastive Learning of Video Representations Introduction This is the code for the paper "Motion-Focused Contrastive Learning of Video

11 Sep 23, 2022
Poplar implementation of "Bundle Adjustment on a Graph Processor" (CVPR 2020)

Poplar Implementation of Bundle Adjustment using Gaussian Belief Propagation on Graphcore's IPU Implementation of CVPR 2020 paper: Bundle Adjustment o

Joe Ortiz 34 Dec 05, 2022
Transformers4Rec is a flexible and efficient library for sequential and session-based recommendation, available for both PyTorch and Tensorflow.

Transformers4Rec is a flexible and efficient library for sequential and session-based recommendation, available for both PyTorch and Tensorflow.

730 Jan 09, 2023
CSAW-M: An Ordinal Classification Dataset for Benchmarking Mammographic Masking of Cancer

CSAW-M This repository contains code for CSAW-M: An Ordinal Classification Dataset for Benchmarking Mammographic Masking of Cancer. Source code for tr

Yue Liu 7 Oct 11, 2022
Step by Step on how to create an vision recognition model using LOBE.ai, export the model and run the model in an Azure Function

Step by Step on how to create an vision recognition model using LOBE.ai, export the model and run the model in an Azure Function

El Bruno 3 Mar 30, 2022
Deep Learning Tutorial for Kaggle Ultrasound Nerve Segmentation competition, using Keras

Deep Learning Tutorial for Kaggle Ultrasound Nerve Segmentation competition, using Keras This tutorial shows how to use Keras library to build deep ne

Marko Jocić 922 Dec 19, 2022
Hierarchical probabilistic 3D U-Net, with attention mechanisms (—𝘈𝘵𝘵𝘦𝘯𝘵𝘪𝘰𝘯 𝘜-𝘕𝘦𝘵, 𝘚𝘌𝘙𝘦𝘴𝘕𝘦𝘵) and a nested decoder structure with deep supervision (—𝘜𝘕𝘦𝘵++).

Hierarchical probabilistic 3D U-Net, with attention mechanisms (—𝘈𝘵𝘵𝘦𝘯𝘵𝘪𝘰𝘯 𝘜-𝘕𝘦𝘵, 𝘚𝘌𝘙𝘦𝘴𝘕𝘦𝘵) and a nested decoder structure with deep supervision (—𝘜𝘕𝘦𝘵++). Built in TensorFlow 2.5. Configured for vox

Diagnostic Image Analysis Group 32 Dec 08, 2022
Unimodal Face Classification with Multimodal Training

Unimodal Face Classification with Multimodal Training This is a PyTorch implementation of the following paper: Unimodal Face Classification with Multi

Wenbin Teng 3 Jul 06, 2022
Python package for downloading ECMWF reanalysis data and converting it into a time series format.

ecmwf_models Readers and converters for data from the ECMWF reanalysis models. Written in Python. Works great in combination with pytesmo. Citation If

TU Wien - Department of Geodesy and Geoinformation 31 Dec 26, 2022