-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_facetvae.py
84 lines (64 loc) · 2.61 KB
/
run_facetvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse
from datetime import datetime
import os
from logging import getLogger
import pandas as pd
import yaml
from recbole.config import Config
from recbole.data import create_dataset, data_preparation, save_split_dataloaders
from recbole.utils import init_logger, init_seed, set_color, get_trainer
from core import FacetVAE
def load_config(config_file_full_path):
with open(config_file_full_path, 'r', encoding='utf-8') as reader:
config_dict = yaml.safe_load(reader)
try:
return config_dict
except KeyError:
raise Exception('`config_id` not found')
def run(args):
config_dict = load_config(args.config_file)
config_dict['gpu_id'] = str(args.device_id)
dataset = config_dict['dataset']
config_file_list = ['./configs/common.yaml']
config = Config(model=FacetVAE,
dataset=dataset,
config_file_list=config_file_list,
config_dict=config_dict)
init_seed(config['seed'], config['reproducibility'])
# logger initialization
init_logger(config)
logger = getLogger()
logger.info(config)
# dataset filtering
dataset = create_dataset(config)
if config['save_dataset']:
dataset.save()
logger.info(dataset)
# dataset splitting
original_model = config['model']
config['model'] = 'MacridVAE'
train_data, valid_data, test_data = data_preparation(config, dataset)
config['model'] = original_model
del original_model
if config['save_dataloaders']:
save_split_dataloaders(config, dataloaders=(train_data, valid_data, test_data))
# model loading and initialization
model = FacetVAE(config, train_data.dataset).to(config['device'])
logger.info(model)
# trainer loading and initialization
trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)
# model training
saved = True
best_valid_score, best_valid_result = trainer.fit(
train_data, valid_data, saved=saved, show_progress=config['show_progress']
)
test_result = trainer.evaluate(test_data, load_best_model=saved, show_progress=config['show_progress'])
logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}')
logger.info(set_color('test result', 'yellow') + f': {test_result}')
os.remove(trainer.saved_model_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-device', '--device_id', type=int, default=2, help='GPU id')
parser.add_argument('-cfg', '--config_file', type=str, default=None, help='Config file')
args = parser.parse_args()
run(args)