A PyTorch implementation of the continual learning experiments with deep neural networks

Overview

Brain-Inspired Replay

A PyTorch implementation of the continual learning experiments with deep neural networks described in the following paper:

This paper proposes a new, brain-inspired version of generative replay that can scale to continual learning problems with natural images as inputs. This is demonstrated with the Split CIFAR-100 protocol, both for task-incremental learning and for class-incremental learning.

Installation & requirements

The current version of the code has been tested with Python 3.5.2 on several Linux operating systems with the following versions of PyTorch and Torchvision:

  • pytorch 1.1.0
  • torchvision 0.2.2

The versions that were used for other Python-packages are listed in requirements.txt.

To use the code, download the repository and change into it:

git clone https://github.com/GMvandeVen/brain-inspired-replay.git
cd brain-inspired-replay

(If downloading the zip-file, extract the files and change into the extracted folder.)

Assuming Python and pip are set up, the Python-packages used by this code can be installed using:

pip install -r requirements.txt

However, you might want to install pytorch and torchvision in a slightly different way to ensure compatability with your version of CUDA (see https://pytorch.org/).

Finally, the code in this repository itself does not need to be installed, but a number of scripts should be made executable:

chmod +x main_*.py compare_*.py create_figures.sh

Demos

Demo 1: Brain-inspired replay on split MNIST

./main_cl.py --experiment=splitMNIST --scenario=class --replay=generative --brain-inspired --pdf

This runs a single continual learning experiment: brain-inspired replay on the class-incremental learning scenario of split MNIST. Information about the data, the model, the training progress and the produced outputs (e.g., a pdf with results) is printed to the screen. Expected run-time on a standard laptop is ~12 minutes, with a GPU it should take ~4 minutes.

Demo 2: Comparison of continual learning methods

./compare_MNIST.py --scenario=class

This runs a series of continual learning experiments to compare the performance of various methods. Information about the different experiments, their progress and the produced outputs (e.g., a summary pdf) is printed to the screen. Expected run-time on a standard laptop is ~50 minutes, with a GPU it should take ~18 minutes.

These two demos can also be run with on-the-fly plots using the flag --visdom. For this visdom must be activated first, see instructions below.

Running comparisons from the paper

The script create_figures.sh provides step-by-step instructions for re-running the experiments and re-creating the figures reported in the paper.

Although it is possible to run this script as it is, it will take very long and it is probably sensible to parallellize the experiments.

Running custom experiments

Using main_cl.py, it is possible to run custom individual experiments. The main options for this script are:

  • --experiment: which task protocol? (splitMNIST|permMNIST|CIFAR100)
  • --scenario: according to which scenario? (task|domain|class)
  • --tasks: how many tasks?

To run specific methods, use the following:

  • Context-dependent-Gating (XdG): ./main_cl.py --xdg --xdg-prop=0.8
  • Elastic Weight Consolidation (EWC): ./main_cl.py --ewc --lambda=5000
  • Online EWC: ./main_cl.py --ewc --online --lambda=5000 --gamma=1
  • Synaptic Intelligenc (SI): ./main_cl.py --si --c=0.1
  • Learning without Forgetting (LwF): ./main_cl.py --replay=current --distill
  • Generative Replay (GR): ./main_cl.py --replay=generative
  • Brain-Inspired Replay (BI-R): ./main_cl.py --replay=generative --brain-inspired

For information on further options: ./main_cl.py -h.

PyTorch-implementations for several methods relying on stored data (Experience Replay, iCaRL and A-GEM), as well as for additional metrics (FWT, BWT, forgetting, intransigence), can be found here: https://github.com/GMvandeVen/continual-learning.

On-the-fly plots during training

With this code it is possible to track progress during training with on-the-fly plots. This feature requires visdom. Before running the experiments, the visdom server should be started from the command line:

python -m visdom.server

The visdom server is now alive and can be accessed at http://localhost:8097 in your browser (the plots will appear there). The flag --visdom should then be added when calling ./main_cl.py to run the experiments with on-the-fly plots.

For more information on visdom see https://github.com/facebookresearch/visdom.

Citation

Please consider citing our paper if you use this code in your research:

@article{vandeven2020brain,
  title={Brain-inspired replay for continual learning with artificial neural networks},
  author={van de Ven, Gido M and Siegelmann, Hava T and Tolias, Andreas S},
  journal={Nature Communications},
  volume={11},
  pages={4069},
  year={2020}
}

Acknowledgments

The research project from which this code originated has been supported by an IBRO-ISN Research Fellowship, by the Lifelong Learning Machines (L2M) program of the Defence Advanced Research Projects Agency (DARPA) via contract number HR0011-18-2-0025 and by the Intelligence Advanced Research Projects Activity (IARPA) via Department of Interior/Interior Business Center (DoI/IBC) contract number D16PC00003. Disclaimer: views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies or endorsements, either expressed or implied, of DARPA, IARPA, DoI/IBC, or the U.S. Government.

Owner
Working at the intersection of Machine Learning, Computational Neuroscience and Cognitive Science.
Neural network for digit classification powered by cuda

cuda_nn_mnist Neural network library for digit classification powered by cuda Resources The library was built to work with MNIST dataset. python-mnist

Nikita Ardashev 1 Dec 20, 2021
Official Implementation and Dataset of "PPR10K: A Large-Scale Portrait Photo Retouching Dataset with Human-Region Mask and Group-Level Consistency", CVPR 2021

Portrait Photo Retouching with PPR10K Paper | Supplementary Material PPR10K: A Large-Scale Portrait Photo Retouching Dataset with Human-Region Mask an

184 Dec 11, 2022
A PyTorch implementation of "CoAtNet: Marrying Convolution and Attention for All Data Sizes".

CoAtNet Overview This is a PyTorch implementation of CoAtNet specified in "CoAtNet: Marrying Convolution and Attention for All Data Sizes", arXiv 2021

Justin Wu 268 Jan 07, 2023
The hippynn python package - a modular library for atomistic machine learning with pytorch.

The hippynn python package - a modular library for atomistic machine learning with pytorch. We aim to provide a powerful library for the training of a

Los Alamos National Laboratory 37 Dec 29, 2022
Build and run Docker containers leveraging NVIDIA GPUs

NVIDIA Container Toolkit Introduction The NVIDIA Container Toolkit allows users to build and run GPU accelerated Docker containers. The toolkit includ

NVIDIA Corporation 15.6k Jan 01, 2023
From Perceptron model to Deep Neural Network from scratch in Python.

Neural-Network-Basics Aim of this Repository: From Perceptron model to Deep Neural Network (from scratch) in Python. ** Currently working on a basic N

Aditya Kahol 1 Jan 14, 2022
VolumeGAN - 3D-aware Image Synthesis via Learning Structural and Textural Representations

VolumeGAN - 3D-aware Image Synthesis via Learning Structural and Textural Representations 3D-aware Image Synthesis via Learning Structural and Textura

GenForce: May Generative Force Be with You 116 Dec 26, 2022
A framework for analyzing computer vision models with simulated data

3DB: A framework for analyzing computer vision models with simulated data Paper Quickstart guide Blog post Installation Follow instructions on: https:

3DB 112 Jan 01, 2023
Explainability for Vision Transformers (in PyTorch)

Explainability for Vision Transformers (in PyTorch) This repository implements methods for explainability in Vision Transformers

Jacob Gildenblat 442 Jan 04, 2023
Permute Me Softly: Learning Soft Permutations for Graph Representations

Permute Me Softly: Learning Soft Permutations for Graph Representations

Giannis Nikolentzos 7 Jul 10, 2022
A data-driven maritime port simulator

PySeidon - A Data-Driven Maritime Port Simulator 🌊 Extendable and modular software for maritime port simulation. This software uses entity-component

6 Apr 10, 2022
RETRO-pytorch - Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch

RETRO - Pytorch (wip) Implementation of RETRO, Deepmind's Retrieval based Attent

Phil Wang 556 Jan 04, 2023
这是一个yolo3-tf2的源码,可以用于训练自己的模型。

YOLOV3:You Only Look Once目标检测模型在Tensorflow2当中的实现 目录 性能情况 Performance 所需环境 Environment 文件下载 Download 训练步骤 How2train 预测步骤 How2predict 评估步骤 How2eval 参考资料

Bubbliiiing 68 Dec 21, 2022
Deepparse is a state-of-the-art library for parsing multinational street addresses using deep learning

Here is deepparse. Deepparse is a state-of-the-art library for parsing multinational street addresses using deep learning. Use deepparse to Use the pr

GRAAL/GRAIL 192 Dec 20, 2022
This repo holds codes of the ICCV21 paper: Visual Alignment Constraint for Continuous Sign Language Recognition.

VAC_CSLR This repo holds codes of the paper: Visual Alignment Constraint for Continuous Sign Language Recognition.(ICCV 2021) [paper] Prerequisites Th

Yuecong Min 64 Dec 19, 2022
3D cascade RCNN for object detection on point cloud

3D Cascade RCNN This is the implementation of 3D Cascade RCNN: High Quality Object Detection in Point Clouds. We designed a 3D object detection model

Qi Cai 22 Dec 02, 2022
Uncertainty-aware Semantic Segmentation of LiDAR Point Clouds for Autonomous Driving

SalsaNext: Fast, Uncertainty-aware Semantic Segmentation of LiDAR Point Clouds for Autonomous Driving Abstract In this paper, we introduce SalsaNext f

308 Jan 04, 2023
Towards Calibrated Model for Long-Tailed Visual Recognition from Prior Perspective

Towards Calibrated Model for Long-Tailed Visual Recognition from Prior Perspective Zhengzhuo Xu, Zenghao Chai, Chun Yuan This is the PyTorch implement

Sincere 16 Dec 15, 2022
Code for the paper "Functional Regularization for Reinforcement Learning via Learned Fourier Features"

Reinforcement Learning with Learned Fourier Features State-space Soft Actor-Critic Experiments Move to the state-SAC-LFF repository. cd state-SAC-LFF

Alex Li 10 Nov 11, 2022
Official Repository for the paper "Improving Baselines in the Wild".

iWildCam and FMoW baselines (WILDS) This repository was originally forked from the official repository of WILDS datasets (commit 7e103ed) For general

Kazuki Irie 3 Nov 24, 2022