The problem of retrieving images from a large-database using ambiguous sketches has been addressed. This problem has been addressed in the zero-shot scenario, where the test sketches/images are from unseen classes and the deep feature extractor's output embedding distances of the sketch and the image have been used to retrieve the top k closest images from the image database of the unseen classes.
The standard triplet loss has been used, along with a domain loss, which is trained to differentiate between sketches and images.
The embedding is passed in to gradient reversal layer 1, and then into a domain-classifier network, and then into the domain loss. The gradient reversal layer acts as an identity layer in the forward pass, and multiplies the gradient by -1 in the backward pass which results in the main network learning a domain-agnostic representation i.e. to fool the domain-classifier network.
† Diagram created by Arjun Ashok using app.diagrams.net
The below table presents a few qualitative results of our model on unseen test classes. The table presents the top 5 results from left to right.
Query Sketch 1 |
---|
Retrieved Image 1 | Retrieved Image 2 | Retrieved Image 3 | Retrieved Image 4 | Retrieved Image 5 |
---|---|---|---|---|
Query Sketch 2 |
---|
Retrieved Image 1 | Retrieved Image 2 | Retrieved Image 3 | Retrieved Image 4 | Retrieved Image 5 |
---|---|---|---|---|
Installation
Please execute the following command to install the required libraries:
pip install -r requirements.txt
Data
Execute bash download_data.sh
Training
The file train.py
can be invoked with the following arguments:
usage: train.py [-h] --data_dir DATA_DIR --batch_size BATCH_SIZE
--checkpoint_dir CHECKPOINT_DIR --epochs EPOCHS
[--domain_loss_ratio DOMAIN_LOSS_RATIO]
[--triplet_loss_ratio TRIPLET_LOSS_RATIO]
[--grl_threshold_epoch GRL_THRESHOLD_EPOCH]
[--print_every PRINT_EVERY]
Training of SBIR
optional arguments:
-h, --help show this help message and exit
--data_dir DATA_DIR Data directory path. Directory should contain two
folders - sketches and photos, along with 2 .txt files
for the labels
--batch_size BATCH_SIZE
Batch size to process the train sketches/photos
--checkpoint_dir CHECKPOINT_DIR
Directory to save checkpoints
--epochs EPOCHS Number of epochs
--domain_loss_ratio DOMAIN_LOSS_RATIO
Domain loss weight
--triplet_loss_ratio TRIPLET_LOSS_RATIO
Triplet loss weight
--grl_threshold_epoch GRL_THRESHOLD_EPOCH
Threshold epoch for GRL lambda
--print_every PRINT_EVERY
Logging interval in iterations
It is advised to use a GPU for training. The code automatically detects and uses a GPU, if available.
Inference
The file evaluate.py
can be invoked with the following args:
usage: evaluate.py [-h] [--model MODEL] --data DATA [--num_images NUM_IMAGES]
[--num_sketches NUM_SKETCHES] [--batch_size BATCH_SIZE]
[--output_dir OUTPUT_DIR]
Evaluation of SBIR
arguments:
-h, --help show this help message and exit
--model MODEL Model checkpoint path
--data DATA Data directory path. Directory should contain two
folders - sketches and photos, along with 2 .txt files
for the labels
--num_images NUM_IMAGES
Number of random images to output for every
sketch
--num_sketches NUM_SKETCHES
Number of random sketches to output
--batch_size BATCH_SIZE
Batch size to process the test sketches/photos
--output_dir OUTPUT_DIR
Directory to save output sketch and images
It is advised to use a GPU for evaluation. The code automatically detects and uses a GPU, if available.
- Ganin, Yaroslav et al. "Domain-Adversarial Training Of Neural Networks". Journal of Machine Learning Research, 2016, pp. 1-35, url:http://jmlr.org/papers/v17/15-239.html