Pytorch implementation of MLP-Mixer with loading pre-trained models.

Overview

MLP-Mixer-Pytorch

PyTorch implementation of MLP-Mixer: An all-MLP Architecture for Vision with the function of loading official ImageNet pre-trained parameters.

Usage

import torch
import numpy as np
from mlp_mixer import MlpMixer

pretrain_model='./pretrain_models/imagenet21k_Mixer-B_16.npz'

model = MlpMixer(num_classes=10, 
                 num_blocks=12, 
                 patch_size=16, 
                 hidden_dim=768, 
                 tokens_mlp_dim=384, 
                 channels_mlp_dim=3072, 
                 image_size=224
                 )

# load official ImageNet pre-trained model:
model.load_from(np.load(pretrain_model))
print ('Finish loading the pre-trained model!')

num_param = sum(p.numel() for p in model.parameters()) / 1e6
print ('Total params.: %f M'%num_param)

pred = model(img)

Fine-tuning

Download the official pre-trained models at https://console.cloud.google.com/storage/mixer_models/.

Hypyer-parameters setting for better fine-tuning:

optim = torch.optim.SGD(param_list, 
                        lr=5e-4, 
                        weight_decay=1e-7,
                        momentum=0.9, 
                        nesterov=True
                        )
lr_schdlr = WarmupCosineLrScheduler(optim, 
                                    n_iters_all, 
                                    warmup_iter=0
                                    )

Using the pre-trained model to fine-tune MLP-Mixer can obtain remarkable improvements (e.g., +10% accuracy on a small dataset).

Note that we can also change the patch_size (e.g., patch_size=8) for inputs with different resolutions, but smaller patch_size may not always bring performance improvements.

Citation

@misc{tolstikhin2021mlpmixer,
      title={MLP-Mixer: An all-MLP Architecture for Vision}, 
      author={Ilya Tolstikhin and Neil Houlsby and Alexander Kolesnikov and Lucas Beyer and Xiaohua Zhai and Thomas Unterthiner and Jessica Yung and Daniel Keysers and Jakob Uszkoreit and Mario Lucic and Alexey Dosovitskiy},
      year={2021},
      eprint={2105.01601},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgement

  1. The implementation is based on the original paper and the official Tensorflow repo: https://github.com/google-research/vision_transformer.
  2. It also refers to the re-implementation repo: https://github.com/d-li14/mlp-mixer.pytorch.
Owner
Qiushi Yang
Student in CityU, Hong Kong. Interested in computer vision and machine learning.
Qiushi Yang
J.A.R.V.I.S is an AI virtual assistant made in python.

J.A.R.V.I.S is an AI virtual assistant made in python. Running JARVIS Without Python To run JARVIS without python: 1. Head over to our installation pa

somePythonProgrammer 16 Dec 29, 2022
Implementation of the GVP-Transformer, which was used in the paper "Learning inverse folding from millions of predicted structures" for de novo protein design alongside Alphafold2

GVP Transformer (wip) Implementation of the GVP-Transformer, which was used in the paper Learning inverse folding from millions of predicted structure

Phil Wang 19 May 06, 2022
ICNet and PSPNet-50 in Tensorflow for real-time semantic segmentation

Real-Time Semantic Segmentation in TensorFlow Perform pixel-wise semantic segmentation on high-resolution images in real-time with Image Cascade Netwo

Oles Andrienko 219 Nov 21, 2022
PHOTONAI is a high level python API for designing and optimizing machine learning pipelines.

PHOTONAI is a high level python API for designing and optimizing machine learning pipelines. We've created a system in which you can easily select and

Medical Machine Learning Lab - University of Münster 57 Nov 12, 2022
Dense Prediction Transformers

Vision Transformers for Dense Prediction This repository contains code and models for our paper: Vision Transformers for Dense Prediction René Ranftl,

Intel ISL (Intel Intelligent Systems Lab) 1.3k Dec 28, 2022
A vision library for performing sliced inference on large images/small objects

SAHI: Slicing Aided Hyper Inference A vision library for performing sliced inference on large images/small objects Overview Object detection and insta

Open Business Software Solutions 2.3k Jan 04, 2023
A universal memory dumper using Frida

Fridump Fridump (v0.1) is an open source memory dumping tool, primarily aimed to penetration testers and developers. Fridump is using the Frida framew

551 Jan 07, 2023
Full Transformer Framework for Robust Point Cloud Registration with Deep Information Interaction

Full Transformer Framework for Robust Point Cloud Registration with Deep Information Interaction. arxiv This repository contains python scripts for tr

12 Dec 12, 2022
Using Convolutional Neural Networks (CNN) for Semantic Segmentation of Breast Cancer Lesions (BRCA)

Using Convolutional Neural Networks (CNN) for Semantic Segmentation of Breast Cancer Lesions (BRCA). Master's thesis documents. Bibliography, experiments and reports.

Erick Cobos 73 Dec 04, 2022
PyTorch implementation of DCT fast weight RNNs

DCT based fast weights This repository contains the official code for the paper: Training and Generating Neural Networks in Compressed Weight Space. T

Kazuki Irie 4 Dec 24, 2022
ICRA 2021 - Robust Place Recognition using an Imaging Lidar

Robust Place Recognition using an Imaging Lidar A place recognition package using high-resolution imaging lidar. For best performance, a lidar equippe

Tixiao Shan 293 Dec 27, 2022
MINOS: Multimodal Indoor Simulator

MINOS Simulator MINOS is a simulator designed to support the development of multisensory models for goal-directed navigation in complex indoor environ

194 Dec 27, 2022
The World of an Octopus: How Reporting Bias Influences a Language Model's Perception of Color

The World of an Octopus: How Reporting Bias Influences a Language Model's Perception of Color Overview Code and dataset for The World of an Octopus: H

1 Nov 13, 2021
Zero-shot Synthesis with Group-Supervised Learning (ICLR 2021 paper)

GSL - Zero-shot Synthesis with Group-Supervised Learning Figure: Zero-shot synthesis performance of our method with different dataset (iLab-20M, RaFD,

Andy_Ge 62 Dec 21, 2022
Variational Attention: Propagating Domain-Specific Knowledge for Multi-Domain Learning in Crowd Counting (ICCV, 2021)

DKPNet ICCV 2021 Variational Attention: Propagating Domain-Specific Knowledge for Multi-Domain Learning in Crowd Counting Baseline of DKPNet is availa

19 Oct 14, 2022
Library extending Jupyter notebooks to integrate with Apache TinkerPop and RDF SPARQL.

Graph Notebook: easily query and visualize graphs The graph notebook provides an easy way to interact with graph databases using Jupyter notebooks. Us

Amazon Web Services 501 Dec 28, 2022
PyTorch implementation of Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy

Anomaly Transformer in PyTorch This is an implementation of Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy. This pape

spencerbraun 160 Dec 19, 2022
Oriented Response Networks, in CVPR 2017

Oriented Response Networks [Home] [Project] [Paper] [Supp] [Poster] Torch Implementation The torch branch contains: the official torch implementation

ZhouYanzhao 217 Dec 12, 2022
SeqTR: A Simple yet Universal Network for Visual Grounding

SeqTR This is the official implementation of SeqTR: A Simple yet Universal Network for Visual Grounding, which simplifies and unifies the modelling fo

seanZhuh 76 Dec 24, 2022
Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch

Perceiver - Pytorch Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch Install $ pip install perceiver-pytorch Usage

Phil Wang 876 Dec 29, 2022