ESGD-M - A stochastic non-convex second order optimizer, suitable for training deep learning models, for PyTorch

Related tags

Deep Learningesgd
Overview

ESGD-M

ESGD-M is a stochastic non-convex second order optimizer, suitable for training deep learning models. It is based on ESGD (Equilibrated adaptive learning rates for non-convex optimization) and incorporates quasi-hyperbolic momentum (Quasi-hyperbolic momentum and Adam for deep learning) to accelerate convergence, which considerably improves its performance over plain ESGD.

ESGD-M obtains Hessian information through occasional Hessian-vector products (by default, every ten optimizer steps; each Hessian-vector product is approximately the same cost as a gradient evaluation) and uses it to adapt per-parameter learning rates. It estimates the diagonal of the absolute Hessian, diag(|H|), to use as a diagonal preconditioner.

To use this optimizer you must call .backward() with the create_graph=True option. Gradient accumulation steps and distributed training are currently not supported.

Learning rates

ESGD-M learning rates have a different meaning from SGD and Adagrad/Adam/etc. You may need to try learning rates in the range 1e-3 to 1.

SGD class optimizers:

  • If you rescale your parameters by a factor of n, you must scale your learning rate by a factor of n^2.

  • If you rescale your loss by a factor of n, you must scale your learning rate by a factor of 1 / n.

Adagrad/Adam class optimizers:

  • If you rescale your parameters by a factor of n, you must scale your learning rate by a factor of n.

  • If you rescale your loss by a factor of n, you do not have to scale your learning rate.

Second order optimizers (including ESGD-M):

  • You do not have to scale your learning rate if you rescale either your parameters or your loss.

Momentum

The default configuration is Nesterov momentum (if v is not specified then it will default to the value of beta_1, producing Nesterov momentum):

opt = ESGD(model.parameters(), lr=1, betas=(0.9, 0.999), v=0.9)

The Quasi-Hyperbolic Momentum recommended defaults can be obtained using:

opt = ESGD(model.parameters(), lr=1, betas=(0.999, 0.999), v=0.7)

Setting v equal to 1 will do normal (non-Nesterov) momentum.

The ESGD-M decay coefficient beta_2 refers not to the squared gradient as in Adam but to the squared Hessian diagonal estimate, which it uses in place of the squared gradient to provide per-parameter adaptive learning rates.

Hessian-vector products

The absolute Hessian diagonal diag(|H|) is estimated every update_d_every steps. The default is 10. Also, for the first d_warmup steps the diagonal will be estimated regardless, to obtain a lower variance estimate of diag(|H|) quickly. The estimation uses a Hessian-vector product, which takes around the same amount of time as a gradient evaluation to compute. You must explicitly signal to PyTorch that you want to do a double backward pass by:

opt.zero_grad(set_to_none=True)
loss = loss_fn(model(inputs), targets)
loss.backward(create_graph=True)
opt.step()

Weight decay

Weight decay is performed separately from the Hessian-vector product and the preconditioner, similar to AdamW except that the weight decay value provided by the user is multiplied by the current learning rate to determine the factor to decay the weights by.

Learning rate warmup

Because the diag(|H|) estimates are high variance, the adaptive learning rates are not very reliable before many steps have been taken and many estimates have been averaged together. To deal with this ESGD-M has a short exponential learning rate warmup by default (it is combined with any external learning rate schedulers). On each step (starting from 1) the learning rate will be:

lr * (1 - lr_warmup**step)

The default value for lr_warmup is 0.99, which reaches 63% of the specified learning rate in 100 steps and 95% in 300 steps.

Owner
Katherine Crowson
AI/generative artist.
Katherine Crowson
Vector Quantized Diffusion Model for Text-to-Image Synthesis

Vector Quantized Diffusion Model for Text-to-Image Synthesis Due to company policy, I have to set microsoft/VQ-Diffusion to private for now, so I prov

Shuyang Gu 294 Jan 05, 2023
ESP32 python application to read data from a Tilt™ Hydrometer for homebrewing

TitlESP32 ESP32 MicroPython application to read and log data from a Tilt™ Hydrometer. Requirements A board with an ESP32 chip USB cable - USB A / micr

IoBeer 5 Dec 01, 2022
A Library for Modelling Probabilistic Hierarchical Graphical Models in PyTorch

A Library for Modelling Probabilistic Hierarchical Graphical Models in PyTorch

Korbinian Pöppel 47 Nov 28, 2022
Code for Mining the Benefits of Two-stage and One-stage HOI Detection

Status: Archive (code is provided as-is, no updates expected) PPO-EWMA [Paper] This is code for training agents using PPO-EWMA and PPG-EWMA, introduce

OpenAI 33 Dec 15, 2022
Code for the paper "JANUS: Parallel Tempered Genetic Algorithm Guided by Deep Neural Networks for Inverse Molecular Design"

JANUS: Parallel Tempered Genetic Algorithm Guided by Deep Neural Networks for Inverse Molecular Design This repository contains code for the paper: JA

Aspuru-Guzik group repo 55 Nov 29, 2022
Pytorch and Keras Implementations of Hyperspectral Image Classification -- Traditional to Deep Models: A Survey for Future Prospects.

The repository contains the implementations for Hyperspectral Image Classification -- Traditional to Deep Models: A Survey for Future Prospects. Model

Ankur Deria 115 Jan 06, 2023
Retrieval.pytorch - The code we used in [2020 DIGIX]

Retrieval.pytorch - The code we used in [2020 DIGIX]

Guo-Hua Wang 2 Feb 07, 2022
🌊 Online machine learning in Python

In a nutshell River is a Python library for online machine learning. It is the result of a merger between creme and scikit-multiflow. River's ambition

OnlineML 4k Jan 02, 2023
Simple implementation of OpenAI CLIP model in PyTorch.

It was in January of 2021 that OpenAI announced two new models: DALL-E and CLIP, both multi-modality models connecting texts and images in some way. In this article we are going to implement CLIP mod

Moein Shariatnia 226 Jan 05, 2023
Canonical Capsules: Unsupervised Capsules in Canonical Pose (NeurIPS 2021)

Canonical Capsules: Unsupervised Capsules in Canonical Pose (NeurIPS 2021) Introduction This is the official repository for the PyTorch implementation

165 Dec 07, 2022
Full Stack Deep Learning Labs

Full Stack Deep Learning Labs Welcome! Project developed during lab sessions of the Full Stack Deep Learning Bootcamp. We will build a handwriting rec

Full Stack Deep Learning 1.2k Dec 31, 2022
particle tracking model, works with the ROMS output file(qck.nc, his.nc)

particle-tracking-model-for-ROMS particle tracking model, works with the ROMS output file(qck.nc, his.nc) description this is a 2-dimensional particle

xusheng 1 Jan 11, 2022
Zero-shot Learning by Generating Task-specific Adapters

Code for "Zero-shot Learning by Generating Task-specific Adapters" This is the repository containing code for "Zero-shot Learning by Generating Task-s

INK Lab @ USC 11 Dec 17, 2021
Official PyTorch Implementation of HELP: Hardware-adaptive Efficient Latency Prediction for NAS via Meta-Learning (NeurIPS 2021 Spotlight)

[NeurIPS 2021 Spotlight] HELP: Hardware-adaptive Efficient Latency Prediction for NAS via Meta-Learning [Paper] This is Official PyTorch implementatio

42 Nov 01, 2022
A Python package for causal inference using Synthetic Controls

Synthetic Control Methods A Python package for causal inference using synthetic controls This Python package implements a class of approaches to estim

Oscar Engelbrektson 107 Dec 28, 2022
This repository contains codes of ICCV2021 paper: SO-Pose: Exploiting Self-Occlusion for Direct 6D Pose Estimation

SO-Pose This repository contains codes of ICCV2021 paper: SO-Pose: Exploiting Self-Occlusion for Direct 6D Pose Estimation This paper is basically an

shangbuhuan 52 Nov 25, 2022
Code to go with the paper "Decentralized Bayesian Learning with Metropolis-Adjusted Hamiltonian Monte Carlo"

dblmahmc Code to go with the paper "Decentralized Bayesian Learning with Metropolis-Adjusted Hamiltonian Monte Carlo" Requirements: https://github.com

1 Dec 17, 2021
A machine learning project which can detect and predict the skin disease through image recognition.

ML-Project-2021 A machine learning project which can detect and predict the skin disease through image recognition. The dataset used for this is the H

Debshishu Ghosh 1 Jan 13, 2022
Locationinfo - A script helps the user to show network information such as ip address

Description This script helps the user to show network information such as ip ad

Roxcoder 1 Dec 30, 2021
Franka Emika Panda manipulator kinematics&dynamics simulation

pybullet_sim_panda Pybullet simulation environment for Franka Emika Panda Dependency pybullet, numpy, spatial_math_mini Simple example (please check s

0 Jan 20, 2022