Skip to content

A PyTorch implementation of EGFF based on NPL 2022 paper "Energy-Guided Feature Fusion for Zero-Shot Sketch-Based Image Retrieval"

Notifications You must be signed in to change notification settings

leftthomas/EGFF

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

91 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

EGFF

A PyTorch implementation of EGFF based on NPL 2022 paper Energy-Guided Feature Fusion for Zero-Shot Sketch-Based Image Retrieval.

Network Architecture

Requirements

conda install pytorch=1.10.1 torchvision cudatoolkit -c pytorch
pip install pytorch-metric-learning
pip install timm
pip install opencv-python

Dataset

Sketchy Extended and TU-Berlin Extended datasets are used in this repo, you could download these datasets from official websites, or download them from Google Drive. The data directory structure is shown as follows:

├──sketchy
  ├── train
      ├── sketch
          ├── airplane
              ├── n02691156_58-1.jpg
              └── ...
          ...
      ├── photo
          same structure as sketch
  ├── val
     same structure as train
     ...
├──tuberlin
  same structure as sketchy
  ...

Usage

Train Model

python train.py --data_name tuberlin
optional arguments:
--data_root                   Datasets root path [default value is '/home/data']
--data_name                   Dataset name [default value is 'sketchy'](choices=['sketchy', 'tuberlin'])
--backbone_type               Backbone type [default value is 'resnet50'](choices=['resnet50', 'vgg16'])
--proj_dim                    Projected embedding dim [default value is 512]
--batch_size                  Number of images in each mini-batch [default value is 64]
--epochs                      Number of epochs over the model to train [default value is 10]
--warmup                      Number of warmups over the model to train [default value is 1]
--save_root                   Result saved root path [default value is 'result']

Test Model

python test.py --num 4
optional arguments:
--data_root                   Datasets root path [default value is '/home/data']
--query_name                  Query image name [default value is '/home/data/sketchy/val/sketch/cow/n01887787_591-14.jpg']
--data_base                   Queried database [default value is 'result/sketchy_resnet50_512_vectors.pth']
--num                         Retrieval number [default value is 8]
--save_root                   Result saved root path [default value is 'result']

Vis Model

python vis.py --model_name result/sketchy_resnet50_2048_model.pth
optional arguments:
--vis_name                    Visual image name [default value is '/home/data/sketchy/val/photo/helicopter/ext_5.jpg']
--model_name                  Model name [default value is 'result/sketchy_resnet50_512_model.pth']
--save_root                   Result saved root path [default value is 'result']

Benchmarks

The models are trained on one NVIDIA GeForce RTX 3090 (24G) GPU. AdamW is used to optimize the model, lr is 1e-5 and weight decay is 5e-4. all the hyper-parameters are the default values.

Backbone Dim Sketchy Extended TU-Berlin Extended Download
mAP@200 mAP@all P@100 P@200 mAP@200 mAP@all P@100 P@200
VGG16 64 36.1 39.8 52.8 48.1 44.2 39.3 57.1 53.9 u7qg
VGG16 512 42.7 45.1 58.9 53.6 48.6 42.8 60.7 57.2 6up4
VGG16 4096 44.6 47.3 60.1 55.2 50.0 44.1 61.8 58.5 hznm
ResNet50 64 43.3 46.6 58.6 54.3 50.7 47.7 61.1 58.5 uhkp
ResNet50 512 52.6 55.4 66.0 61.7 58.0 53.5 67.5 65.0 u8ct
ResNet50 2048 53.7 56.8 66.4 62.5 60.4 56.1 69.4 67.1 ipr3

Results

vis

About

A PyTorch implementation of EGFF based on NPL 2022 paper "Energy-Guided Feature Fusion for Zero-Shot Sketch-Based Image Retrieval"

Topics

Resources

Stars

Watchers

Forks

Languages