We present a regularized self-labeling approach to improve the generalization and robustness properties of fine-tuning.

Overview

Overview

This repository provides the implementation for the paper "Improved Regularization and Robustness for Fine-tuning in Neural Networks", which will be presented as a poster paper in NeurIPS'21.

In this work, we propose a regularized self-labeling approach that combines regularization and self-training methods for improving the generalization and robustness properties of fine-tuning. Our approach includes two components:

  • First, we encode layer-wise regularization to penalize the model weights at different layers of the neural net.
  • Second, we add self-labeling that relabels data points based on current neural net's belief and reweights data points whose confidence is low.

Requirements

To install requirements:

pip install -r requirements.txt

Data Preparation

We use seven image datasets in our paper. We list the link for downloading these datasets and describe how to prepare data to run our code below.

  • Aircrafts: download and extract into ./data/aircrafts
    • remove the class 257.clutter out of the data directory
  • CUB-200-2011: download and extract into ./data/CUB_200_2011/
  • Caltech-256: download and extract into ./data/caltech256/
  • Stanford-Cars: download and extract into ./data/StanfordCars/
  • Stanford-Dogs: download and extract into ./data/StanfordDogs/
  • Flowers: download and extract into ./data/flowers/
  • MIT-Indoor: download and extract into ./data/Indoor/

Our code automatically handles the split of the datasets.

Usage

Our algorithm (RegSL) interpolates between layer-wise regularization and self-labeling. Run the following commands for conducting experiments in this paper.

Fine-tuning ResNet-101 on image classification tasks.

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_indoor.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.136809975858091 --reg_predictor 6.40780158171339 --scale_factor 2.52883770643206\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_aircrafts.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 1.18330556653284 --reg_predictor 5.27713618808711 --scale_factor 1.27679969876201\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_birds.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.204403908747731 --reg_predictor 23.7850606577679 --scale_factor 4.73803591794678\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_caltech.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.0867998872549272 --reg_predictor 9.4552942790218 --scale_factor 1.1785989596144\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_cars.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 1.3340347414257 --reg_predictor 8.26940794089601 --scale_factor 3.47676759842434\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_dogs.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.0561320847651626 --reg_predictor 4.46281825974388 --scale_factor 1.58722606909531\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_flower.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.131991042311165 --reg_predictor 10.7674132173309 --scale_factor 4.98010215976503\
    --device 1

Fine-tuning ResNet-18 under label noise.

python train_label_noise.py --config configs/config_constraint_indoor.json --model ResNet18 \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 7.80246991703043 --reg_predictor 14.077402847906 \
    --noise_rate 0.2 --train_correct_label --reweight_epoch 5 --reweight_temp 2.0 --correct_epoch 10 --correct_thres 0.9 

python train_label_noise.py --config configs/config_constraint_indoor.json --model ResNet18 \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 8.47139398080791 --reg_predictor 19.0191127114923 \
    --noise_rate 0.4 --train_correct_label --reweight_epoch 5 --reweight_temp 2.0 --correct_epoch 10 --correct_thres 0.9 

python train_label_noise.py --config configs/config_constraint_indoor.json --model ResNet18 \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 10.7576018531961 --reg_predictor 19.8157649727473 \
    --noise_rate 0.6 --train_correct_label --reweight_epoch 5 --reweight_temp 2.0 --correct_epoch 10 --correct_thres 0.9 
    
python train_label_noise.py --config configs/config_constraint_indoor.json --model ResNet18 \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 9.2031662757248 --reg_predictor 6.41568500472423 \
    --noise_rate 0.8 --train_correct_label --reweight_epoch 5 --reweight_temp 1.5 --correct_epoch 10 --correct_thres 0.9 

Fine-tuning Vision Transformer on noisy labels.

python train_label_noise.py --config configs/config_constraint_indoor.json \
    --model VisionTransformer --is_vit --img_size 224 --vit_type ViT-B_16 --vit_pretrained_dir pretrained/imagenet21k_ViT-B_16.npz \
    --reg_method none --reg_norm none \
    --lr 0.0001 --device 1 --noise_rate 0.4

python train_label_noise.py --config configs/config_constraint_indoor.json \
    --model VisionTransformer --is_vit --img_size 224 --vit_type ViT-B_16 --vit_pretrained_dir pretrained/imagenet21k_ViT-B_16.npz \
    --reg_method none --reg_norm none \
    --lr 0.0001 --device 1 --noise_rate 0.8

python train_label_noise.py --config configs/config_constraint_indoor.json \
    --model VisionTransformer --is_vit --img_size 224 --vit_type ViT-B_16 --vit_pretrained_dir pretrained/imagenet21k_ViT-B_16.npz \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.7488074175044196 --reg_predictor 9.842955837419588 \
    --train_correct_label --reweight_epoch 24 --correct_epoch 18\
    --lr 0.0001 --device 1 --noise_rate 0.4

python train_label_noise.py --config configs/config_constraint_indoor.json \
    --model VisionTransformer --is_vit --img_size 224 --vit_type ViT-B_16 --vit_pretrained_dir pretrained/imagenet21k_ViT-B_16.npz \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.1568903647089986 --reg_predictor 1.407080880079702 \
    --train_correct_label --reweight_epoch 18 --correct_epoch 2\
    --lr 0.0001 --device 1 --noise_rate 0.8

Please follow the instructions in ViT-pytorch to download the pre-trained models.

Fine-tuning ResNet-18 on ChestX-ray14 data set.

Run experiments on ChestX-ray14 in reproduce-chexnet path:

cd reproduce-chexnet

python retrain.py --reg_method None --reg_norm None --device 0

python retrain.py --reg_method constraint --reg_norm frob \
    --reg_extractor 5.728564437344309 --reg_predictor 2.5669480884876905 --scale_factor 1.0340072757925474 \
    --device 0

Citation

If you find this repository useful, consider citing our work titled above.

Acknowledgment

Thanks to the authors of the following repositories for providing their implementation publicly available.

Owner
NEU-StatsML-Research
We are a group of faculty and students from the Computer Science College of Northeastern University
NEU-StatsML-Research
This project intends to use SVM supervised learning to determine whether or not an individual is diabetic given certain attributes.

Diabetes Prediction Using SVM I explore a diabetes prediction algorithm using a Diabetes dataset. Using a Support Vector Machine for my prediction alg

Jeff Shen 1 Jan 14, 2022
Time-stretch audio clips quickly with PyTorch (CUDA supported)! Additional utilities for searching efficient transformations are included.

Time-stretch audio clips quickly with PyTorch (CUDA supported)! Additional utilities for searching efficient transformations are included.

Kento Nishi 22 Jul 07, 2022
Saeed Lotfi 28 Dec 12, 2022
CVPR2022 paper "Dense Learning based Semi-Supervised Object Detection"

[CVPR2022] DSL: Dense Learning based Semi-Supervised Object Detection DSL is the first work on Anchor-Free detector for Semi-Supervised Object Detecti

Bhchen 69 Dec 08, 2022
An implementation for `Text2Event: Controllable Sequence-to-Structure Generation for End-to-end Event Extraction`

Text2Event An implementation for Text2Event: Controllable Sequence-to-Structure Generation for End-to-end Event Extraction Please contact Yaojie Lu (@

Roger 153 Jan 07, 2023
Knowledge Distillation Toolbox for Semantic Segmentation

SegDistill: Toolbox for Knowledge Distillation on Semantic Segmentation Networks This repo contains the supported code and configuration files for Seg

9 Dec 12, 2022
An end-to-end implementation of intent prediction with Metaflow and other cool tools

You Don't Need a Bigger Boat An end-to-end (Metaflow-based) implementation of an intent prediction flow for kids who can't MLOps good and wanna learn

Jacopo Tagliabue 614 Dec 31, 2022
Books, Presentations, Workshops, Notebook Labs, and Model Zoo for Software Engineers and Data Scientists wanting to learn the TF.Keras Machine Learning framework

Books, Presentations, Workshops, Notebook Labs, and Model Zoo for Software Engineers and Data Scientists wanting to learn the TF.Keras Machine Learning framework

Google Cloud Platform 792 Dec 28, 2022
An Inverse Kinematics library aiming performance and modularity

IKPy Demo Live demos of what IKPy can do (click on the image below to see the video): Also, a presentation of IKPy: Presentation. Features With IKPy,

Pierre Manceron 481 Jan 02, 2023
VLG-Net: Video-Language Graph Matching Networks for Video Grounding

VLG-Net: Video-Language Graph Matching Networks for Video Grounding Introduction Official repository for VLG-Net: Video-Language Graph Matching Networ

Mattia Soldan 25 Dec 04, 2022
Reinforcement learning library in JAX.

Reinforcement learning library in JAX.

Yicheng Luo 96 Oct 30, 2022
A modification of Daniel Russell's notebook merged with Katherine Crowson's hq-skip-net changes

Edits made to this repo by Katherine Crowson I have added several features to this repository for use in creating higher quality generative art (featu

Paul Fishwick 10 May 07, 2022
Dataset and Source code of paper 'Enhancing Keyphrase Extraction from Academic Articles with their Reference Information'.

Enhancing Keyphrase Extraction from Academic Articles with their Reference Information Overview Dataset and code for paper "Enhancing Keyphrase Extrac

15 Nov 24, 2022
Cognate Detection Repository

Cognate Detection Repository Details This repository contains the data for two publications: Challenge Dataset of Cognates and False Friend Pairs from

Diptesh Kanojia 1 Apr 26, 2022
Semi-supervised Stance Detection of Tweets Via Distant Network Supervision

SANDS This is an annonymous repository containing code and data necessary to reproduce the results published in "Semi-supervised Stance Detection of T

2 Sep 22, 2022
10th place solution for Google Smartphone Decimeter Challenge at kaggle.

Under refactoring 10th place solution for Google Smartphone Decimeter Challenge at kaggle. Google Smartphone Decimeter Challenge Global Navigation Sat

12 Oct 25, 2022
Official repository of the paper "GPR1200: A Benchmark for General-PurposeContent-Based Image Retrieval"

GPR1200 Dataset GPR1200: A Benchmark for General-Purpose Content-Based Image Retrieval (ArXiv) Konstantin Schall, Kai Uwe Barthel, Nico Hezel, Klaus J

Visual Computing Group 16 Nov 21, 2022
A modular, open and non-proprietary toolkit for core robotic functionalities by harnessing deep learning

A modular, open and non-proprietary toolkit for core robotic functionalities by harnessing deep learning Website • About • Installation • Using OpenDR

OpenDR 304 Dec 28, 2022
Turi Create simplifies the development of custom machine learning models.

Quick Links: Installation | Documentation | WWDC 2019 | WWDC 2018 Turi Create Check out our talks at WWDC 2019 and at WWDC 2018! Turi Create simplifie

Apple 10.9k Jan 01, 2023