Accelerate Neural Net Training by Progressively Freezing Layers

Overview

FreezeOut

A simple technique to accelerate neural net training by progressively freezing layers.

LRCURVE

This repository contains code for the extended abstract "FreezeOut."

FreezeOut directly accelerates training by annealing layer-wise learning rates to zero on a set schedule, and excluding layers from the backward pass once their learning rate bottoms out.

I had this idea while replying to a reddit comment at 4AM. I threw it in an experiment, and it just worked out of the box (with linear scaling and t_0=0.5), so I went on a 96-hour SCIENCE binge, and now, here we are.

DESIGNCURVE

The exact speedup you get depends on how much error you can tolerate--higher speedups appear to come at the cost of an increase in error, but speedups below 20% should be within a 3% relative error envelope, and speedups around 10% seem to incur no error cost for Scaled Cubic and Unscaled Linear strategies.

Installation

To run this script, you will need PyTorch and a CUDA-capable GPU. If you wish to run it on CPU, just remove all the .cuda() calls.

Running

To run with default parameters, simply call

python train.py

This will by default download CIFAR-100, split it into train, valid, and test sets, then train a k=12 L=76 DenseNet-BC using SGD with Nesterov Momentum.

This script supports command line arguments for a variety of parameters, with the FreezeOut specific parameters being:

  • how_scale selects which annealing strategy to use, among linear, squared, and cubic. Cubic by default.
  • scale_lr determines whether to scale initial learning rates based on t_i. True by default.
  • t_0 is a float between 0 and 1 that decides how far into training to freeze the first layer. 0.8 (pre-cubed) by default.
  • const_time is an experimental setting that increases the number of epochs based on the estimated speedup, in order to match the total training time against a non-FreezeOut baseline. I have not validated if this is worthwhile or not.

You can also set the name of the weights and the metrics log, which model to use, how many epochs to train for, etc.

If you want to calculate an estimated speedup for a given strategy and t_0 value, use the calc_speedup() function in utils.py.

Notes

If you know how to implement this in a static-graph framework (specifically TensorFlow or Caffe2), shoot me an email! It's really easy to do with dynamic graphs, but I believe it to be possible with some simple conditionals in a static graph.

There's (at least) one typo in the paper where it defines the learning rate schedule, there should be a 1/2 in front of alpha.

Acknowledgments

Owner
Andy Brock
Dimensionality Diabolist
Andy Brock
Semi-supervised Implicit Scene Completion from Sparse LiDAR

Semi-supervised Implicit Scene Completion from Sparse LiDAR Paper Created by Pengfei Li, Yongliang Shi, Tianyu Liu, Hao Zhao, Guyue Zhou and YA-QIN ZH

114 Nov 30, 2022
Direct Multi-view Multi-person 3D Human Pose Estimation

Implementation of NeurIPS-2021 paper: Direct Multi-view Multi-person 3D Human Pose Estimation [paper] [video-YouTube, video-Bilibili] [slides] This is

Sea AI Lab 251 Dec 30, 2022
Code for WECHSEL: Effective initialization of subword embeddings for cross-lingual transfer of monolingual language models.

WECHSEL Code for WECHSEL: Effective initialization of subword embeddings for cross-lingual transfer of monolingual language models. arXiv: https://arx

Institute of Computational Perception 45 Dec 29, 2022
PyTorch implementation of CloudWalk's recent work DenseBody

densebody_pytorch PyTorch implementation of CloudWalk's recent paper DenseBody. Note: For most recent updates, please check out the dev branch. Update

Lingbo Yang 401 Nov 19, 2022
Fine-grained Control of Image Caption Generation with Abstract Scene Graphs

Faster R-CNN pretrained on VisualGenome This repository modifies maskrcnn-benchmark for object detection and attribute prediction on VisualGenome data

Shizhe Chen 7 Apr 20, 2021
RDA: Robust Domain Adaptation via Fourier Adversarial Attacking

RDA: Robust Domain Adaptation via Fourier Adversarial Attacking Updates 08/2021: check out our domain adaptation for video segmentation paper Domain A

17 Nov 30, 2022
最新版本yolov5+deepsort目标检测和追踪,支持5.0版本可训练自己数据集

使用YOLOv5+Deepsort实现车辆行人追踪和计数,代码封装成一个Detector类,更容易嵌入到自己的项目中。

422 Dec 30, 2022
Protect against subdomain takeover

domain-protect scans Amazon Route53 across an AWS Organization for domain records vulnerable to takeover deploy to security audit account scan your en

OVO Technology 0 Nov 17, 2022
a reimplementation of Holistically-Nested Edge Detection in PyTorch

pytorch-hed This is a personal reimplementation of Holistically-Nested Edge Detection [1] using PyTorch. Should you be making use of this work, please

Simon Niklaus 375 Dec 06, 2022
Code for the preprint "Well-classified Examples are Underestimated in Classification with Deep Neural Networks"

This is a repository for the paper of "Well-classified Examples are Underestimated in Classification with Deep Neural Networks" The implementation and

LancoPKU 25 Dec 11, 2022
YOLOv5 in PyTorch > ONNX > CoreML > TFLite

This repository represents Ultralytics open-source research into future object detection methods, and incorporates lessons learned and best practices evolved over thousands of hours of training and e

Ultralytics 34.1k Dec 31, 2022
Repository for GNSS-based position estimation using a Deep Neural Network

Code repository accompanying our work on 'Improving GNSS Positioning using Neural Network-based Corrections'. In this paper, we present a Deep Neural

32 Dec 13, 2022
A simple pygame dino game which can also be trained and played by a NEAT KI

Dino Game AI Game The game itself was developed with the Pygame module pip install pygame You can also play it yourself by making the dino jump with t

Kilian Kier 7 Dec 05, 2022
Code of Periodic Activation Functions Induce Stationarity

Periodic Activation Functions Induce Stationarity This repository is the official implementation of the methods in the publication: L. Meronen, M. Tra

AaltoML 12 Jun 07, 2022
Concept drift monitoring for HA model servers.

{Fast, Correct, Simple} - pick three Easily compare training and production ML data & model distributions Goals Boxkite is an instrumentation library

98 Dec 15, 2022
Direct LiDAR Odometry: Fast Localization with Dense Point Clouds

Direct LiDAR Odometry: Fast Localization with Dense Point Clouds DLO is a lightweight and computationally-efficient frontend LiDAR odometry solution w

VECTR at UCLA 369 Dec 30, 2022
Unofficial implementation of "TTNet: Real-time temporal and spatial video analysis of table tennis" (CVPR 2020)

TTNet-Pytorch The implementation for the paper "TTNet: Real-time temporal and spatial video analysis of table tennis" An introduction of the project c

Nguyen Mau Dung 438 Dec 29, 2022
This is the code repository for the paper "Identification of the Generalized Condorcet Winner in Multi-dueling Bandits" (NeurIPS 2021).

Code Repository for the Paper "Identification of the Generalized Condorcet Winner in Multi-dueling Bandits" (To appear in: Proceedings of NeurIPS20

1 Oct 03, 2022
A framework for using LSTMs to detect anomalies in multivariate time series data. Includes spacecraft anomaly data and experiments from the Mars Science Laboratory and SMAP missions.

Telemanom (v2.0) v2.0 updates: Vectorized operations via numpy Object-oriented restructure, improved organization Merge branches into single branch fo

Kyle Hundman 844 Dec 28, 2022
The Python code for the paper A Hybrid Quantum-Classical Algorithm for Robust Fitting

About The Python code for the paper A Hybrid Quantum-Classical Algorithm for Robust Fitting The demo program was only tested under Conda in a standard

Anh-Dzung Doan 5 Nov 28, 2022