Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding (CVPR2022)

Overview

Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding

by Qiaole Dong*, Chenjie Cao*, Yanwei Fu

Paper and Supplemental Material (arXiv)

LICENSE

Pipeline

Click to expand

The overview of our ZITS. At first, the TSR model is used to restore structures with low resolutions. Then the simple CNN based upsampler is leveraged to upsample edge and line maps. Moreover, the upsampled sketch space is encoded and added to the FTR through ZeroRA to restore the textures.

TO DO

We have updated weights of TSR!

Our project page is available at https://dqiaole.github.io/ZITS_inpainting/.

  • Releasing inference codes.
  • Releasing pre-trained moodel.
  • Releasing training codes.

Preparation

Click to expand
  1. Preparing the environment:

    as there are some bugs when using GP loss with DDP (link), we strongly recommend installing Apex without CUDA extensions via torch1.9.0 for the multi-gpu training

    conda create -n train_env python=3.6
    conda activate train_env
    pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
    pip install -r requirement.txt
    git clone https://github.com/NVIDIA/apex
    cd apex
    pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" ./
    
  2. For training, MST provide irregular and segmentation masks (download) with different masking rates. And you should define the mask file list before the training as in MST.

  3. Download the pretrained masked wireframe detection model to the './ckpt' fold: LSM-HAWP (MST ICCV2021 retrained from HAWP CVPR2020).

  4. Prepare the wireframes:

    as the MST train the LSM-HAWP in Pytorch 1.3.1 and it causes problem (link) when tested in Pytorch 1.9, we recommand to inference the lines(wireframes) with torch==1.3.1. If the line detection is not based on torch1.3.1, the performance may drop a little.

    conda create -n wireframes_inference_env python=3.6
    conda activate wireframes_inference_env
    pip install torch==1.3.1 torchvision==0.4.2
    pip install -r requirement.txt
    

    then extract wireframes with following code

    python lsm_hawp_inference.py --ckpt_path <best_lsm_hawp.pth> --input_path <input image path> --output_path <output image path> --gpu_ids '0'
    
  5. If you need to train the model, please download the pretrained models for perceptual loss, provided by LaMa:

    mkdir -p ade20k/ade20k-resnet50dilated-ppm_deepsup/
    wget -P ade20k/ade20k-resnet50dilated-ppm_deepsup/ http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth
    

Eval

Click to expand

Download pretrained models on Places2 here.

Link for BaiduDrive, password:qnm5

Batch Test

For batch test, you need to complete steps 3 and 4 above.

Put the pretrained models to the './ckpt' fold. Then modify the config file according to you image, mask and wireframes path.

Test on 256 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2 --config_file ./config_list/config_ZITS_places2.yml --GPU_ids '0'

Test on 512 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2_hr --config_file ./config_list/config_ZITS_HR_places2.yml --GPU_ids '0'

Single Image Test

Note: For single image test, environment 'wireframes_inference_env' in step 4 is recommended for a better line detection. This code only supports squared images (or they will be center cropped).

conda activate wireframes_inference_env
python single_image_test.py --path <ckpt_path> --config_file <config_path> \
 --GPU_ids '0' --img_path ./image.png --mask_path ./mask.png --save_path ./

Training

Click to expand

⚠️ Warning: The training codes is not fully tested yet after refactoring

Training TSR

python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 12 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP
python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 15 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP --MaP

Train SSU

We recommend to use the pretrained SSU. You can also train your SSU refered to https://github.com/ewrfcas/StructureUpsampling.

Training LaMa First

python FTR_train.py --nodes 1 --gpus 1 --GPU_ids '0' --path ./ckpt/lama_places2 \
--config_file ./config_list/config_LAMA.yml --lama

Training FTR

256:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2 \
--config_file ./config_list/config_ZITS_places2.yml --DDP

256~512:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2_HR \
--config_file ./config_list/config_ZITS_HR_places2.yml --DDP

More 1K Results

Click to expand

Acknowledgments

Cite

If you found our program helpful, please consider citing:

@inproceedings{dong2022incremental,
      title={Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding}, 
      author={Qiaole Dong and Chenjie Cao and Yanwei Fu},
      booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
      year={2022}
}
Owner
Qiaole Dong
Qiaole Dong
Conservative and Adaptive Penalty for Model-Based Safe Reinforcement Learning

Conservative and Adaptive Penalty for Model-Based Safe Reinforcement Learning This is the official repository for Conservative and Adaptive Penalty fo

7 Nov 22, 2022
Open & Efficient for Framework for Aspect-based Sentiment Analysis

PyABSA - Open & Efficient for Framework for Aspect-based Sentiment Analysis Fast & Low Memory requirement & Enhanced implementation of Local Context F

YangHeng 567 Jan 07, 2023
Noise Conditional Score Networks (NeurIPS 2019, Oral)

Generative Modeling by Estimating Gradients of the Data Distribution This repo contains the official implementation for the NeurIPS 2019 paper Generat

451 Dec 26, 2022
Official Implementation of SimIPU: Simple 2D Image and 3D Point Cloud Unsupervised Pre-Training for Spatial-Aware Visual Representations

Official Implementation of SimIPU SimIPU: Simple 2D Image and 3D Point Cloud Unsupervised Pre-Training for Spatial-Aware Visual Representations Since

Zhyever 37 Dec 01, 2022
AlgoVision - A Framework for Differentiable Algorithms and Algorithmic Supervision

NeurIPS 2021 Paper "Learning with Algorithmic Supervision via Continuous Relaxations"

Felix Petersen 76 Jan 01, 2023
End-to-End Referring Video Object Segmentation with Multimodal Transformers

End-to-End Referring Video Object Segmentation with Multimodal Transformers This repo contains the official implementation of the paper: End-to-End Re

608 Dec 30, 2022
A simple root calculater for python

Root A simple root calculater Usage/Examples python3 root.py 9 3 4 # Order: number - grid - number of decimals # Output: 2.08

Reza Hosseinzadeh 5 Feb 10, 2022
BboxToolkit is a tiny library of special bounding boxes.

BboxToolkit is a light codebase collecting some practical functions for the special-shape detection, such as oriented detection

jbwang1997 73 Jan 01, 2023
Bot developed in Python that automates races in pegaxy.

español | português About it: This is a fork from pega-racing-bot. This bot, developed in Python, is to automate races in pegaxy. The game developers

4 Apr 08, 2022
Code for the paper "Balancing Training for Multilingual Neural Machine Translation, ACL 2020"

Balancing Training for Multilingual Neural Machine Translation Implementation of the paper Balancing Training for Multilingual Neural Machine Translat

Xinyi Wang 21 May 18, 2022
PlenOctrees: NeRF-SH Training & Conversion

PlenOctrees Official Repo: NeRF-SH training and conversion This repository contains code to train NeRF-SH and to extract the PlenOctree, constituting

Alex Yu 323 Dec 29, 2022
optimization routines for hyperparameter tuning

Hyperopt: Distributed Hyperparameter Optimization Hyperopt is a Python library for serial and parallel optimization over awkward search spaces, which

Marc Claesen 398 Nov 09, 2022
Vpw analyzer - A visual J1850 VPW analyzer written in Python

VPW Analyzer A visual J1850 VPW analyzer written in Python Requires Tkinter, Pan

7 May 01, 2022
CondNet: Conditional Classifier for Scene Segmentation

CondNet: Conditional Classifier for Scene Segmentation Introduction The fully convolutional network (FCN) has achieved tremendous success in dense vis

ycszen 31 Jul 22, 2022
Implementation of "Learning to Match Features with Seeded Graph Matching Network" ICCV2021

SGMNet Implementation PyTorch implementation of SGMNet for ICCV'21 paper "Learning to Match Features with Seeded Graph Matching Network", by Hongkai C

87 Dec 11, 2022
Research shows Google collects 20x more data from Android than Apple collects from iOS. Block this non-consensual telemetry using pihole blocklists.

pihole-antitelemetry Research shows Google collects 20x more data from Android than Apple collects from iOS. Block both using these pihole lists. Proj

Adrian Edwards 290 Jan 09, 2023
Yolov5 + Deep Sort with PyTorch

딥소트 수정중 Yolov5 + Deep Sort with PyTorch Introduction This repository contains a two-stage-tracker. The detections generated by YOLOv5, a family of obj

1 Nov 26, 2021
Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation

Unseen Object Clustering: Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation Introduction In this work, we propose a new method

NVIDIA Research Projects 132 Dec 13, 2022
Group-Free 3D Object Detection via Transformers

Group-Free 3D Object Detection via Transformers By Ze Liu, Zheng Zhang, Yue Cao, Han Hu, Xin Tong. This repo is the official implementation of "Group-

Ze Liu 213 Dec 07, 2022
An implementation of an abstract algebra for music tones (pitches).

nbdev template Use this template to more easily create your nbdev project. If you are using an older version of this template, and want to upgrade to

Open Music Kit 0 Oct 10, 2022