Skip to content

fastmachinelearning/SR_Mobile_Quantization

 
 

Repository files navigation

Introduction

A winner solution for MAI2021 Competition(CVPR2021 Workshop). Our model outperforms other participtants by a large margin in terms of both inference speed and reconstruction performance.

Challenge report: Mobile AI 2021 Real-Time Image Super-Resolution Challenge.

Our paper: Anchor-based Plain Net for Mobile Image Super-Resolution.

Contribution for INT8 Quantization SR Mobile Network

Investigation of meta-node latency

We conduct an experiment about meta-node latency by decomposing lightweight SR architectures, which determines the portable operations we can utilize. This step is curcial important if you want to deploy your model across mobile device.

Anchor-based residual learning

For full-integer quantization which means all the weights and activations are int8, it's obvious a better choice to learn residual(always close to zero) rather than directly mapping low-resolution image to high-resolution image. In existing methods, residual learning can be divided into two categories: (1). Image space residual learning means passing the interpolated-input(bilinear, bicubic) to network output. (2).Feature space residual learning means passing the output of shallow convolutional layer to network output. For float32 quantized model, feature space residual learning is slightly better(+0.08dB). For int8 quantized model, image space residual learning is always better(+0.3dB) because it forces the whole network to learn subtle change, thus a set of continuous real-valued numbers can be represented more accurately using a fixed discrete set of numbers. However, bilinear resize and nearest neighbor resize is really slow on mobile device due to pixel-wise multiplication when doing coordinate mapping. Our anchor-based residual learning can enjoy the good property of image space residual learning while being as fast as feature space residual learning. The core operation is repeating input nine times(for x3 scale) and add it to the feature before depth-to-space. See our architecture in model.

Another more convolution after deep feature extraction

After deep feature extraction, existing methods use one convolution to map features to origin image space, followed by a depth-to-space(PixelShuffle in Pytorch) layer. We find that in image space, one more convolution can significantly improve the performance compared with adding one convolution in deep feature extraction stage(+0.11dB).

Requirements

It should be noted that tensorflow version matters a lot because old versions don't include some layers such as depth-to-space, so you should make sure tf version is larger than 2.4.0. Another important thing is that only tf-nightly larger than 2.5.0 can perform arbitrary input shape quantization. I provide two conda environments, tf.yaml for training and tfnightly.yaml for Post-Training Quantization(PTQ) and Quantization-Aware Training(QAT). You can use the following scripts to create two separate conda environments.

conda env create -f tf.yaml
conda env create -f tfnightly.yaml

Pipeline

  1. Train and validate on DIV2K. We can achieve 30.22dB with 42.54K parameters.
  2. Post-Training Quantization: after int8 quantization, PSNR drops to 30.09dB.
  3. Quantization-Aware Training: Insert fake quantization nodes during training. PSNR increases to 30.15dB, which means the model size becomes 4x smaller with only 0.07dB performance loss.

Prepare DIV2K Data

Download DIV2K and put DIV2K in data folder. You can use (or edit) the provided script:

cd data
make download

Then the structure should look like:

data

DIV2K

DIV2K_train_HR

0001.png

...

0900.png

DIV2K_train_LR_bicubic

X2

0001x2.png

...

0900x2.png

Pre-trained model

The training script saves the models as .pb files in the directory experiment. The authors also provide a pre-trained model in the directory experiment.

You can also convert .pb files to .hdf5 files. Edit experiment/Makefile (if necessary) and then:

cd experiment
make run

Training

You can edit (if necessary) and run the Makefile:

make train-base7_D4C28_bs16ps64_lr1e-3

If you have any issue in training with a GPU, please try:

make train-nogpu-base7_D4C28_bs16ps64_lr1e-3

or

python train.py --opt options/train/base7.yaml --name base7_D4C28_bs16ps64_lr1e-3 --scale 3  --bs 16 --ps 64 --lr 1e-3 --gpu_ids 0

Note: The argument --name specifies the following save path:

  • Log file will be saved in log/{name}.log
  • Checkpoint and current best weights will be saved in experiment/{name}/best_status/
  • Visualization of Train and Validate will be saved in Tensorboard/{name}/

You can use tensorboard to monitor the training and validating process by:

tensorboard --logdir Tensorboard

Quantization-Aware Training

If you haven't worked with Tensorflow Lite and network quantization before, please refer to official guideline. This technology inserts fake quantization nodes to make the weights aware that themselves will be quantized. For this model, you can simply use the following script to perform QAT:

python train.py --opt options/train/base7_qat.yaml --name base7_D4C28_bs16ps64_lr1e-3_qat --scale 3  --bs 16 --ps 64 --lr 1e-3 --gpu_ids 0 --qat --qat_path experiment/base7_D4C28_bs16ps64_lr1e-3/best_status

Convert to TFLite which can run on mobile device

python generate_tflite.py

Then the converted tflite model will be saved in TFMODEL/. TFMODEL/{name}.tflite is used for predicting high-resolution image(arbitary low-resolution input shape is allowed), while TFMODEL/{name}_time.tflite fixes model input shape to [1, 360, 640, 3] for getting inference time.

Run TFLite Model on your own devices

  1. Download AI Benchmark from the Google Play / website and run its standard tests.
  2. After the end of the tests, enter the PRO Model and select the Custom Model tab there.
  3. Send your tflite model to your device and remember its location, then run the model.

Contact

:) If you have any questions, feel free to contact 151220022@smail.nju.edu.cn

About

Winner solution of mobile AI (CVPRW 2021).

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 89.9%
  • Shell 6.1%
  • Makefile 4.0%