GeDML is an easy-to-use generalized deep metric learning library

Overview

Logo

Documentation build

News

  • [2021-9-6]: v0.0.0 has been released.

Introduction

GeDML is an easy-to-use generalized deep metric learning library, which contains:

  • State-of-the-art DML algorithms: We contrain 18+ losses functions and 6+ sampling strategies, and divide these algorithms into three categories (i.e., collectors, selectors, and losses).
  • Bridge bewteen DML and SSL: We attempt to bridge the gap between deep metric learning and self-supervised learning through specially designed modules, such as collectors.
  • Auxiliary modules to assist in building: We also encapsulates the upper interface for users to start programs quickly and separates the codes and configs for managing hyper-parameters conveniently.

Installation

Pip

pip install gedml

Framework

This project is modular in design. The pipeline diagram is as follows:

Pipeline

Code structure

  • _debug: Debug files.
  • demo: Demos of configuration files.
  • docs: Documentation.
  • src: Source code.
    • core: Losses, selectors, collectors, etc.
    • client: Tmux manager.
    • config: Config files including link, convert, assert and params.
    • launcher: Manager, Trainer, Tester, etc.
    • recorder: Recorder.

Method

Collectors

method description
BaseCollector Base class
DefaultCollector Do nothing
ProxyCollector Maintain a set of proxies
MoCoCollector paper: Momentum Contrast for Unsupervised Visual Representation Learning
SimSiamCollector paper: Exploring Simple Siamese Representation Learning
HDMLCollector paper: Hardness-Aware Deep Metric Learning
DAMLCollector paper: Deep Adversarial Metric Learning
DVMLCollector paper: Deep Variational Metric Learning

Losses

classifier-based

method description
CrossEntropyLoss Cross entropy loss for unsupervised methods
LargeMarginSoftmaxLoss paper: Large-Margin Softmax Loss for Convolutional Neural Networks
ArcFaceLoss paper: ArcFace: Additive Angular Margin Loss for Deep Face Recognition
CosFaceLoss paper: CosFace: Large Margin Cosine Loss for Deep Face Recognition

pair-based

method description
ContrastiveLoss paper: Learning a Similarity Metric Discriminatively, with Application to Face Verification
MarginLoss paper: Sampling Matters in Deep Embedding Learning
TripletLoss paper: Learning local feature descriptors with triplets and shallow convolutional neural networks
AngularLoss paper: Deep Metric Learning with Angular Loss
CircleLoss paper: Circle Loss: A Unified Perspective of Pair Similarity Optimization
FastAPLoss paper: Deep Metric Learning to Rank
LiftedStructureLoss paper: Deep Metric Learning via Lifted Structured Feature Embedding
MultiSimilarityLoss paper: Multi-Similarity Loss With General Pair Weighting for Deep Metric Learning
NPairLoss paper: Improved Deep Metric Learning with Multi-class N-pair Loss Objective
SignalToNoiseRatioLoss paper: Signal-To-Noise Ratio: A Robust Distance Metric for Deep Metric Learning
PosPairLoss paper: Exploring Simple Siamese Representation Learning

proxy-based

method description
ProxyLoss paper: No Fuss Distance Metric Learning Using Proxies
ProxyAnchorLoss paper: Proxy Anchor Loss for Deep Metric Learning
SoftTripleLoss paper: SoftTriple Loss: Deep Metric Learning Without Triplet Sampling

Selectors

method description
BaseSelector Base class
DefaultSelector Do nothing
DenseTripletSelector Select all triples
DensePairSelector Select all pairs

Quickstart

Please set the environment variable WORKSPACE first to indicate where to manage your project.

Initialization

Use ConfigHandler to create all objects.

config_handler = ConfigHandler()
config_handler.get_params_dict()
objects_dict = config_handler.create_all()

Start

Use manager to automatically call trainer and tester.

manager = utils.get_default(objects_dict, "managers")
manager.run()

Directly use trainer and tester.

trainer = utils.get_default(objects_dict, "trainers")
tester = utils.get_default(objects_dict, "testers")
recorder = utils.get_default(objects_dict, "recorders")

# start to train
utils.func_params_mediator(
    [objects_dict],
    trainer.__call__
)

# start to test
metrics = utils.func_params_mediator(
    [
        {"recorders": recorder},
        objects_dict,
    ],
    tester.__call__
)

Document

For more information, please refer to:

đŸ“– đŸ‘‰ Docs

Some specific guidances:

Configs

We will continually update the optimal parameters of different configs in TsinghuaCloud

Code Reference

TODO:

  • assert parameters
  • distributed methods and Non-distributed methods!!!
  • write github action to automate unit-test, package publish and docs building.
  • add cross-validation splits protocol.
Owner
Borui Zhang
I am a first year Ph.D student in the Department of Automation at THU. My research direction is computer vision.
Borui Zhang
Using deep learning model to detect breast cancer.

Breast-Cancer-Detection Breast cancer is the most frequent cancer among women, with around one in every 19 women at risk. The number of cases of breas

1 Feb 13, 2022
Pytorch implementation for "Density-aware Chamfer Distance as a Comprehensive Metric for Point Cloud Completion" (NeurIPS 2021)

Density-aware Chamfer Distance This repository contains the official PyTorch implementation of our paper: Density-aware Chamfer Distance as a Comprehe

Tong WU 93 Dec 15, 2022
A Neural Net Training Interface on TensorFlow, with focus on speed + flexibility

Tensorpack is a neural network training interface based on TensorFlow. Features: It's Yet Another TF high-level API, with speed, and flexibility built

Tensorpack 6.2k Jan 09, 2023
PyTorch implementation of MLP-Mixer

PyTorch implementation of MLP-Mixer MLP-Mixer: an all-MLP architecture composed of alternate token-mixing and channel-mixing operations. The token-mix

Duo Li 33 Nov 27, 2022
JUSTICE: A Benchmark Dataset for Supreme Court’s Judgment Prediction

JUSTICE: A Benchmark Dataset for Supreme Court’s Judgment Prediction CSCI 544 Final Project done by: Mohammed Alsayed, Shaayan Syed, Mohammad Alali, S

Smit Patel 3 Dec 28, 2022
Toolbox to analyze temporal context invariance of deep neural networks

PyTCI A toolbox that estimates the integration window of a sensory response using the "Temporal Context Invariance" paradigm (TCI). The TCI method Int

4 Oct 23, 2022
Unofficial TensorFlow implementation of Protein Interface Prediction using Graph Convolutional Networks.

[TensorFlow] Protein Interface Prediction using Graph Convolutional Networks Unofficial TensorFlow implementation of Protein Interface Prediction usin

YeongHyeon Park 9 Oct 25, 2022
A multi-scale unsupervised learning for deformable image registration

A multi-scale unsupervised learning for deformable image registration Shuwei Shao, Zhongcai Pei, Weihai Chen, Wentao Zhu, Xingming Wu and Baochang Zha

ShuweiShao 2 Apr 13, 2022
Generates all variables from your .tf files into a variables.tf file.

tfvg Generates all variables from your .tf files into a variables.tf file. It searches for every var.variable_name in your .tf files and generates a v

1 Dec 01, 2022
A Survey on Deep Learning Technique for Video Segmentation

A Survey on Deep Learning Technique for Video Segmentation A Survey on Deep Learning Technique for Video Segmentation Wenguan Wang, Tianfei Zhou, Fati

Tianfei Zhou 112 Dec 12, 2022
[3DV 2021] Channel-Wise Attention-Based Network for Self-Supervised Monocular Depth Estimation

Channel-Wise Attention-Based Network for Self-Supervised Monocular Depth Estimation This is the official implementation for the method described in Ch

Jiaxing Yan 27 Dec 30, 2022
Data Consistency for Magnetic Resonance Imaging

Data Consistency for Magnetic Resonance Imaging Data Consistency (DC) is crucial for generalization in multi-modal MRI data and robustness in detectin

Dimitris Karkalousos 19 Dec 12, 2022
The official implementation of CSG-Stump: A Learning Friendly CSG-Like Representation for Interpretable Shape Parsing

CSGStumpNet The official implementation of CSG-Stump: A Learning Friendly CSG-Like Representation for Interpretable Shape Parsing Paper | Project page

Daxuan 39 Dec 26, 2022
A Deep Reinforcement Learning Framework for Stock Market Trading

DQN-Trading This is a framework based on deep reinforcement learning for stock market trading. This project is the implementation code for the two pap

61 Jan 01, 2023
A Keras implementation of YOLOv3 (Tensorflow backend)

keras-yolo3 Introduction A Keras implementation of YOLOv3 (Tensorflow backend) inspired by allanzelener/YAD2K. Quick Start Download YOLOv3 weights fro

7.1k Jan 03, 2023
Implementation of "StrengthNet: Deep Learning-based Emotion Strength Assessment for Emotional Speech Synthesis"

StrengthNet Implementation of "StrengthNet: Deep Learning-based Emotion Strength Assessment for Emotional Speech Synthesis" https://arxiv.org/abs/2110

RuiLiu 65 Dec 20, 2022
Code for paper "Context-self contrastive pretraining for crop type semantic segmentation"

Code for paper "Context-self contrastive pretraining for crop type semantic segmentation" Setting up a python environment Follow the instruction in ht

Michael Tarasiou 11 Oct 09, 2022
Models, datasets and tools for Facial keypoints detection

Template for Data Science Project This repo aims to give a robust starting point to any Data Science related project. It contains readymade tools setu

girafe.ai 1 Feb 11, 2022
Datasets, Transforms and Models specific to Computer Vision

torchvision The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision. Installat

13.1k Jan 02, 2023
PyTorch reimplementation of hand-biomechanical-constraints (ECCV2020)

Hand Biomechanical Constraints Pytorch Unofficial PyTorch reimplementation of Hand-Biomechanical-Constraints (ECCV2020). This project reimplement foll

Hao Meng 59 Dec 20, 2022