-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathmain.py
167 lines (137 loc) · 8.42 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
""" Code for the Main function of MetaPred. """
import os, csv
import numpy as np
import random
import pickle as pkl
import tensorflow as tf
import copy
from tensorflow.python.platform import flags
from data_loader import DataLoader
import model, finetune
FLAGS = flags.FLAGS
flags.DEFINE_string('source', 'AD', 'source task')
flags.DEFINE_string('target', 'MCI', 'simulated task')
flags.DEFINE_string('true_target', 'PD', 'true task')
## Dataset/method options
flags.DEFINE_integer('n_classes', 2, 'number of classes used in classification (e.g. binary classification)')
## Training options
flags.DEFINE_string('method', 'rnn', 'deep learning methods for modeling')
flags.DEFINE_integer('pretrain_iterations', 20000, 'number of pre-training iterations')
flags.DEFINE_integer('metatrain_iterations', 10000, 'number of metatraining iterations') # 15k for omniglot, 50k for sinusoid
flags.DEFINE_integer('meta_batch_size', 8, 'number of tasks sampled per meta-update')
flags.DEFINE_integer('update_batch_size', 16, 'number of samples used for inner gradient update (K for K-shot learning)')
flags.DEFINE_float('meta_lr', 0.0001, 'the base learning rate of the generator')
flags.DEFINE_float('update_lr', 1e-3, 'step size alpha for inner gradient update')
flags.DEFINE_integer('num_updates', 4, 'number of inner gradient updates during training')
flags.DEFINE_integer('n_total_batches', 100000, 'total batches generated by random sampling')
## Model options
flags.DEFINE_string('norm', 'None', 'batch_norm, layer_norm, or None')
flags.DEFINE_bool('stop_grad', False, 'if True, do not use second derivatives in meta-optimization (for speed)')
flags.DEFINE_bool('isReg', True, 'if True, compute regularization of weights and bias')
flags.DEFINE_float('dropout', 0.5, 'drop out when modeling, with probability keep_prob')
## Logging, saving, and testing options
flags.DEFINE_integer('run_time', 1, 're-run for stable analysis')
flags.DEFINE_bool('train', True, 'True to train, False to test directly')
flags.DEFINE_bool('test', True, 'True to test, no matter the model is trained')
flags.DEFINE_bool('finetune', False, 'True to finetunning furthermore, after meta-learning')
flags.DEFINE_bool('log', True, 'if false, do not log summaries, for debugging code')
flags.DEFINE_string('logdir', 'model/', 'directory for summaries and checkpoints')
flags.DEFINE_bool('resume', False, 'resume training if there is a model available')
flags.DEFINE_integer('test_iter', -1, 'iteration to load model (-1 for latest model)')
flags.DEFINE_integer('train_update_batch_size', -1, 'number of examples used for gradient update during training (use if you want to test with a different number)')
flags.DEFINE_float('train_update_lr', -1, 'value of inner gradient step step during training. (use if you want to test with a different value)') # 0.1 for omniglot
def train(data_loader, ifold, exp_string):
# construct MetaPred model
print ("constructing MetaPred model ...")
m1 = model.MetaPred(data_loader, FLAGS.meta_lr, FLAGS.update_lr)
# fitting the meta-learning model
print ("model training...")
sess = m1.fit(data_loader.episode, data_loader.episode_val[ifold], ifold, exp_string)
return m1, sess
def test(data_loader, ifold, m, sess, exp_string):
# meta-testing the model
print ("model test...")
data_tuple_val = (data_loader.data_s, data_loader.data_tt_val[ifold], data_loader.label_s, data_loader.label_tt_val[ifold])
test_accs, test_aucs, test_ap, test_f1s = m.evaluate(data_loader.episode_val[ifold], data_tuple_val, sess=sess, prefix="metatest_")
print('Test results: ' + "ifold: " + str(ifold) + ": tAcc: " + str(test_accs) + \
", tAuc: " + str(test_aucs) + ", tAP: " + str(test_ap) + ", tF1: " + str(test_f1s))
return test_accs, test_aucs, test_ap, test_f1s
def fine_tune(data_loader, ifold, meta_m, weights_for_finetune, exp_string):
# construct MetaPred model
is_finetune = True
print ("finetunning MetaPred model ...")
if FLAGS.method == "cnn":
m2 = finetune.CNN(data_loader, weights_for_finetune, freeze_opt=freeze_opt, is_finetune=is_finetune)
if FLAGS.method == "rnn":
m2 = finetune.RNN(data_loader, weights_for_finetune, freeze_opt=freeze_opt, is_finetune=is_finetune)
print ("model finetunning...")
# model finetunning
sess, _, _ = m2.fit(data_loader.tt_sample[ifold], data_loader.tt_label[ifold],
data_loader.tt_sample_val[ifold], data_loader.tt_label_val[ifold])
return m2, sess
def save_results(metatest, exp_string):
out_filename = "results/res_" + exp_string
with open(out_filename, 'w') as f:
writer = csv.writer(f, delimiter=',')
for key in metatest:
writer.writerow([np.mean(np.array(metatest[key]))])
writer.writerow([np.std(np.array(metatest[key]))])
print ("results saved")
def save_weights(meta_m, source, target, true_target, data_loader, ifold):
with open("weights/meta-" + FLAGS.method + ".weights" + ".source_" + "-".join(source) + ".starget_" + "".join(target) + ".ttarget_" + "".join(true_target) + ".pkl", 'wb') as f:
pkl.dump((meta_m.weights_for_finetune), f, protocol=2)
f.close()
with open("weights/meta-" + FLAGS.method + ".tt_train" + ".source_" + "-".join(source) + ".starget_" + "".join(target) + ".ttarget_" + "".join(true_target) + ".pkl", 'wb') as f:
pkl.dump((data_loader.tt_sample[ifold], data_loader.tt_label[ifold]), f, protocol=2)
f.close()
with open("weights/meta-" + FLAGS.method + ".tt_val" + ".source_" + "-".join(source) + ".starget_" + "".join(target) + ".ttarget_" + "".join(true_target) + ".pkl", 'wb') as f:
pkl.dump((data_loader.tt_sample_val[ifold], data_loader.tt_label_val[ifold]), f, protocol=2)
f.close()
print("model weights saved")
def main():
print (FLAGS.method)
# set source and simulated target for training
print ('task setting: ')
source = [FLAGS.source]
target = [FLAGS.target]
true_target = [FLAGS.true_target]
print ("The applied source tasks are: ", " ".join(source))
print ("The simulated target task is: ", " ".join(target))
print ("The true target task is: ", " ".join(true_target))
n_tasks = len(source) + len(target)
# load ehrs data
data_loader = DataLoader(source, target, true_target, n_tasks,
FLAGS.update_batch_size, FLAGS.meta_batch_size)
exp_string = 'stsk_'+str('&'.join(source))+'ttsk_'+str('&'.join(target))+'.mbs_'+str(FLAGS.meta_batch_size) + \
'.ubs_' + str(FLAGS.update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(FLAGS.update_lr)
metatest = {'aucroc': [], 'avepre': [], 'f1score': []} # n_fold result
n_fold = data_loader.n_fold
for ifold in range(n_fold):
print ("----------The %d-th fold-----------" %(ifold+1))
meta_model = None
if FLAGS.train:
meta_model, sess = train(data_loader, ifold, exp_string)
save_weights(meta_model, source, target, true_target, data_loader, ifold)
if FLAGS.finetune:
with open("weights/meta-" + FLAGS.method + ".weights" + ".source_" + "-".join(source) + ".starget_" + "".join(target) + ".ttarget_" + "".join(true_target) + ".pkl", 'rb') as f:
weights_for_finetune = pkl.load(f)
f.close()
model, sess = fine_tune(data_loader, ifold, meta_model, weights_for_finetune, exp_string)
if FLAGS.test:
_, test_aucs, test_ap, test_f1s = test(data_loader, ifold, meta_model, sess, exp_string)
metatest['aucroc'].append(test_aucs)
metatest['avepre'].append(test_ap)
metatest['f1score'].append(test_f1s)
# show results
print ('--------------- model setting ---------------')
print('source: ', " ".join(source), 'simulated target: ', " ".join(target), 'true target: ', " ".join(true_target))
print('method:', 'meta-' + FLAGS.method, 'meta-bz:', FLAGS.meta_batch_size, 'update-bz:', FLAGS.update_batch_size, \
'num update:', FLAGS.num_updates, 'meta-lr:', FLAGS.meta_lr, 'update-lr:', FLAGS.update_lr)
print ('--------------- 5fold results ---------------')
print ('aucroc mean: ', np.mean(np.array(metatest['aucroc'])))
print ('aucroc std: ', np.std(np.array(metatest['aucroc'])))
print ('f1score mean: ', np.mean(np.array(metatest['f1score'])))
print ('f1score std: ', np.std(np.array(metatest['f1score'])))
save_results(metatest, exp_string)
if __name__ == "__main__":
main()