Semi-supervised learning for object detection

Overview

Source code for STAC: A Simple Semi-Supervised Learning Framework for Object Detection

STAC is a simple yet effective SSL framework for visual object detection along with a data augmentation strategy. STAC deploys highly confident pseudo labels of localized objects from an unlabeled image and updates the model by enforcing consistency via strong augmentation.

This code is only used for research. This is not an official Google product.

Instruction

Install dependencies

Set global enviroment variables.

export PRJROOT=/path/to/your/project/directory/STAC
export DATAROOT=/path/to/your/dataroot
export COCODIR=$DATAROOT/coco
export VOCDIR=$DATAROOT/voc
export PYTHONPATH=$PYTHONPATH:${PRJROOT}/third_party/FasterRCNN:${PRJROOT}/third_party/auto_augment:${PRJROOT}/third_party/tensorpack

Install virtual environment in the root folder of the project

cd ${PRJROOT}

sudo apt install python3-dev python3-virtualenv python3-tk imagemagick
virtualenv -p python3 --system-site-packages env3
. env3/bin/activate
pip install -r requirements.txt

# Make sure your tensorflow version is 1.14 not only in virtual environment but also in
# your machine, 1.15 can cause OOM issues.
python -c 'import tensorflow as tf; print(tf.__version__)'

# install coco apis
pip3 install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'

(Optional) Install tensorpack

tensorpack with a compatible version is already included at third_party/tensorpack. bash cd ${PRJROOT}/third_party pip install --upgrade git+https://github.com/tensorpack/tensorpack.git

Download COCO/PASCAL VOC data and pre-trained models

Download data

See DATA.md

Download backbone model

cd ${COCODIR}
wget http://models.tensorpack.com/FasterRCNN/ImageNet-R50-AlignPadding.npz

Training

There are three steps:

  • 1. Train a standard detector on labeled data (detection/scripts/coco/train_stg1.sh).
  • 2. Predict pseudo boxes and labels of unlabeled data using the trained detector (detection/scripts/coco/eval_stg1.sh).
  • 3. Use labeled data and unlabeled data with pseudo labels to train a STAC detector (detection/scripts/coco/train_stg2.sh).

Besides instruction at here, detection/scripts/coco/train_stac.sh provides a combined script to train STAC.

detection/scripts/voc/train_stac.sh is a combined script to train STAC on PASCAL VOC.

The following example use labeled data as 10% train2017 and rest 90% train2017 data as unlabeled data.

Step 0: Set variables

cd ${PRJROOT}/detection

# Labeled and Unlabeled datasets
[email protected]
UNLABELED_DATASET=${DATASET}-unlabeled

# PATH to save trained models
CKPT_PATH=result/${DATASET}

# PATH to save pseudo labels for unlabeled data
PSEUDO_PATH=${CKPT_PATH}/PSEUDO_DATA

# Train with 8 GPUs
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

Step 1: Train FasterRCNN on labeled data

. scripts/coco/train_stg1.sh.

Set TRAIN.AUGTYPE_LAB=strong to apply strong data augmentation.

# --simple_path makes train_log/${DATASET}/${EXPNAME} as exact location to save
python3 train_stg1.py \
    --logdir ${CKPT_PATH} --simple_path --config \
    BACKBONE.WEIGHTS=${COCODIR}/ImageNet-R50-AlignPadding.npz \
    DATA.BASEDIR=${COCODIR} \
    DATA.TRAIN="('${DATASET}',)" \
    MODE_MASK=False \
    FRCNN.BATCH_PER_IM=64 \
    PREPROC.TRAIN_SHORT_EDGE_SIZE="[500,800]" \
    TRAIN.EVAL_PERIOD=20 \
    TRAIN.AUGTYPE_LAB='default'

Step 2: Generate pseudo labels of unlabeled data

. scripts/coco/eval_stg1.sh.

Evaluate using COCO metrics and save eval.json

# Check pseudo path
if [ ! -d ${PSEUDO_PATH} ]; then
    mkdir -p ${PSEUDO_PATH}
fi

# Evaluate the model for sanity check
# model-180000 is the last checkpoint
# save eval.json at $PSEUDO_PATH

python3 predict.py \
    --evaluate ${PSEUDO_PATH}/eval.json \
    --load "${CKPT_PATH}"/model-180000 \
    --config \
    DATA.BASEDIR=${COCODIR} \
    DATA.TRAIN="('${UNLABELED_DATASET}',)"

Generate pseudo labels for unlabeled data

Set EVAL.PSEUDO_INFERENCE=True to use original images rather than resized ones for inference.

# Extract pseudo label
python3 predict.py \
    --predict_unlabeled ${PSEUDO_PATH} \
    --load "${CKPT_PATH}"/model-180000 \
    --config \
    DATA.BASEDIR=${COCODIR} \
    DATA.TRAIN="('${UNLABELED_DATASET}',)" \
    EVAL.PSEUDO_INFERENCE=True

Step 3: Train STAC

. scripts/coco/train_stg2.sh.

The dataloader loads pseudo labels from ${PSEUDO_PATH}/pseudo_data.npy.

Apply default augmentation on labeled data and strong augmentation on unlabeled data.

TRAIN.CONFIDENCE and TRAIN.WU are two major parameters of the method.

python3 train_stg2.py \
    --logdir=${CKPT_PATH}/STAC --simple_path \
    --pseudo_path=${PSEUDO_PATH} \
    --config \
    BACKBONE.WEIGHTS=${COCODIR}/ImageNet-R50-AlignPadding.npz \
    DATA.BASEDIR=${COCODIR} \
    DATA.TRAIN="('${DATASET}',)" \
    DATA.UNLABEL="('${UNLABELED_DATASET}',)" \
    MODE_MASK=False \
    FRCNN.BATCH_PER_IM=64 \
    PREPROC.TRAIN_SHORT_EDGE_SIZE="[500,800]" \
    TRAIN.EVAL_PERIOD=20 \
    TRAIN.AUGTYPE_LAB='default' \
    TRAIN.AUGTYPE='strong' \
    TRAIN.CONFIDENCE=0.9 \
    TRAIN.WU=2

Tensorboard

All training logs and tensorboard info are under ${PRJROOT}/detection/train_log. Visualize using

tensorboard --logdir=${PRJROOT}/detection/train_log

Citation

@inproceedings{sohn2020detection,
  title={A Simple Semi-Supervised Learning Framework for Object Detection},
  author={Kihyuk Sohn and Zizhao Zhang and Chun-Liang Li and Han Zhang and Chen-Yu Lee and Tomas Pfister},
  year={2020},
  booktitle={arXiv:2005.04757}
}

Acknowledgement

Owner
Google Research
Google Research
利用Tensorflow实现基于CNN的中文短文本分类

Text Classification with CNN 使用卷积神经网络进行中文文本分类 CNN做句子分类的论文可以参看: Convolutional Neural Networks for Sentence Classification 还可以去读dennybritz大牛的博客:Implemen

Jeremiah 4 Nov 08, 2022
Statistical-Rethinking-with-Python-and-PyMC3 - Python/PyMC3 port of the examples in " Statistical Rethinking A Bayesian Course with Examples in R and Stan" by Richard McElreath

Statistical Rethinking with Python and PyMC3 This repository has been deprecated in favour of this one, please check that repository for updates, for

Osvaldo Martin 786 Dec 29, 2022
Official repository for Natural Image Matting via Guided Contextual Attention

GCA-Matting: Natural Image Matting via Guided Contextual Attention The source codes and models of Natural Image Matting via Guided Contextual Attentio

Li Yaoyi 349 Dec 26, 2022
A collection of Google research projects related to Federated Learning and Federated Analytics.

Federated Research Federated Research is a collection of research projects related to Federated Learning and Federated Analytics. Federated learning i

Google Research 483 Jan 05, 2023
Making a music video with Wav2CLIP and VQGAN-CLIP

music2video Overview A repo for making a music video with Wav2CLIP and VQGAN-CLIP. The base code was derived from VQGAN-CLIP The CLIP embedding for au

Joel Jang | 장요엘 163 Dec 26, 2022
Official implementation for the paper: Multi-label Classification with Partial Annotations using Class-aware Selective Loss

Multi-label Classification with Partial Annotations using Class-aware Selective Loss Paper | Pretrained models Official PyTorch Implementation Emanuel

99 Dec 27, 2022
Voila - Voilà turns Jupyter notebooks into standalone web applications

Rendering of live Jupyter notebooks with interactive widgets. Introduction Voilà turns Jupyter notebooks into standalone web applications. Unlike the

Voilà Dashboards 4.5k Jan 03, 2023
House-GAN++: Generative Adversarial Layout Refinement Network towards Intelligent Computational Agent for Professional Architects

House-GAN++ Code and instructions for our paper: House-GAN++: Generative Adversarial Layout Refinement Network towards Intelligent Computational Agent

122 Dec 28, 2022
Deep Learning Based Fasion Recommendation System for Ecommerce

Project Name: Fasion Recommendation System for Ecommerce A Deep learning based streamlit web app which can recommened you various types of fasion prod

BAPPY AHMED 13 Dec 13, 2022
Pure python implementation reverse-mode automatic differentiation

MiniGrad A minimal implementation of reverse-mode automatic differentiation (a.k.a. autograd / backpropagation) in pure Python. Inspired by Andrej Kar

Kenny Song 76 Sep 12, 2022
Re-implement CycleGAN in Tensorlayer

CycleGAN_Tensorlayer Re-implement CycleGAN in TensorLayer Original CycleGAN Improved CycleGAN with resize-convolution Prerequisites: TensorLayer Tenso

89 Aug 15, 2022
🛰️ Awesome Satellite Imagery Datasets

Awesome Satellite Imagery Datasets List of aerial and satellite imagery datasets with annotations for computer vision and deep learning. Newest datase

Christoph Rieke 3k Jan 03, 2023
Improving Generalization Bounds for VC Classes Using the Hypergeometric Tail Inversion

Improving Generalization Bounds for VC Classes Using the Hypergeometric Tail Inversion Preface This directory provides an implementation of the algori

Jean-Samuel Leboeuf 0 Nov 03, 2021
Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more

Apache MXNet (incubating) for Deep Learning Apache MXNet is a deep learning framework designed for both efficiency and flexibility. It allows you to m

The Apache Software Foundation 20.2k Jan 08, 2023
Database Reasoning Over Text project for ACL paper

Database Reasoning over Text This repository contains the code for the Database Reasoning Over Text paper, to appear at ACL2021. Work is performed in

Facebook Research 320 Dec 12, 2022
Kaggle | 9th place (part of) solution for the Bristol-Myers Squibb – Molecular Translation challenge

Part of the 9th place solution for the Bristol-Myers Squibb – Molecular Translation challenge translating images containing chemical structures into I

Erdene-Ochir Tuguldur 22 Nov 30, 2022
DimReductionClustering - Dimensionality Reduction + Clustering + Unsupervised Score Metrics

Dimensionality Reduction + Clustering + Unsupervised Score Metrics Introduction

11 Nov 15, 2022
Using LSTM write Tang poetry

本教程将通过一个示例对LSTM进行介绍。通过搭建训练LSTM网络,我们将训练一个模型来生成唐诗。本文将对该实现进行详尽的解释,并阐明此模型的工作方式和原因。并不需要过多专业知识,但是可能需要新手花一些时间来理解的模型训练的实际情况。为了节省时间,请尽量选择GPU进行训练。

56 Dec 15, 2022
StocksMA is a package to facilitate access to financial and economic data of Moroccan stocks.

Creating easier access to the Moroccan stock market data What is StocksMA ? StocksMA is a package to facilitate access to financial and economic data

Salah Eddine LABIAD 28 Jan 04, 2023
Pyserini is a Python toolkit for reproducible information retrieval research with sparse and dense representations.

Pyserini Pyserini is a Python toolkit for reproducible information retrieval research with sparse and dense representations. Retrieval using sparse re

Castorini 706 Dec 29, 2022