Exploring whether attention is necessary for vision transformers

Overview

Do You Even Need Attention?
A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet

Paper/Report

TL;DR

We replace the attention layer in a vision transformer with a feed-forward layer and find that it still works quite well on ImageNet.

Abstract

The strong performance of vision transformers on image classification and other vision tasks is often attributed to the design of their multi-head attention layers. However, the extent to which attention is responsible for this strong performance remains unclear. In this short report, we ask: is the attention layer even necessary? Specifically, we replace the attention layer in a vision transformer with a feed-forward layer applied over the patch dimension. The resulting architecture is simply a series of feed-forward layers applied over the patch and feature dimensions in an alternating fashion. In experiments on ImageNet, this architecture performs surprisingly well: a ViT/DeiT-base-sized model obtains 74.9% top-1 accuracy, compared to 77.9% and 79.9% for ViT and DeiT respectively. These results indicate that aspects of vision transformers other than attention, such as the patch embedding, may be more responsible for their strong performance than previously thought. We hope these results prompt the community to spend more time trying to understand why our current models are as effective as they are.

Note

This is concurrent research with MLP-Mixer from Google Research. The ideas are exacty the same, with the one difference being that they use (a lot) more compute.

Pretrained models and logs

Here is a Weights and Biases report with the expected training trajectory: W&B

name [email protected] #params url
FF-tiny 61.4 7.7M model
FF-base 74.9 62M model
FF-large 71.4 206M -

Note: I haven't uploaded the FF-Large model because (1) it's over GitHub's file storage limit, and (2) I don't see why anyone would want it, given that it performs worse than the base model. That being said, if you want it, reach out to me and I'll send it to you.

How to train

The model definition in vision_transformer_linear.py is designed to be run with the repo from DeiT, which is itself based on the wonderful timm package.

Steps:

  • Clone the DeiT repo and move the file into it
git clone https://github.com/facebookresearch/deit
mv vision_transformer_linear.py deit
cd deit
  • Add a line to import vision_transformer_linear in main.py. For example, add the following after the import statements (around line 27):
+ import vision_transformer_linear
  • Train:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \
--nproc_per_node=8 \
--master_port 10490 \
--use_env main.py \
--model linear_tiny \
--batch-size 128 \
--drop 0.1 \
--output_dir outputs/linear-tiny \
--data-path your/path/to/imagenet

Citation

If you build upon this idea, feel free to drop a citation (and also cite MLP-Mixer).

@article{melaskyriazi2021doyoueven,
  title={Do You Even Need Attention? A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet},
  author={Luke Melas-Kyriazi},
  journal=arxiv,
  year=2021
}
You might also like...
The first dataset of composite images with rationality score indicating whether the object placement in a composite image is reasonable.
The first dataset of composite images with rationality score indicating whether the object placement in a composite image is reasonable.

Object-Placement-Assessment-Dataset-OPA Object-Placement-Assessment (OPA) is to verify whether a composite image is plausible in terms of the object p

Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network.
Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network.

face-mask-detection Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network. It contains 3 scr

This program was designed to detect whether someone is wearing a facemask through a live video stream.

This program was designed to detect whether someone is wearing a facemask through a live video stream. A custom lightweight CNN trained with TensorFlow on a public dataset provided by Kaggle is used to detect whether each face detected by the cv2 face detection dnn is wearing a mask

A system used to detect whether a person is wearing a medical mask or not.

Mask_Detection_System A system used to detect whether a person is wearing a medical mask or not. To open the program, please follow these steps: Make

Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Multivariate Time Series Forecasting with efficient Transformers. Code for the paper
Multivariate Time Series Forecasting with efficient Transformers. Code for the paper "Long-Range Transformers for Dynamic Spatiotemporal Forecasting."

Spacetimeformer Multivariate Forecasting This repository contains the code for the paper, "Long-Range Transformers for Dynamic Spatiotemporal Forecast

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch.

SE3 Transformer - Pytorch Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 resu

Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.
Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.

PyTorch Implementation of Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers 1 Using Colab Please notic

Comments
  • On details about the experiments

    On details about the experiments

    In the experiment section of the report, you mention:

    Such a comparison is not exactly fair because the feed-forward model uses stronger training augmentations.

    I wonder what the augmentations are, for details about the augmentation seems missing from the paper.

    opened by boredtylin 2
  • Positional Encoding Ablation

    Positional Encoding Ablation

    Hi Luke, thank you for sharing this amazing work.

    In your arxiv document, I cannot find any mention of positional encoding, but I see that you use them in your code. Did you conduct any ablation study on the PE? i.e., how much does it affect the performance, with and without?

    Thank you in advance.

    opened by jjparkcv 0
  • Interaction between patches through a transpose may have a stronger role to play ?

    Interaction between patches through a transpose may have a stronger role to play ?

    Hi, I was going through your exp report. You have made a point that since you are able to get a good performance without using attention layer so good performance of ViT may be more to do with it's embedding layer than attention .

    But I believe, It's also may be to do with how you have established an interaction between patches through a transpose very similar to what was done in MLP-Mixer .

    Would love to know your thoughts on this ?

    opened by rakshith291 0
Owner
Luke Melas-Kyriazi
I'm student at Harvard University studying mathematics and computer science, always open to collaborate on interesting projects!
Luke Melas-Kyriazi
A cool little repl-based simulation written in Python

A cool little repl-based simulation written in Python planned to integrate machine-learning into itself to have AI battle to the death before your eye

Em 6 Sep 17, 2022
This is a GUI interface which can process forest fire detection, smoke detection and fire segmentation

This is a GUI interface which can process forest fire detection, smoke detection and fire segmentation. Yolov5 is used to detect fire and smoke and unet is used to segment fire.

7 Jan 08, 2023
code for our BMVC 2021 paper "HCV: Hierarchy-Consistency Verification for Incremental Implicitly-Refined Classification"

HCV_IIRC code for our BMVC 2021 paper HCV: Hierarchy-Consistency Verification for Incremental Implicitly-Refined Classification by Kai Wang, Xialei Li

kai wang 13 Oct 03, 2022
Price-Prediction-For-a-Dream-Home - A machine learning based linear regression trained model for house price prediction.

Price-Prediction-For-a-Dream-Home ROADMAP TO THIS LINEAR REGRESSION BASED HOUSE PRICE PREDICTION PREDICTION MODEL Import all the dependencies of the p

DIKSHA DESWAL 1 Dec 29, 2021
Repository for self-supervised landmark discovery

self-supervised-landmarks Repository for self-supervised landmark discovery Requirements pytorch pynrrd (for 3d images) Usage The use of this models i

Riddhish Bhalodia 2 Apr 18, 2022
Detectron2-FC a fast construction platform of neural network algorithm based on detectron2

What is Detectron2-FC Detectron2-FC a fast construction platform of neural network algorithm based on detectron2. We have been working hard in two dir

董晋宗 9 Jun 06, 2022
Cl datasets - PyTorch image dataloaders and utility functions to load datasets for supervised continual learning

Continual learning datasets Introduction This repository contains PyTorch image

berjaoui 5 Aug 28, 2022
Implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT : Cross-Attention Multi-Scale Vision Transformer for Image Classification This is an unofficial PyTorch implementation of CrossViT: Cross-Att

Rishikesh (ऋषिकेश) 103 Nov 25, 2022
Graph-based community clustering approach to extract protein domains from a predicted aligned error matrix

Using a predicted aligned error matrix corresponding to an AlphaFold2 model , returns a series of lists of residue indices, where each list corresponds to a set of residues clustering together into a

Tristan Croll 24 Nov 23, 2022
(CVPR2021) ClassSR: A General Framework to Accelerate Super-Resolution Networks by Data Characteristic

ClassSR (CVPR2021) ClassSR: A General Framework to Accelerate Super-Resolution Networks by Data Characteristic Paper Authors: Xiangtao Kong, Hengyuan

Xiangtao Kong 308 Jan 05, 2023
Single object tracking and segmentation.

Single/Multiple Object Tracking and Segmentation Codes and comparison of recent single/multiple object tracking and segmentation. News 💥 AutoMatch is

ZP ZHANG 385 Jan 02, 2023
PASSL包含 SimCLR,MoCo,BYOL,CLIP等基于对比学习的图像自监督算法以及 Vision-Transformer,Swin-Transformer,BEiT,CVT,T2T,MLP_Mixer等视觉Transformer算法

PASSL Introduction PASSL is a Paddle based vision library for state-of-the-art Self-Supervised Learning research with PaddlePaddle. PASSL aims to acce

186 Dec 29, 2022
Forecasting directional movements of stock prices for intraday trading using LSTM and random forest

Forecasting directional movements of stock-prices for intraday trading using LSTM and random-forest https://arxiv.org/abs/2004.10178 Pushpendu Ghosh,

Pushpendu Ghosh 270 Dec 24, 2022
Easy and comprehensive assessment of predictive power, with support for neuroimaging features

Documentation: https://raamana.github.io/neuropredict/ News As of v0.6, neuropredict now supports regression applications i.e. predicting continuous t

Pradeep Reddy Raamana 93 Nov 29, 2022
Code for Deep Single-image Portrait Image Relighting

Deep Single-Image Portrait Relighting [Project Page] Hao Zhou, Sunil Hadap, Kalyan Sunkavalli, David W. Jacobs. In ICCV, 2019 Overview Test script for

438 Jan 05, 2023
HomoInterpGAN - Homomorphic Latent Space Interpolation for Unpaired Image-to-image Translation

HomoInterpGAN Homomorphic Latent Space Interpolation for Unpaired Image-to-image Translation (CVPR 2019, oral) Installation The implementation is base

Ying-Cong Chen 99 Nov 15, 2022
Hooks for VCOCO

Verbs in COCO (V-COCO) Dataset This repository hosts the Verbs in COCO (V-COCO) dataset and associated code to evaluate models for the Visual Semantic

Saurabh Gupta 131 Nov 24, 2022
An official source code for paper Deep Graph Clustering via Dual Correlation Reduction, accepted by AAAI 2022

Dual Correlation Reduction Network An official source code for paper Deep Graph Clustering via Dual Correlation Reduction, accepted by AAAI 2022. Any

yueliu1999 109 Dec 23, 2022
(CVPR 2022 Oral) Official implementation for "Surface Representation for Point Clouds"

RepSurf - Surface Representation for Point Clouds [CVPR 2022 Oral] By Haoxi Ran* , Jun Liu, Chengjie Wang ( * : corresponding contact) The pytorch off

Haoxi Ran 264 Dec 23, 2022
buildseg is a building extraction plugin of QGIS based on PaddlePaddle.

buildseg buildseg is a building extraction plugin of QGIS based on PaddlePaddle. TODO Extract building on 512x512 remote sensing images. Extract build

Yizhou Chen 11 Sep 26, 2022