An implementation of Geoffrey Hinton's paper "How to represent part-whole hierarchies in a neural network" in Pytorch.

Related tags

Deep LearningGLOM
Overview

GLOM

An implementation of Geoffrey Hinton's paper "How to represent part-whole hierarchies in a neural network" for MNIST Dataset. To understand this implementation, please watch Yannick Kilcher's GLOM video, then read this README.md, then read the code.

Running

Open in jupyter notebook to run. Program expects an Nvidia graphics card for gpu speedup. If you run out of gpu memory, decrease the batch_size variable. If you want to look at the code on github and it fails, try reloading or refreshing several times.

Results

The best models, which have been posted under the best_models folder, reached an accuracy of about 91%.

Implementation details

Three Types of networks per layer of vectors

  1. Top-Down Network
  2. Bottom-up Network
  3. Attention on the same layer Network

Intro to State

There is an initial state that all three types of network outputs get added to after every time step. The bottom layer of the state is the input vector where the MNIST pixel data is kept and doesn't get anything added to it to retain the MNIST pixel data. The top layer of the state is the output layer where the loss function is applied and trained to be the one-hot MNIST target vector.

Explanation of compute_all function

Each type of network will see a 3x3 grid of vectors surrounding the current network input vector at the current layer. This is done to allow information to travel faster laterally across vectors, allowing for more information to be sent across an image in less steps. The easy way to do this is to shift (or roll) every vector along the x and y axis and then concatenate the vectors ontop of eachother so that every place a vector used to be in the state, now contains every vector and its neighboring vectors in the same layer. This also connects the edges of the image so that data can be passed from one edge of the image to the other, reducing the maximum distance any two pixels or vectors can be from one another.

For a more complex dataset, its possible this could pose some issues since two separate edges of an image aren't generally continous, but for MNIST, this problem doesn't arise. Then, these vectors are fed to each type of model. The models will get an input of all neighboring state vectors for a certain layer for each pixel that is given. Each model will then output a single vector. But there are 3 types of models per layer. In this example, every line drawn is a new model that is reused for every pixel this process is done for. After each model type has given an output, the three lists of vectors are added together.

This will give a single list of vectors that will be added to the corresponding list of vectors at the specific x,y coordinate from the original state.

Repeating this step for every list of vectors per x,y coordinate in the original state will yield the full new State value.

Since each network only sees a 3x3 grid and not larger image patches, this technique can be used for any size images and is easily parrallelizable.

If I had more compute

My 2080Ti runs into memory errors running this if the batch size is above around 30, so here are my implementatin ideas if I had more compute.

  1. Increase batch_size. This probably wont affect the training, but it would make testing the accuracy faster.
  2. Saving more states throughout the steps taken and adding them together. This would allow for gradients to get passed back to the original state similar to how RESNET can train very large model since the gradients can get passed backwards easier. This has been implemented to a smaller degree already and showed massive accuracy improvements.
  3. Perform some kind of evolutionary parameter search by mutating the model parameters while also using backprop. This has been shown to improve the accuracy of image classifiers and other models. But this would take a ton of compute.

Yannic Kilcher's Attention

This hass been pushed to github because during testing and tuning hyperparameters, a better model than previous was found. More testing needs to be done and I'm working on the visual explanation for it now. Previous versions of this code don't have the attention seen in the current version and will have similar performance.

Other Ideas behind the paper implementation

This is basically a neural cellular automata from the paper Growing Neural Cellular Automata with some inspiration from the follow up paper Self-classifying MNIST Digits. Except instead of a single list of numbers (or one vector) per pixel, there are several vectors per pixel in each image. The Growing Neural Cellular Automata paper was very difficult to train also because the long gradient chains, so increasing the models complexity in this GLOM paper makes training even harder. But the neural cellular automata papers are the reason why the MSE loss function is used while also adding random noise to the state during training.

To do

  1. Generated the explanation for Yannick Kilcher's version of attention that is implemented here.
  2. See if part-whole heirarchies are being found.
  3. Keep testing hyperpatameters to push accuracy higher.
  4. Test different state initializations.
  5. Train on harder datasets.

If you find any issues, please feel free to contact me

Owner
Just a random coder
code from "Tensor decomposition of higher-order correlations by nonlinear Hebbian plasticity"

Code associated with the paper "Tensor decomposition of higher-order correlations by nonlinear Hebbian learning," Ocker & Buice, Neurips 2021. "plot_f

Gabriel Koch Ocker 4 Oct 16, 2022
End-to-end speech secognition toolkit

End-to-end speech secognition toolkit This is an E2E ASR toolkit modified from Espnet1 (version 0.9.9). This is the official implementation of paper:

Jinchuan Tian 147 Dec 28, 2022
Knowledge Management for Humans using Machine Learning & Tags

HyperTag HyperTag helps humans intuitively express how they think about their files using tags and machine learning.

Ravn Tech, Inc. 165 Nov 04, 2022
Pytorch implementation of Bert and Pals: Projected Attention Layers for Efficient Adaptation in Multi-Task Learning

PyTorch implementation of BERT and PALs Introduction Work by Asa Cooper Stickland and Iain Murray, University of Edinburgh. Code for BERT and PALs; mo

Asa Cooper Stickland 70 Dec 29, 2022
Implement face detection, and age and gender classification, and emotion classification.

YOLO Keras Face Detection Implement Face detection, and Age and Gender Classification, and Emotion Classification. (image from wider face dataset) Ove

Chloe 10 Nov 14, 2022
Hand Gesture Volume Control | Open CV | Computer Vision

Gesture Volume Control Hand Gesture Volume Control | Open CV | Computer Vision Use gesture control to change the volume of a computer. First we look i

Jhenil Parihar 3 Jun 15, 2022
This application explain how we can easily integrate Deepface framework with Python Django application

deepface_suite This application explain how we can easily integrate Deepface framework with Python Django application install redis cache install requ

Mohamed Naji Aboo 3 Apr 18, 2022
StyleTransfer - Open source style transfer project, based on VGG19

StyleTransfer - Open source style transfer project, based on VGG19

Patrick martins de lima 9 Dec 13, 2021
This respository includes implementations on Manifoldron: Direct Space Partition via Manifold Discovery

Manifoldron: Direct Space Partition via Manifold Discovery This respository includes implementations on Manifoldron: Direct Space Partition via Manifo

dayang_wang 4 Apr 28, 2022
Hso-groupie - A pwnable challenge in Real World CTF 4th

Hso-groupie - A pwnable challenge in Real World CTF 4th

Riatre Foo 42 Dec 05, 2022
Code repository for the paper Computer Vision User Entity Behavior Analytics

Computer Vision User Entity Behavior Analytics Code repository for "Computer Vision User Entity Behavior Analytics" Code Description dataset.csv As di

Sameer Khanna 2 Aug 20, 2022
Official page of Patchwork (RA-L'21 w/ IROS'21)

Patchwork Official page of "Patchwork: Concentric Zone-based Region-wise Ground Segmentation with Ground Likelihood Estimation Using a 3D LiDAR Sensor

Hyungtae Lim 254 Jan 05, 2023
POT : Python Optimal Transport

POT: Python Optimal Transport This open source Python library provide several solvers for optimization problems related to Optimal Transport for signa

Python Optimal Transport 1.7k Dec 31, 2022
ICNet and PSPNet-50 in Tensorflow for real-time semantic segmentation

Real-Time Semantic Segmentation in TensorFlow Perform pixel-wise semantic segmentation on high-resolution images in real-time with Image Cascade Netwo

Oles Andrienko 219 Nov 21, 2022
This repository contains the code used in the paper "Prompt-Based Multi-Modal Image Segmentation".

Prompt-Based Multi-Modal Image Segmentation This repository contains the code used in the paper "Prompt-Based Multi-Modal Image Segmentation". The sys

Timo Lüddecke 305 Dec 30, 2022
[ICCV 2021] Released code for Causal Attention for Unbiased Visual Recognition

CaaM This repo contains the codes of training our CaaM on NICO/ImageNet9 dataset. Due to my recent limited bandwidth, this codebase is still messy, wh

Wang Tan 66 Dec 31, 2022
an implementation of softmax splatting for differentiable forward warping using PyTorch

softmax-splatting This is a reference implementation of the softmax splatting operator, which has been proposed in Softmax Splatting for Video Frame I

Simon Niklaus 338 Dec 28, 2022
Tensors and Dynamic neural networks in Python with strong GPU acceleration

PyTorch is a Python package that provides two high-level features: Tensor computation (like NumPy) with strong GPU acceleration Deep neural networks b

61.4k Jan 04, 2023
Dynamic Bottleneck for Robust Self-Supervised Exploration

Dynamic Bottleneck Introduction This is a TensorFlow based implementation for our paper on "Dynamic Bottleneck for Robust Self-Supervised Exploration"

Bai Chenjia 4 Nov 14, 2022
Face recognize system

FRS Face_recognize_system This project contains my work that target on solving some problems of FRS: Face detection: Retinaface Face anti-spoofing: Fo

Tran Anh Tuan 4 Nov 18, 2021