Learning to Prompt for Vision-Language Models.

Related tags

Deep LearningCoOp
Overview

CoOp

Paper: Learning to Prompt for Vision-Language Models

Authors: Kaiyang Zhou, Jingkang Yang, Chen Change Loy, Ziwei Liu

CoOp (Context Optimization) is a differentiable approach that focuses on continuous prompt learning to facilitate deployment of pre-trained vision language models (like CLIP) in downstream datasets.

Updates

  • 15.10.2021: We find that the best_val model and the last_step model achieve similar performance, so we set TEST.FINAL_MODEL = "last_step" for all datasets to save training time. Why we used best_val: the (tiny) validation set was designed for the linear probe approach, which requires extensive tuning for its hyperparameters, so we used the best_val model for CoOp as well for fair comparison (in this way, both approaches have access to the validation set).

  • 09.10.2021: Important changes are made to Dassl's transforms.py. Please pull the latest commits from https://github.com/KaiyangZhou/Dassl.pytorch and this repo to make sure the code works properly. In particular, 1) center_crop now becomes a default transform in testing (applied after resizing the smaller edge to a certain size to keep the image aspect ratio), and 2) for training, Resize(cfg.INPUT.SIZE) is deactivated when random_crop or random_resized_crop is used. Please read this issue on how these changes might affect the performance.

  • 18.09.2021: We have fixed an error in Dassl which could cause a training data loader to have zero length (so no training will be performed) when the dataset size is smaller than the batch size (due to drop_last=True). Please pull the latest commit for Dassl (>= 8eecc3c). This error led to lower results for CoOp in EuroSAT's 1- and 2-shot settings (others are all correct). We will update the paper on arxiv to fix this error.

How to Install

This code is built on top of the awesome toolbox Dassl.pytorch so you need to install the dassl environment first. Simply follow the instructions described here to install dassl as well as PyTorch. After that, run pip install -r requirements.txt under CoOp/ to install a few more packages required by CLIP (this should be done when dassl is activated). Then, you are ready to go.

Follow DATASETS.md to install the datasets.

How to Run

We provide the running scripts in scripts/. Make sure you change the path in DATA and run the commands under CoOp/scripts/.

Few-Shot Learning

All you need is CoOp/scripts/main.sh, which contains six input arguments.

DATASET takes as input a dataset name, like imagenet or caltech101. The valid names are the files' names in CoOp/configs/datasets/.

CFG means which config file to use, such as rn50, rn101 or vit_b32 (see CoOp/configs/trainers/CoOp/). Note that for ImageNet, we use CoOp/configs/trainers/CoOp/*_ep50.yaml for all settings (please follow the implementation details shown in the paper).

Below we provide examples on how to run CoOp on Caltech101.

CLIP + CoOp (M=16, end):

  • 1 shot: bash main.sh caltech101 rn50_ep50 end 16 1 False
  • 2 shots: bash main.sh caltech101 rn50_ep100 end 16 2 False
  • 4 shots: bash main.sh caltech101 rn50_ep100 end 16 4 False
  • 8 shots: bash main.sh caltech101 rn50 end 16 8 False
  • 16 shots: bash main.sh caltech101 rn50 end 16 16 False

CLIP + CoOp (M=16, mid):

  • 1 shot: bash main.sh caltech101 rn50_ep50 middle 16 1 False
  • 2 shots: bash main.sh caltech101 rn50_ep100 middle 16 2 False
  • 4 shots: bash main.sh caltech101 rn50_ep100 middle 16 4 False
  • 8 shots: bash main.sh caltech101 rn50 middle 16 8 False
  • 16 shots: bash main.sh caltech101 rn50 middle 16 16 False

CLIP + CoOp (M=16, end, CSC):

  • 1 shot: bash main.sh caltech101 rn50_ep50 end 16 1 True
  • 2 shots: bash main.sh caltech101 rn50_ep100 end 16 2 True
  • 4 shots: bash main.sh caltech101 rn50_ep100 end 16 4 True
  • 8 shots: bash main.sh caltech101 rn50 end 16 8 True
  • 16 shots: bash main.sh caltech101 rn50 end 16 16 True

CLIP + CoOp (M=16, mid, CSC):

  • 1 shot: bash main.sh caltech101 rn50_ep50 middle 16 1 True
  • 2 shots: bash main.sh caltech101 rn50_ep100 middle 16 2 True
  • 4 shots: bash main.sh caltech101 rn50_ep100 middle 16 4 True
  • 8 shots: bash main.sh caltech101 rn50 middle 16 8 True
  • 16 shots: bash main.sh caltech101 rn50 middle 16 16 True

After the experiments are finished, you can use parse_test_res.py to calculate the average results instead of manually looking into the log files. Say the structure of output/ is

output
|–– caltech101/
|   |–– CoOp/
|   |   |–– rn50_16shots/
|   |   |   |–– nctx16_cscFalse_ctpend/
|   |   |   |   |–– seed1/
|   |   |   |   |–– seed2/
|   |   |   |   |–– seed3/
|   |   |–– rn50_8shots/
|   |   |   |–– nctx16_cscFalse_ctpend/
|   |   |   |   |–– seed1/
|   |   |   |   |–– seed2/
|   |   |   |   |–– seed3/

To calculate the average results for the folder rn50_16shots/nctx16_cscFalse_ctpend/, you can run

python parse_test_res.py output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend

Then, you will see something like this in your terminal

Parsing files in output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed1/log.txt. accuracy: 91.81%. error: 8.19%.
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed2/log.txt. accuracy: 92.01%. error: 7.99%.
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed3/log.txt. accuracy: 92.17%. error: 7.83%.
===
Summary of directory: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend
* accuracy: 92.00% +- 0.15%
* error: 8.00% +- 0.15%
===

How to initialize the context tokens with pre-trained word vectors? Specify the words for the parameter TRAINER.COOP.CTX_INIT in your config file. In our paper, we use configs/trainers/rn50_ctxv1.yaml (give this file to --config-file, see scripts/main.sh), which uses "a photo of a" as the initialization words.

How to visualize nearest words for the learned context tokens? All you need is interpret_prompt.py. Say the learned tokens are saved in a/b/c/prompt_learner/model.pth.tar and you would like to see the top-3 nearest words for each token. In this case, run python interpret_prompt.py a/b/c/prompt_learner/model.pth.tar 3

Robustness to Distribution Shift

To reproduce the robustness experiments, you can simply load the models learned on ImageNet and evaluate them on the following datasets: imagenetv2, imagenet-sketch, imagenet-a and imagenet-r.

The command is provided in CoOp/scripts/eval.sh. The key arguments are --model-dir, --load-epoch and --eval-only. --model-dir indicates the directory where the models are saved (i.e. the entire folder containing log.txt, the tensorboard file and prompt_learner/). --load-epoch tells the code to load the model saved at a specific epoch, like --load-epoch 50 for ImageNet (see the source code for more details).

For example, to evaluate CLIP + CoOp (M=16, end) on ImageNetV2, you can do

# Don't need to use rn5_ep50 here as no training is performed
bash eval.sh imagenetv2 rn50

The default setting is SHOTS=16. Feel free to modify the script.

Again, you can use parse_test_res.py to automate the calculation of average performance. This time you should append --test-log, e.g., python parse_test_res.py directory --test-log.

Zero-Shot CLIP

See CoOp/scripts/zeroshot.sh.

Linear Probe CLIP

Please move to lpclip/.

How to Cite CoOp

If you use this code in your research, please kindly cite the following paper

@article{zhou2021coop,
    title={Learning to Prompt for Vision-Language Models},
    author={Zhou, Kaiyang and Yang, Jingkang and Loy, Chen Change and Liu, Ziwei},
    journal={arXiv preprint arXiv:2109.01134},
    year={2021}
}
Owner
Kaiyang
Kaiyang
A flexible and extensible framework for gait recognition.

A flexible and extensible framework for gait recognition. You can focus on designing your own models and comparing with state-of-the-arts easily with the help of OpenGait.

Shiqi Yu 335 Dec 22, 2022
Classification models 1D Zoo - Keras and TF.Keras

Classification models 1D Zoo - Keras and TF.Keras This repository contains 1D variants of popular CNN models for classification like ResNets, DenseNet

Roman Solovyev 12 Jan 06, 2023
[ICLR 2022 Oral] F8Net: Fixed-Point 8-bit Only Multiplication for Network Quantization

F8Net Fixed-Point 8-bit Only Multiplication for Network Quantization (ICLR 2022 Oral) OpenReview | arXiv | PDF | Model Zoo | BibTex PyTorch implementa

Snap Research 76 Dec 13, 2022
Realtime micro-expression recognition using OpenCV and PyTorch

Micro-expression Recognition Realtime micro-expression recognition from scratch using OpenCV and PyTorch Try it out with a webcam or video using the e

Irfan 35 Dec 05, 2022
To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

Kunal Wadhwa 2 Jan 05, 2022
Bald-to-Hairy Translation Using CycleGAN

GANiry: Bald-to-Hairy Translation Using CycleGAN Official PyTorch implementation of GANiry. GANiry: Bald-to-Hairy Translation Using CycleGAN, Fidan Sa

Fidan Samet 10 Oct 27, 2022
ICLR 2021 i-Mix: A Domain-Agnostic Strategy for Contrastive Representation Learning

Introduction PyTorch code for the ICLR 2021 paper [i-Mix: A Domain-Agnostic Strategy for Contrastive Representation Learning]. @inproceedings{lee2021i

Kibok Lee 68 Nov 27, 2022
Convert Pytorch model to onnx or tflite, and the converted model can be visualized by Netron

Convert Pytorch model to onnx or tflite, and the converted model can be visualized by Netron

Roxbili 5 Nov 19, 2022
Start-to-finish tutorial for interactive music co-creation in PyTorch and Tensorflow.js

Start-to-finish tutorial for interactive music co-creation in PyTorch and Tensorflow.js

Chris Donahue 98 Dec 14, 2022
Graph-Refined Convolutional Network for Multimedia Recommendation with Implicit Feedback

Graph-Refined Convolutional Network for Multimedia Recommendation with Implicit Feedback This is our Pytorch implementation for the paper: Yinwei Wei,

17 Jun 10, 2022
Alias-Free Generative Adversarial Networks (StyleGAN3) Official PyTorch implementation

Alias-Free Generative Adversarial Networks (StyleGAN3) Official PyTorch implementation

NVIDIA Research Projects 4.8k Jan 09, 2023
[IROS'21] SurRoL: An Open-source Reinforcement Learning Centered and dVRK Compatible Platform for Surgical Robot Learning

SurRoL IROS 2021 SurRoL: An Open-source Reinforcement Learning Centered and dVRK Compatible Platform for Surgical Robot Learning Features dVRK compati

<a href=[email protected]"> 55 Jan 03, 2023
This YoloV5 based model is fit to detect people and different types of land vehicles, and displaying their density on a fitted map, according to their coordinates and detected labels.

This YoloV5 based model is fit to detect people and different types of land vehicles, and displaying their density on a fitted map, according to their

Liron Bdolah 8 May 22, 2022
Beginner-friendly repository for Hacktober Fest 2021. Start your contribution to open source through baby steps. 💜

Hacktober Fest 2021 🎉 Open source is changing the world – one contribution at a time! 🎉 This repository is made for beginners who are unfamiliar wit

Abhilash M Nair 32 Dec 11, 2022
The official implementation of the Hybrid Self-Attention NEAT algorithm

PUREPLES - Pure Python Library for ES-HyperNEAT About This is a library of evolutionary algorithms with a focus on neuroevolution, implemented in pure

Adrian Westh 91 Dec 12, 2022
Hypercomplex Neural Networks with PyTorch

HyperNets Hypercomplex Neural Networks with PyTorch: this repository would be a container for hypercomplex neural network modules to facilitate resear

Eleonora Grassucci 21 Dec 27, 2022
Run Keras models in the browser, with GPU support using WebGL

**This project is no longer active. Please check out TensorFlow.js.** The Keras.js demos still work but is no longer updated. Run Keras models in the

Leon Chen 4.9k Dec 29, 2022
Official implementation for NIPS'17 paper: PredRNN: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal LSTMs.

PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning The predictive learning of spatiotemporal sequences aims to generate future

THUML: Machine Learning Group @ THSS 243 Dec 26, 2022
PyTorch-Multi-Style-Transfer - Neural Style and MSG-Net

PyTorch-Style-Transfer This repo provides PyTorch Implementation of MSG-Net (ours) and Neural Style (Gatys et al. CVPR 2016), which has been included

Hang Zhang 906 Jan 04, 2023
PyTorch Kafka Dataset: A definition of a dataset to get training data from Kafka.

PyTorch Kafka Dataset: A definition of a dataset to get training data from Kafka.

ERTIS Research Group 7 Aug 01, 2022