Learning to Prompt for Vision-Language Models.

Related tags

Deep LearningCoOp
Overview

CoOp

Paper: Learning to Prompt for Vision-Language Models

Authors: Kaiyang Zhou, Jingkang Yang, Chen Change Loy, Ziwei Liu

CoOp (Context Optimization) is a differentiable approach that focuses on continuous prompt learning to facilitate deployment of pre-trained vision language models (like CLIP) in downstream datasets.

Updates

  • 15.10.2021: We find that the best_val model and the last_step model achieve similar performance, so we set TEST.FINAL_MODEL = "last_step" for all datasets to save training time. Why we used best_val: the (tiny) validation set was designed for the linear probe approach, which requires extensive tuning for its hyperparameters, so we used the best_val model for CoOp as well for fair comparison (in this way, both approaches have access to the validation set).

  • 09.10.2021: Important changes are made to Dassl's transforms.py. Please pull the latest commits from https://github.com/KaiyangZhou/Dassl.pytorch and this repo to make sure the code works properly. In particular, 1) center_crop now becomes a default transform in testing (applied after resizing the smaller edge to a certain size to keep the image aspect ratio), and 2) for training, Resize(cfg.INPUT.SIZE) is deactivated when random_crop or random_resized_crop is used. Please read this issue on how these changes might affect the performance.

  • 18.09.2021: We have fixed an error in Dassl which could cause a training data loader to have zero length (so no training will be performed) when the dataset size is smaller than the batch size (due to drop_last=True). Please pull the latest commit for Dassl (>= 8eecc3c). This error led to lower results for CoOp in EuroSAT's 1- and 2-shot settings (others are all correct). We will update the paper on arxiv to fix this error.

How to Install

This code is built on top of the awesome toolbox Dassl.pytorch so you need to install the dassl environment first. Simply follow the instructions described here to install dassl as well as PyTorch. After that, run pip install -r requirements.txt under CoOp/ to install a few more packages required by CLIP (this should be done when dassl is activated). Then, you are ready to go.

Follow DATASETS.md to install the datasets.

How to Run

We provide the running scripts in scripts/. Make sure you change the path in DATA and run the commands under CoOp/scripts/.

Few-Shot Learning

All you need is CoOp/scripts/main.sh, which contains six input arguments.

DATASET takes as input a dataset name, like imagenet or caltech101. The valid names are the files' names in CoOp/configs/datasets/.

CFG means which config file to use, such as rn50, rn101 or vit_b32 (see CoOp/configs/trainers/CoOp/). Note that for ImageNet, we use CoOp/configs/trainers/CoOp/*_ep50.yaml for all settings (please follow the implementation details shown in the paper).

Below we provide examples on how to run CoOp on Caltech101.

CLIP + CoOp (M=16, end):

  • 1 shot: bash main.sh caltech101 rn50_ep50 end 16 1 False
  • 2 shots: bash main.sh caltech101 rn50_ep100 end 16 2 False
  • 4 shots: bash main.sh caltech101 rn50_ep100 end 16 4 False
  • 8 shots: bash main.sh caltech101 rn50 end 16 8 False
  • 16 shots: bash main.sh caltech101 rn50 end 16 16 False

CLIP + CoOp (M=16, mid):

  • 1 shot: bash main.sh caltech101 rn50_ep50 middle 16 1 False
  • 2 shots: bash main.sh caltech101 rn50_ep100 middle 16 2 False
  • 4 shots: bash main.sh caltech101 rn50_ep100 middle 16 4 False
  • 8 shots: bash main.sh caltech101 rn50 middle 16 8 False
  • 16 shots: bash main.sh caltech101 rn50 middle 16 16 False

CLIP + CoOp (M=16, end, CSC):

  • 1 shot: bash main.sh caltech101 rn50_ep50 end 16 1 True
  • 2 shots: bash main.sh caltech101 rn50_ep100 end 16 2 True
  • 4 shots: bash main.sh caltech101 rn50_ep100 end 16 4 True
  • 8 shots: bash main.sh caltech101 rn50 end 16 8 True
  • 16 shots: bash main.sh caltech101 rn50 end 16 16 True

CLIP + CoOp (M=16, mid, CSC):

  • 1 shot: bash main.sh caltech101 rn50_ep50 middle 16 1 True
  • 2 shots: bash main.sh caltech101 rn50_ep100 middle 16 2 True
  • 4 shots: bash main.sh caltech101 rn50_ep100 middle 16 4 True
  • 8 shots: bash main.sh caltech101 rn50 middle 16 8 True
  • 16 shots: bash main.sh caltech101 rn50 middle 16 16 True

After the experiments are finished, you can use parse_test_res.py to calculate the average results instead of manually looking into the log files. Say the structure of output/ is

output
|–– caltech101/
|   |–– CoOp/
|   |   |–– rn50_16shots/
|   |   |   |–– nctx16_cscFalse_ctpend/
|   |   |   |   |–– seed1/
|   |   |   |   |–– seed2/
|   |   |   |   |–– seed3/
|   |   |–– rn50_8shots/
|   |   |   |–– nctx16_cscFalse_ctpend/
|   |   |   |   |–– seed1/
|   |   |   |   |–– seed2/
|   |   |   |   |–– seed3/

To calculate the average results for the folder rn50_16shots/nctx16_cscFalse_ctpend/, you can run

python parse_test_res.py output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend

Then, you will see something like this in your terminal

Parsing files in output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed1/log.txt. accuracy: 91.81%. error: 8.19%.
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed2/log.txt. accuracy: 92.01%. error: 7.99%.
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed3/log.txt. accuracy: 92.17%. error: 7.83%.
===
Summary of directory: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend
* accuracy: 92.00% +- 0.15%
* error: 8.00% +- 0.15%
===

How to initialize the context tokens with pre-trained word vectors? Specify the words for the parameter TRAINER.COOP.CTX_INIT in your config file. In our paper, we use configs/trainers/rn50_ctxv1.yaml (give this file to --config-file, see scripts/main.sh), which uses "a photo of a" as the initialization words.

How to visualize nearest words for the learned context tokens? All you need is interpret_prompt.py. Say the learned tokens are saved in a/b/c/prompt_learner/model.pth.tar and you would like to see the top-3 nearest words for each token. In this case, run python interpret_prompt.py a/b/c/prompt_learner/model.pth.tar 3

Robustness to Distribution Shift

To reproduce the robustness experiments, you can simply load the models learned on ImageNet and evaluate them on the following datasets: imagenetv2, imagenet-sketch, imagenet-a and imagenet-r.

The command is provided in CoOp/scripts/eval.sh. The key arguments are --model-dir, --load-epoch and --eval-only. --model-dir indicates the directory where the models are saved (i.e. the entire folder containing log.txt, the tensorboard file and prompt_learner/). --load-epoch tells the code to load the model saved at a specific epoch, like --load-epoch 50 for ImageNet (see the source code for more details).

For example, to evaluate CLIP + CoOp (M=16, end) on ImageNetV2, you can do

# Don't need to use rn5_ep50 here as no training is performed
bash eval.sh imagenetv2 rn50

The default setting is SHOTS=16. Feel free to modify the script.

Again, you can use parse_test_res.py to automate the calculation of average performance. This time you should append --test-log, e.g., python parse_test_res.py directory --test-log.

Zero-Shot CLIP

See CoOp/scripts/zeroshot.sh.

Linear Probe CLIP

Please move to lpclip/.

How to Cite CoOp

If you use this code in your research, please kindly cite the following paper

@article{zhou2021coop,
    title={Learning to Prompt for Vision-Language Models},
    author={Zhou, Kaiyang and Yang, Jingkang and Loy, Chen Change and Liu, Ziwei},
    journal={arXiv preprint arXiv:2109.01134},
    year={2021}
}
Owner
Kaiyang
Kaiyang
An implementation of a sequence to sequence neural network using an encoder-decoder

Keras implementation of a sequence to sequence model for time series prediction using an encoder-decoder architecture. I created this post to share a

Luke Tonin 195 Dec 17, 2022
This repository consists of Blender python scripts and corresponding assets to generate variants of the CANDLE dataset

candle-simulator This repository consists of Blender python scripts and corresponding assets to generate variants of the IITH-CANDLE dataset. The rend

1 Dec 15, 2021
Empowering journalists and whistleblowers

Onymochat Empowering journalists and whistleblowers Onymochat is an end-to-end encrypted, decentralized, anonymous chat application. You can also host

Samrat Dutta 19 Sep 02, 2022
Artstation-Artistic-face-HQ Dataset (AAHQ)

Artstation-Artistic-face-HQ Dataset (AAHQ) Artstation-Artistic-face-HQ (AAHQ) is a high-quality image dataset of artistic-face images. It is proposed

onion 105 Dec 16, 2022
ICML 21 - Voice2Series: Reprogramming Acoustic Models for Time Series Classification

Voice2Series-Reprogramming Voice2Series: Reprogramming Acoustic Models for Time Series Classification International Conference on Machine Learning (IC

49 Jan 03, 2023
Jupyter notebooks showing best practices for using cx_Oracle, the Python DB API for Oracle Database

Python cx_Oracle Notebooks, 2022 The repository contains Jupyter notebooks showing best practices for using cx_Oracle, the Python DB API for Oracle Da

Christopher Jones 13 Dec 15, 2022
MetaBalance: High-Performance Neural Networks for Class-Imbalanced Data

This repository is the official PyTorch implementation of Meta-Balance. Find the paper on arxiv MetaBalance: High-Performance Neural Networks for Clas

Arpit Bansal 20 Oct 18, 2021
TalkNet 2: Non-Autoregressive Depth-Wise Separable Convolutional Model for Speech Synthesis with Explicit Pitch and Duration Prediction.

TalkNet 2 [WIP] TalkNet 2: Non-Autoregressive Depth-Wise Separable Convolutional Model for Speech Synthesis with Explicit Pitch and Duration Predictio

Rishikesh (ऋषिकेश) 69 Dec 17, 2022
Anomaly detection related books, papers, videos, and toolboxes

Anomaly Detection Learning Resources Outlier Detection (also known as Anomaly Detection) is an exciting yet challenging field, which aims to identify

Yue Zhao 6.7k Dec 31, 2022
This repo is a PyTorch implementation for Paper "Unsupervised Learning for Cuboid Shape Abstraction via Joint Segmentation from Point Clouds"

Unsupervised Learning for Cuboid Shape Abstraction via Joint Segmentation from Point Clouds This repository is a PyTorch implementation for paper: Uns

Kaizhi Yang 42 Dec 09, 2022
Help you understand Manual and w/ Clutch point while driving.

简体中文 forza_auto_gear forza_auto_gear is a tool for Forza Horizon 5. It will help us understand the best gear shift point using Manual or w/ Clutch in

15 Oct 08, 2022
Probabilistic Gradient Boosting Machines

PGBM Probabilistic Gradient Boosting Machines (PGBM) is a probabilistic gradient boosting framework in Python based on PyTorch/Numba, developed by Air

Olivier Sprangers 112 Dec 28, 2022
World Models with TensorFlow 2

World Models This repo reproduces the original implementation of World Models. This implementation uses TensorFlow 2.2. Docker The easiest way to hand

Zac Wellmer 234 Nov 30, 2022
Adversarial Autoencoders

Adversarial Autoencoders (with Pytorch) Dependencies argparse time torch torchvision numpy itertools matplotlib Create Datasets python create_datasets

Felipe Ducau 188 Jan 01, 2023
GitHub repository for "Improving Video Generation for Multi-functional Applications"

Improving Video Generation for Multi-functional Applications GitHub repository for "Improving Video Generation for Multi-functional Applications" Pape

Bernhard Kratzwald 328 Dec 07, 2022
Code accompanying "Dynamic Neural Relational Inference" from CVPR 2020

Code accompanying "Dynamic Neural Relational Inference" This codebase accompanies the paper "Dynamic Neural Relational Inference" from CVPR 2020. This

Colin Graber 48 Dec 23, 2022
GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond

GCNet for Object Detection By Yue Cao, Jiarui Xu, Stephen Lin, Fangyun Wei, Han Hu. This repo is a official implementation of "GCNet: Non-local Networ

Jerry Jiarui XU 1.1k Dec 29, 2022
Image segmentation with private İstanbul Dataset

Image Segmentation This repo was created for academic research and test result. Repo will update after academic article online. This repo contains wei

İrem KÖMÜRCÜ 9 Dec 11, 2022
My solutions for Stanford University course CS224W: Machine Learning with Graphs Fall 2021 colabs (GNN, GAT, GraphSAGE, GCN)

machine-learning-with-graphs My solutions for Stanford University course CS224W: Machine Learning with Graphs Fall 2021 colabs Course materials can be

Marko Njegomir 7 Dec 14, 2022
[NeurIPS2021] Code Release of Learning Transferable Perturbations

Learning Transferable Adversarial Perturbations This is an official release of the paper Learning Transferable Adversarial Perturbations. The code is

Krishna Kanth 17 Nov 11, 2022