Official implementation for ICDAR 2021 paper "Handwritten Mathematical Expression Recognition with Bidirectionally Trained Transformer"

Overview

Handwritten Mathematical Expression Recognition with Bidirectionally Trained Transformer

arXiv

Description

Convert offline handwritten mathematical expression to LaTeX sequence using bidirectionally trained transformer.

How to run

First, install dependencies

# clone project   
git clone https://github.com/Green-Wood/BTTR

# install project   
cd BTTR
conda create -y -n bttr python=3.7
conda activate bttr
conda install --yes -c pytorch pytorch=1.7.0 torchvision cudatoolkit=<your-cuda-version>
pip install -e .   

Next, navigate to any file and run it. It may take 6~7 hours to coverage on 4 gpus using ddp.

# module folder
cd BTTR

# train bttr model using 4 gpus and ddp
python train.py --config config.yaml  

For single gpu user, you may change the config.yaml file to

gpus: 1
# gpus: 4
# accelerator: ddp

Imports

This project is setup as a package which means you can now easily import any file into any other file like so:

from bttr.datamodule import CROHMEDatamodule
from bttr import LitBTTR
from pytorch_lightning import Trainer

# model
model = LitBTTR()

# data
dm = CROHMEDatamodule(test_year=test_year)

# train
trainer = Trainer()
trainer.fit(model, datamodule=dm)

# test using the best model!
trainer.test(datamodule=dm)

Note

Metrics used in validation is not accurate.

For more accurate metrics:

  1. use test.py to generate result.zip
  2. download and install crohmelib, lgeval, and tex2symlg tool.
  3. convert tex file to symLg file using tex2symlg command
  4. evaluate two folder using evaluate command

Citation

@article{zhao2021handwritten,
  title={Handwritten Mathematical Expression Recognition with Bidirectionally Trained Transformer},
  author={Zhao, Wenqi and Gao, Liangcai and Yan, Zuoyu and Peng, Shuai and Du, Lin and Zhang, Ziyin},
  journal={arXiv preprint arXiv:2105.02412},
  year={2021}
}
Comments
  • can you provide predict.py code?

    can you provide predict.py code?

    Hi ~ @Green-Wood.

    I feel grateful mind for your help. I wanna get predict.py code that prints latex from an input image. If this code is provided, it will be very useful to others as well.

    Best regards.

    opened by ai-motive 17
  • val_exprate=0 and save checkpoint

    val_exprate=0 and save checkpoint

    hello!thanks for your time! When I transfer some code in decoder or use it directly,the val_exprate are always be 0.000,I don't know why. Another problem is,I noticed that this code don't have the function to save checkpoint or something.Can you give me some help?Thanks again!

    opened by Ashleyyyi 6
  • Val_exprate = 0

    Val_exprate = 0

    When I retrained the model according to the instruction, the val_exprate was always 0.00, did anyone encounter this problem, thank you! (I has not modified any codes) @Green-Wood

    opened by qingqianshuying 4
  • test.py error occurs

    test.py error occurs

    When I run test.py code, the following error occurs. Can i get some helps?

    in test.py code test_year = "2016" ckp_path = "pretrained model"

    GPU available: True, used: True
    TPU available: False, using: 0 TPU cores
    Load data from: /home/motive/PycharmProjects/BTTR/bttr/datamodule/../../data.zip
    Extract data from: 2016, with data size: 1147
    total  1147 batch data loaded
    LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
    Testing: 100%|██████████| 1147/1147 [07:34<00:00,  2.01s/it]ExpRate: 0.32258063554763794
    length of total file: 1147
    Testing: 100%|██████████| 1147/1147 [07:34<00:00,  2.52it/s]
    --------------------------------------------------------------------------------
    DATALOADER:0 TEST RESULTS
    {}
    --------------------------------------------------------------------------------
    Traceback (most recent call last):
      File "/home/motive/PycharmProjects/BTTR/test.py", line 17, in <module>
        trainer.test(model, datamodule=dm)
      File "/home/motive/anaconda3/envs/bttr/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in test
        results = self._run(model)
      File "/home/motive/anaconda3/envs/bttr/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 759, in _run
        self.post_dispatch()
      File "/home/motive/anaconda3/envs/bttr/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 789, in post_dispatch
        self.accelerator.teardown()
      File "/home/motive/anaconda3/envs/bttr/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu.py", line 51, in teardown
        self.lightning_module.cpu()
      File "/home/motive/anaconda3/envs/bttr/lib/python3.7/site-packages/pytorch_lightning/utilities/device_dtype_mixin.py", line 141, in cpu
        return super().cpu()
      File "/home/motive/anaconda3/envs/bttr/lib/python3.7/site-packages/torch/nn/modules/module.py", line 471, in cpu
        return self._apply(lambda t: t.cpu())
      File "/home/motive/anaconda3/envs/bttr/lib/python3.7/site-packages/torch/nn/modules/module.py", line 359, in _apply
        module._apply(fn)
      File "/home/motive/anaconda3/envs/bttr/lib/python3.7/site-packages/torchmetrics/metric.py", line 317, in _apply
        setattr(this, key, [fn(cur_v) for cur_v in current_val])
      File "/home/motive/anaconda3/envs/bttr/lib/python3.7/site-packages/torchmetrics/metric.py", line 317, in <listcomp>
        setattr(this, key, [fn(cur_v) for cur_v in current_val])
      File "/home/motive/anaconda3/envs/bttr/lib/python3.7/site-packages/torch/nn/modules/module.py", line 471, in <lambda>
        return self._apply(lambda t: t.cpu())
    AttributeError: 'tuple' object has no attribute 'cpu'
    
    opened by ai-motive 3
  • How long does BTTR take to train?

    How long does BTTR take to train?

    Hi, thank you for great repository!

    How long does it take to train for your experiment in the paper? I mean training on CROHME 2014/2016/2019 on four NVIDIA 1080Ti GPUs.

    Thanks,

    opened by RyosukeFukatani 2
  • can you provide transfer learning code?

    can you provide transfer learning code?

    Hi~ @Green-Wood

    I wanna apply trasnfer learning using pretrained model.

    but, LightningCLI() is wrapped and difficult to customize.

    Thanks & best regards.

    opened by ai-motive 1
  • How can it get pretrained model ?

    How can it get pretrained model ?

    Hi, I wanna test your BTTR model but, it need to training process which will take a lot of time. So, can you give me a pretrained model link?

    Best regards.

    opened by ai-motive 1
  • After adding new token in dictionary getting error .

    After adding new token in dictionary getting error .

    Hi , getting error after adding new token in dictionary.txt

    Error(s) in loading state_dict for LitBTTR: size mismatch for bttr.decoder.word_embed.0.weight: copying a param with shape torch.Size([113, 256]) from checkpoint, the shape in current model is torch.Size([115, 256]). size mismatch for bttr.decoder.proj.weight: copying a param with shape torch.Size([113, 256]) from checkpoint, the shape in current model is torch.Size([115, 256]). size mismatch for bttr.decoder.proj.bias: copying a param with shape torch.Size([113]) from checkpoint, the shape in current model is torch.Size([115]).

    Kindly help me out how can i fix this error.

    opened by shivankaraditi 0
  • About dataset

    About dataset

    Could you tell me how to generate the offline math expression image from inkml file? My experiment show that a large scale image could improve the result obviously,so I'd like to know if there is unified offline data for academic research.

    opened by lightflash7 0
  • predicting on gpu is slower

    predicting on gpu is slower

    Hi ,

    As this model is a bit slower compared to the existing state-of-the-art model on CPU. So I tried to make predictions on GPU and surprisingly it slower on Gpu compare to CPU as well.

    I am attaching a code snapshot here

    device = torch.device('cuda')if torch.cuda.is_available() else torch.device('cpu')

    model = LitBTTR.load_from_checkpoint('pretrained-2014.ckpt',map_location=device)

    img = Image.open(img_path) img = ToTensor()(img) img.to(device)

    t1 = time.time() hyp = model.beam_search(img) t2 = time.time()

    Kindly help me out here how i can reduce prediction time

    FYI - using GPU on aws g4dn.xlarge configuration machine

    opened by Suma3 1
  • how to use TensorBoard?

    how to use TensorBoard?

    hello i don't know how to add scalar to TensorBoard? I want to do this kind of topic, hoping to improve some ExpRate, but I don’t know much about lightning TensorBoard.

    opened by win5923 9
Releases(v2.0)
Owner
Wenqi Zhao
Student in Nanjing University
Wenqi Zhao
Implementation of BI-RADS-BERT & The Advantages of Section Tokenization.

BI-RADS BERT Implementation of BI-RADS-BERT & The Advantages of Section Tokenization. This implementation could be used on other radiology in house co

1 May 17, 2022
A Pytorch implementation of "Splitter: Learning Node Representations that Capture Multiple Social Contexts" (WWW 2019).

Splitter ⠀⠀ A PyTorch implementation of Splitter: Learning Node Representations that Capture Multiple Social Contexts (WWW 2019). Abstract Recent inte

Benedek Rozemberczki 201 Nov 09, 2022
Computer Vision is an elective course of MSAI, SCSE, NTU, Singapore

[AI6122] Computer Vision is an elective course of MSAI, SCSE, NTU, Singapore. The repository corresponds to the AI6122 of Semester 1, AY2021-2022, starting from 08/2021. The instructor of this course

HT. Li 5 Sep 12, 2022
State-Relabeling Adversarial Active Learning

State-Relabeling Adversarial Active Learning Code for SRAAL [2020 CVPR Oral] Requirements torch = 1.6.0 numpy = 1.19.1 tqdm = 4.31.1 AL Results The

10 Jul 14, 2022
Home repository for the Regularized Greedy Forest (RGF) library. It includes original implementation from the paper and multithreaded one written in C++, along with various language-specific wrappers.

Regularized Greedy Forest Regularized Greedy Forest (RGF) is a tree ensemble machine learning method described in this paper. RGF can deliver better r

RGF-team 364 Dec 28, 2022
Official Repo for ICCV2021 Paper: Learning to Regress Bodies from Images using Differentiable Semantic Rendering

[ICCV2021] Learning to Regress Bodies from Images using Differentiable Semantic Rendering Getting Started DSR has been implemented and tested on Ubunt

Sai Kumar Dwivedi 83 Nov 27, 2022
Author's PyTorch implementation of TD3+BC, a simple variant of TD3 for offline RL

A Minimalist Approach to Offline Reinforcement Learning TD3+BC is a simple approach to offline RL where only two changes are made to TD3: (1) a weight

Scott Fujimoto 193 Dec 23, 2022
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

JAX: Autograd and XLA Quickstart | Transformations | Install guide | Neural net libraries | Change logs | Reference docs | Code search News: JAX tops

Google 21.3k Jan 01, 2023
PyTorch implementation(s) of various ResNet models from Twitch streams.

pytorch-resnet-twitch PyTorch implementation(s) of various ResNet models from Twitch streams. Status: ResNet50 currently not working. Will update in n

Daniel Bourke 3 Jan 11, 2022
Python3 / PyTorch implementation of the following paper: Fine-grained Semantics-aware Representation Enhancement for Self-supervisedMonocular Depth Estimation. ICCV 2021 (oral)

FSRE-Depth This is a Python3 / PyTorch implementation of FSRE-Depth, as described in the following paper: Fine-grained Semantics-aware Representation

77 Dec 28, 2022
This repository is for DSA and CP scripts for reference.

dsa-script-collections This Repo is the collection of DSA and CP scripts for reference. Contents Python Bubble Sort Insertion Sort Merge Sort Quick So

Aditya Kumar Pandey 9 Nov 22, 2022
Direct LiDAR Odometry: Fast Localization with Dense Point Clouds

Direct LiDAR Odometry: Fast Localization with Dense Point Clouds DLO is a lightweight and computationally-efficient frontend LiDAR odometry solution w

VECTR at UCLA 369 Dec 30, 2022
A configurable, tunable, and reproducible library for CTR prediction

FuxiCTR This repo is the community dev version of the official release at huawei-noah/benchmark/FuxiCTR. Click-through rate (CTR) prediction is an cri

XUEPAI 397 Dec 30, 2022
The code of paper "Block Modeling-Guided Graph Convolutional Neural Networks".

Block Modeling-Guided Graph Convolutional Neural Networks This repository contains the demo code of the paper: Block Modeling-Guided Graph Convolution

22 Dec 08, 2022
This repository contains the code needed to train Mega-NeRF models and generate the sparse voxel octrees

Mega-NeRF This repository contains the code needed to train Mega-NeRF models and generate the sparse voxel octrees used by the Mega-NeRF-Dynamic viewe

cmusatyalab 260 Dec 28, 2022
meProp: Sparsified Back Propagation for Accelerated Deep Learning

meProp The codes were used for the paper meProp: Sparsified Back Propagation for Accelerated Deep Learning with Reduced Overfitting (ICML 2017) [pdf]

LancoPKU 107 Nov 18, 2022
Rlmm blender toolkit - A set of tools to streamline level generation in UDK straight from Blender

rlmm_blender_toolkit A set of tools to streamline level generation in UDK straig

Rocket League Mapmaking 0 Jan 15, 2022
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty

AugMix Introduction We propose AugMix, a data processing technique that mixes augmented images and enforces consistent embeddings of the augmented ima

Google Research 876 Dec 17, 2022
Adversarial Graph Representation Adaptation for Cross-Domain Facial Expression Recognition (AGRA, ACM 2020, Oral)

Cross Domain Facial Expression Recognition Benchmark Implementation of papers: Cross-Domain Facial Expression Recognition: A Unified Evaluation Benchm

89 Dec 09, 2022
The official implementation of Variable-Length Piano Infilling (VLI).

Variable-Length-Piano-Infilling The official implementation of Variable-Length Piano Infilling (VLI). (paper: Variable-Length Music Score Infilling vi

29 Sep 01, 2022