Prototypical Networks for Few shot Learning in PyTorch

Overview

Prototypical Networks for Few shot Learning in PyTorch

Simple alternative Implementation of Prototypical Networks for Few Shot Learning (paper, code) in PyTorch.

Prototypical Networks

As shown in the reference paper Prototypical Networks are trained to embed samples features in a vectorial space, in particular, at each episode (iteration), a number of samples for a subset of classes are selected and sent through the model, for each subset of class c a number of samples' features (n_support) are used to guess the prototype (their barycentre coordinates in the vectorial space) for that class, so then the distances between the remaining n_query samples and their class barycentre can be minimized.

Prototypical Networks

T-SNE

After training, you can compute the t-SNE for the features generated by the model (not done in this repo, more infos about t-SNE here), this is a sample as shown in the paper.

Reference Paper t-SNE

Omniglot Dataset

Kudos to @ludc for his contribute: https://github.com/pytorch/vision/pull/46. We will use the official dataset when it will be added to torchvision if it doesn't imply big changes to the code.

Dataset splits

We implemented the Vynials splitting method as in [Matching Networks for One Shot Learning]. That sould be the same method used in the paper (in fact I download the split files from the "offical" repo). We then apply the same rotations there described. In this way we should be able to compare results obtained by running this code with results described in the reference paper.

Prototypical Batch Sampler

As described in its PyDoc, this class is used to generate the indexes of each batch for a prototypical training algorithm.

In particular, the object is instantiated by passing the list of the labels for the dataset, the sampler infers then the total number of classes and creates a set of indexes for each class ni the dataset. At each episode the sampler selects n_classes random classes and returns a number (n_support + n_query) of samples indexes for each one of the selected classes.

Prototypical Loss

Compute the loss as in the cited paper, mostly inspired by this code by one of its authors.

In prototypical_loss.py both loss function and loss class à la PyTorch are implemented.

The function takes in input the batch input from the model, samples' ground truths and the number n_suppport of samples to be used as support samples. Episode classes get infered from the target list, n_support samples get randomly extracted for each class, their class barycentres get computed, as well as the distances of each remaining samples' embedding from each class barycentre and the probability of each sample of belonging to each episode class get finmally computed; then the loss is then computed from the wrong predictions probabilities (for the query samples) as usual in classification problems.

Training

Please note that the training code is here just for demonstration purposes.

To train the Protonet on this task, cd into this repo's src root folder and execute:

$ python train.py

The script takes the following command line options:

  • dataset_root: the root directory where tha dataset is stored, default to '../dataset'

  • nepochs: number of epochs to train for, default to 100

  • learning_rate: learning rate for the model, default to 0.001

  • lr_scheduler_step: StepLR learning rate scheduler step, default to 20

  • lr_scheduler_gamma: StepLR learning rate scheduler gamma, default to 0.5

  • iterations: number of episodes per epoch. default to 100

  • classes_per_it_tr: number of random classes per episode for training. default to 60

  • num_support_tr: number of samples per class to use as support for training. default to 5

  • num_query_tr: nnumber of samples per class to use as query for training. default to 5

  • classes_per_it_val: number of random classes per episode for validation. default to 5

  • num_support_val: number of samples per class to use as support for validation. default to 5

  • num_query_val: number of samples per class to use as query for validation. default to 15

  • manual_seed: input for the manual seeds initializations, default to 7

  • cuda: enables cuda (store True)

Running the command without arguments will train the models with the default hyperparamters values (producing results shown above).

Performances

We are trying to reproduce the reference paper performaces, we'll update here our best results.

Model 1-shot (5-way Acc.) 5-shot (5-way Acc.) 1 -shot (20-way Acc.) 5-shot (20-way Acc.)
Reference Paper 98.8% 99.7% 96.0% 98.9%
This repo 98.5%** 99.6%* 95.1%° 98.6%°°

* achieved using default parameters (using --cuda option)

** achieved running python train.py --cuda -nsTr 1 -nsVa 1

° achieved running python train.py --cuda -nsTr 1 -nsVa 1 -cVa 20

°° achieved running python train.py --cuda -nsTr 5 -nsVa 5 -cVa 20

Helpful links

.bib citation

cite the paper as follows (copied-pasted it from arxiv for you):

@article{DBLP:journals/corr/SnellSZ17,
  author    = {Jake Snell and
               Kevin Swersky and
               Richard S. Zemel},
  title     = {Prototypical Networks for Few-shot Learning},
  journal   = {CoRR},
  volume    = {abs/1703.05175},
  year      = {2017},
  url       = {http://arxiv.org/abs/1703.05175},
  archivePrefix = {arXiv},
  eprint    = {1703.05175},
  timestamp = {Wed, 07 Jun 2017 14:41:38 +0200},
  biburl    = {http://dblp.org/rec/bib/journals/corr/SnellSZ17},
  bibsource = {dblp computer science bibliography, http://dblp.org}
}

License

This project is licensed under the MIT License

Copyright (c) 2018 Daniele E. Ciriello, Orobix Srl (www.orobix.com).

Owner
Orobix
Orobix
Spatial color quantization in Rust

rscolorq Rust port of Derrick Coetzee's scolorq, based on the 1998 paper "On spatial quantization of color images" by Jan Puzicha, Markus Held, Jens K

Collyn O'Kane 37 Dec 22, 2022
Compact Bidirectional Transformer for Image Captioning

Compact Bidirectional Transformer for Image Captioning Requirements Python 3.8 Pytorch 1.6 lmdb h5py tensorboardX Prepare Data Please use git clone --

YE Zhou 19 Dec 12, 2022
Official implementation of "StyleCariGAN: Caricature Generation via StyleGAN Feature Map Modulation" (SIGGRAPH 2021)

StyleCariGAN in PyTorch Official implementation of StyleCariGAN:Caricature Generation via StyleGAN Feature Map Modulation in PyTorch Requirements PyTo

PeterZhouSZ 49 Oct 31, 2022
A distributed deep learning framework that supports flexible parallelization strategies.

FlexFlow FlexFlow is a deep learning framework that accelerates distributed DNN training by automatically searching for efficient parallelization stra

528 Dec 25, 2022
PyTorch implementation of neural style randomization for data augmentation

README Augment training images for deep neural networks by randomizing their visual style, as described in our paper: https://arxiv.org/abs/1809.05375

84 Nov 23, 2022
Python code to fuse multiple RGB-D images into a TSDF voxel volume.

Volumetric TSDF Fusion of RGB-D Images in Python This is a lightweight python script that fuses multiple registered color and depth images into a proj

Andy Zeng 845 Jan 03, 2023
Recurrent Conditional Query Learning

Recurrent Conditional Query Learning (RCQL) This repository contains the Pytorch implementation of One Model Packs Thousands of Items with Recurrent C

Dongda 4 Nov 28, 2022
The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"

Hierarchical Token Semantic Audio Transformer Introduction The Code Repository for "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound

Knut(Ke) Chen 134 Jan 01, 2023
A disassembler for the RP2040 Programmable I/O State-machine!

piodisasm A disassembler for the RP2040 Programmable I/O State-machine! Usage Just run piodisasm.py on a file that contains the PIO code as hex! (Such

Ghidra Ninja 29 Dec 06, 2022
PiRank: Learning to Rank via Differentiable Sorting

PiRank: Learning to Rank via Differentiable Sorting This repository provides a reference implementation for learning PiRank-based models as described

54 Dec 17, 2022
An AFL implementation with UnTracer (our coverage-guided tracer)

UnTracer-AFL This repository contains an implementation of our prototype coverage-guided tracing framework UnTracer in the popular coverage-guided fuz

113 Dec 17, 2022
CCP dataset from Clothing Co-Parsing by Joint Image Segmentation and Labeling

Clothing Co-Parsing (CCP) Dataset Clothing Co-Parsing (CCP) dataset is a new clothing database including elaborately annotated clothing items. 2, 098

Wei Yang 434 Dec 24, 2022
Implementation of BI-RADS-BERT & The Advantages of Section Tokenization.

BI-RADS BERT Implementation of BI-RADS-BERT & The Advantages of Section Tokenization. This implementation could be used on other radiology in house co

1 May 17, 2022
Official PyTorch implementation of Retrieve in Style: Unsupervised Facial Feature Transfer and Retrieval.

Retrieve in Style: Unsupervised Facial Feature Transfer and Retrieval PyTorch This is the PyTorch implementation of Retrieve in Style: Unsupervised Fa

60 Oct 12, 2022
OpenMMLab Video Perception Toolbox. It supports Video Object Detection (VID), Multiple Object Tracking (MOT), Single Object Tracking (SOT), Video Instance Segmentation (VIS) with a unified framework.

English | 简体中文 Documentation: https://mmtracking.readthedocs.io/ Introduction MMTracking is an open source video perception toolbox based on PyTorch.

OpenMMLab 2.7k Jan 08, 2023
Keeping it safe - AI Based COVID-19 Tracker using Deep Learning and facial recognition

Keeping it safe - AI Based COVID-19 Tracker using Deep Learning and facial recognition

Vansh Wassan 15 Jun 17, 2021
Predicting Semantic Map Representations from Images with Pyramid Occupancy Networks

This is the code associated with the paper Predicting Semantic Map Representations from Images with Pyramid Occupancy Networks, published at CVPR 2020.

Thomas Roddick 219 Dec 20, 2022
El-Gamal on Elliptic Curve (Python)

El-Gamal-on-EC El-Gamal on Elliptic Curve (Python) References: https://docsdrive.com/pdfs/ansinet/itj/2005/299-306.pdf https://arxiv.org/ftp/arxiv/pap

3 May 04, 2022
implementation of paper - You Only Learn One Representation: Unified Network for Multiple Tasks

YOLOR implementation of paper - You Only Learn One Representation: Unified Network for Multiple Tasks To reproduce the results in the paper, please us

Kin-Yiu, Wong 1.8k Jan 04, 2023
Recovering Brain Structure Network Using Functional Connectivity

Recovering-Brain-Structure-Network-Using-Functional-Connectivity Framework: Papers: This repository provides a PyTorch implementation of the models ad

5 Nov 30, 2022