-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathtrain_cub.py
110 lines (84 loc) · 3.53 KB
/
train_cub.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""End-to-end training code for cross-modal retrieval tasks.
PCME
Copyright (c) 2021-present NAVER Corp.
MIT license
"""
import os
import fire
import torch.backends.cudnn as cudnn
from config import parse_config
from datasets import prepare_cub_dataloaders
from engine import TrainerEngine
from engine import CUBEvaluator
from logger import PythonLogger
def pretrain(config, dataloaders, vocab, logger):
logger.log('start pretrain')
engine = TrainerEngine()
engine.set_logger(logger)
config.model.img_finetune = False
config.model.txt_finetune = False
_dataloaders = dataloaders.copy()
val_epochs = config.train.get('pretrain_val_epochs', 1)
evaluator = CUBEvaluator(eval_method=config.model.get('eval_method', 'matmul'),
verbose=False,
eval_device='cuda')
engine.create(config, vocab.word2idx, evaluator)
engine.train(tr_loader=_dataloaders.pop('train'),
n_epochs=config.train.pretrain_epochs,
val_loaders=_dataloaders,
val_epochs=val_epochs,
model_save_to=config.train.pretrain_save_path,
best_model_save_to=config.train.best_pretrain_save_path)
def finetune(config, pretrain_path, dataloaders, vocab, logger):
logger.log('start finetune')
engine = TrainerEngine()
engine.set_logger(logger)
config.model.img_finetune = True
config.model.txt_finetune = True
config.optimizer.learning_rate *= config.train.get('finetune_lr_decay', 0.1)
_dataloaders = dataloaders.copy()
val_epochs = config.train.get('val_epochs', 1)
evaluator = CUBEvaluator(eval_method=config.model.get('eval_method', 'matmul'),
verbose=False,
eval_device='cuda')
engine.create(config, vocab.word2idx, evaluator)
if os.path.exists(pretrain_path):
engine.load_models(pretrain_path,
load_keys=['model', 'criterion'])
engine.train(tr_loader=_dataloaders.pop('train'),
n_epochs=config.train.finetune_epochs,
val_loaders=_dataloaders,
val_epochs=val_epochs,
model_save_to=config.train.model_save_path,
best_model_save_to=config.train.best_model_save_path)
def main(config_path,
dataset_root,
caption_root,
dataset_name='cub',
vocab_path='./datasets/vocabs/cub_vocab.pkl',
**kwargs):
"""Main interface for the training.
Args:
config_path: path to the configuration file
dataset_root: root for the dataset
caption_root: root for the caption
vocab_path: vocab filename
Other configurations:
you can override any pcme configuration in the command line!
try, --<depth1>__<depth2>. E.g., --dataloader__batch_size 32
"""
logger = PythonLogger()
config = parse_config(config_path,
strict_cast=False,
**kwargs)
cudnn.benchmark = True
dataloaders, vocab = prepare_cub_dataloaders(config.dataloader,
dataset_name,
dataset_root,
caption_root,
vocab_path)
pretrain(config, dataloaders, vocab, logger)
finetune(config, config.train.pretrain_save_path,
dataloaders, vocab, logger)
if __name__ == '__main__':
fire.Fire(main)