Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

Related tags

Deep Learningcql-jax
Overview

CQL-JAX

This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on top of the SAC base of JAX-RL.

Usage

Install Dependencies-

pip install -r requirements.txt
pip install "jax[cuda111]<=0.21.1" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Run CQL-

python train_offline.py --env_name=hopper-expert-v0 --min_q_weight=5

Please use the following values of min_q_weight on MuJoCo tasks to reproduce CQL results from IQL paper-

Domain medium medium-replay medium-expert
walker 10 1 10
hopper 5 5 1
cheetah 90 80 100

For antmaze tasks min_q_weight=10 is found to work best.

In case of Out-Of Memory errors in JAX, try running with the following env variables-

XLA_PYTHON_CLIENT_MEM_FRACTION=0.80 python ...
XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 python ...

Performance & Runtime

Returns are more or less same as the torch implementation and comparable to IQL-

Task CQL(PyTorch) CQL(JAX) IQL
hopper-medium-v2 58.5 74.6 66.3
hopper-medium-replay-v2 95.0 92.1 94.7
hopper-medium-expert-v2 105.4 83.2 91.5
antmaze-umaze-v0 74.0 69.5 87.5
antmaze-umaze-diverse-v0 84.0 78.7 62.2
antmaze-medium-play-v0 61.2 14.2 71.2
antmaze-medium-diverse-v0 53.7 10.7 70.2
antmaze-large-play-v0 15.8 0.0 39.6
antmaze-large-diverse-v0 14.9 0.0 47.5

Wall-clock time averages to ~50 mins, improving over IQL paper's 80 min CQL and closing the gap with IQL's 20 min.

Task CQL(JAX) IQL
hopper-medium-v2 52 27
hopper-medium-replay-v2 54 30
hopper-medium-expert-v2 57 29

Time efficiency over the original torch implementation is more than 4 times.

For more offline RL algorithm implementations, check out the JAX-RL, IQL and rlkit repositories.

Citation

In case you use CQL-JAX for your research, please cite the following-

@misc{cqljax,
  author = {Suri, Karush},
  title = {{Conservative Q Learning in JAX.}},
  url = {https://github.com/karush17/cql-jax},
  year = {2021}
}

References

Owner
Karush Suri
Deep Learning Researcher at Huawei Noah's Ark Lab, Toronto.
Karush Suri
Code for binary and multiclass model change active learning, with spectral truncation implementation.

Model Change Active Learning Paper (To Appear) Python code for doing active learning in graph-based semi-supervised learning (GBSSL) paradigm. Impleme

Kevin Miller 1 Jul 24, 2022
BirdCLEF 2021 - Birdcall Identification 4th place solution

BirdCLEF 2021 - Birdcall Identification 4th place solution My solution detail kaggle discussion Inference Notebook (best submission) Environment Use K

tattaka 42 Jan 02, 2023
Stereo Hybrid Event-Frame (SHEF) Cameras for 3D Perception, IROS 2021

For academic use only. Stereo Hybrid Event-Frame (SHEF) Cameras for 3D Perception Ziwei Wang, Liyuan Pan, Yonhon Ng, Zheyu Zhuang and Robert Mahony Th

Ziwei Wang 11 Jan 04, 2023
공공장소에서 눈만 돌리면 CCTV가 보인다는 말이 과언이 아닐 정도로 CCTV가 우리 생활에 깊숙이 자리 잡았습니다.

ObsCare_Main 소개 공공장소에서 눈만 돌리면 CCTV가 보인다는 말이 과언이 아닐 정도로 CCTV가 우리 생활에 깊숙이 자리 잡았습니다. CCTV의 대수가 급격히 늘어나면서 관리와 효율성 문제와 더불어, 곳곳에 설치된 CCTV를 개별 관제하는 것으로는 응급 상

5 Jul 07, 2022
g9.py - Torch interactive graphics

g9.py - Torch interactive graphics A Torch toy in the browser. Demo at https://srush.github.io/g9py/ This is a shameless copy of g9.js, written in Pyt

Sasha Rush 13 Nov 16, 2022
ADOP: Approximate Differentiable One-Pixel Point Rendering

ADOP: Approximate Differentiable One-Pixel Point Rendering Abstract: We present a novel point-based, differentiable neural rendering pipeline for scen

Darius Rückert 1.9k Jan 06, 2023
Official Keras Implementation for UNet++ in IEEE Transactions on Medical Imaging and DLMIA 2018

UNet++: A Nested U-Net Architecture for Medical Image Segmentation UNet++ is a new general purpose image segmentation architecture for more accurate i

Zongwei Zhou 1.8k Dec 27, 2022
Lipstick ain't enough: Beyond Color-Matching for In-the-Wild Makeup Transfer (CVPR 2021)

Table of Content Introduction Datasets Getting Started Requirements Usage Example Training & Evaluation CPM: Color-Pattern Makeup Transfer CPM is a ho

VinAI Research 248 Dec 13, 2022
Contrastive Multi-View Representation Learning on Graphs

Contrastive Multi-View Representation Learning on Graphs This work introduces a self-supervised approach based on contrastive multi-view learning to l

Kaveh 208 Dec 23, 2022
Offical code for the paper: "Growing 3D Artefacts and Functional Machines with Neural Cellular Automata" https://arxiv.org/abs/2103.08737

Growing 3D Artefacts and Functional Machines with Neural Cellular Automata Video of more results: https://www.youtube.com/watch?v=-EzztzKoPeo Requirem

Robotics Evolution and Art Lab 51 Jan 01, 2023
Efficient Sparse Attacks on Videos using Reinforcement Learning

EARL This repository provides a simple implementation of the work "Efficient Sparse Attacks on Videos using Reinforcement Learning" Example: Demo: Her

12 Dec 05, 2021
2021搜狐校园文本匹配算法大赛 分比我们低的都是帅哥队

sohu_text_matching 2021搜狐校园文本匹配算法大赛Top2:分比我们低的都是帅哥队 本repo包含了本次大赛决赛环节提交的代码文件及答辩PPT,提交的模型文件可在百度网盘获取(链接:https://pan.baidu.com/s/1T9FtwiGFZhuC8qqwXKZSNA ,

hflserdaniel 43 Oct 01, 2022
Collection of NLP model explanations and accompanying analysis tools

Thermostat is a large collection of NLP model explanations and accompanying analysis tools. Combines explainability methods from the captum library wi

126 Nov 22, 2022
Official repository for the paper "GN-Transformer: Fusing AST and Source Code information in Graph Networks".

GN-Transformer AST This is the official repository for the paper "GN-Transformer: Fusing AST and Source Code information in Graph Networks". Data Prep

Cheng Jun-Yan 10 Nov 26, 2022
Simple implementation of OpenAI CLIP model in PyTorch.

It was in January of 2021 that OpenAI announced two new models: DALL-E and CLIP, both multi-modality models connecting texts and images in some way. In this article we are going to implement CLIP mod

Moein Shariatnia 226 Jan 05, 2023
PyTorch implementation of "ContextNet: Improving Convolutional Neural Networks for Automatic Speech Recognition with Global Context" (INTERSPEECH 2020)

ContextNet ContextNet has CNN-RNN-transducer architecture and features a fully convolutional encoder that incorporates global context information into

Sangchun Ha 24 Nov 24, 2022
Emotional conditioned music generation using transformer-based model.

This is the official repository of EMOPIA: A Multi-Modal Pop Piano Dataset For Emotion Recognition and Emotion-based Music Generation. The paper has b

hung anna 96 Nov 09, 2022
Keqing Chatbot With Python

KeqingChatbot A public running instance can be found on telegram as @keqingchat_bot. Requirements Python 3.8 or higher. A bot token. Local Deploy git

Rikka-Chan 2 Jan 16, 2022
Deep Learning Models for Causal Inference

Extensive tutorials for learning how to build deep learning models for causal inference using selection on observables in Tensorflow 2.

Bernard J Koch 151 Dec 31, 2022
AirPose: Multi-View Fusion Network for Aerial 3D Human Pose and Shape Estimation

AirPose AirPose: Multi-View Fusion Network for Aerial 3D Human Pose and Shape Estimation Check the teaser video This repository contains the code of A

Robot Perception Group 41 Dec 05, 2022