Skip to content

Latest commit

 

History

History
80 lines (60 loc) · 4.12 KB

README.md

File metadata and controls

80 lines (60 loc) · 4.12 KB

Knowledge Distillation for Multi-task Learning

This is the implementation of Knowledge Distillation for Multi-task Learning introduced by Wei-Hong Li and Hakan Bilen. We provide code for our method that performs semantic segmentation, depth estimation and surface normal estimation on NYU-v2 dataset using SegNet and MTAN. The repository also contains code for baselines which are compared in our paper. All methods are implemented in Pytorch

Figure 1. Diagram of our method. We first train a task-specific model for each task in an offline stage and freeze their parameters (i.e. (a), (c)). We then optimize the parameters of the multi-task network for minimizing a sum of task-specific losses and also for producing similar features with the single-task networks (i.e. (b)).

Updates

  • April'23, This work has been updated with applications to multi-task dense prediction, multi-domain learning, cross-domain few-shot learning and implementation of more loss balancing strategies and multi-task/domain learning backbones. Check out our latest version: Universal Representations for more details!

Requirements

  • Python 3.6+
  • PyTorch 1.0 (or newer version)
  • torchvision 0.2.2 (or newer version)
  • progress
  • matplotlib
  • numpy

Usage

Prepare dataset

We use the preprocessed NYUv2 dataset provided by this repo. Download the dataset and place the dataset folder in ./data/

Our method

  • train the single task learning models with SegNet:
python model_segnet_single.py --gpu <gpu_id> --out SegNet-single --dataroot ./data/nyuv2 --task <task: semantic, depth, normal>

or with MTAN

python model_mtan_single.py --gpu <gpu_id> --out mtan-single --dataroot ./data/nyuv2 --task <task: semantic, depth, normal>
  • train the multi-task learning model using our KD-MTL with SegNet:
python model_segnet_kdmtl.py --gpu <gpu_id> --alr 1e-1 --out SegNet-KD-MTL --single-dir ./SegNet-single/ --dataroot ./data/nyuv2

or with MTAN

python model_mtan_kdmtl.py --gpu <gpu_id> --alr 1e-1 --out MTAN-KD-MTL --single-dir ./mtan-single/ --dataroot ./data/nyuv2

Baselines

We provide code, model_segnet_baselines.py and model_mtan_baselines.py, for several balancing loss weighting approaches including Uniformly weighting, MGDA (adapted from the source code), GradNorm, and DWA (source code).

  • Train the multi-task learning model using MGDA with SegNet:
python model_segnet_baselines.py --gpu <gpu_id> --out baselines --dataroot ./data/nyuv2 --weight mgda #weight: uniform, mgda, gradnorm, dwa

or with MTAN

python model_mtan_baselines.py --gpu <gpu_id> --out baselines --dataroot ./data/nyuv2 --weight mgda #weight: uniform, mgda, gradnorm, dwa

Acknowledge

We thank Shikun Liu and Ozan Sener for their source code of MTAN and MGDA.

Contact

For any question, you can contact Wei-Hong Li.

Citation

If you find this code to be useful in your work and use it for a publication, please kindly cite:

@inproceedings{li2020knowledge,
  title={Knowledge Distillation for Multi-task Learning},
  author={Li, Wei-Hong and Bilen, Hakan},
  booktitle={Proceedings of the European Conference on Computer Vision Workshop on Imbalance Problems in Computer Vision},
  year={2020}
}