-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_basic.py
74 lines (55 loc) · 2.97 KB
/
train_basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import os
import json
import random
import numpy as np
import tensorflow as tf
from tensorflow.keras import optimizers
from train_lib import train
from data import read_instances, build_vocabulary, \
save_vocabulary, index_instances, load_glove_embeddings
from model import MyBasicAttentiveBiGRU
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train Model')
parser.add_argument('--data-file', type=str, help="Location of data", default="./data/train.txt")
parser.add_argument('--val-file', type=str, help="Location of val data", default="./data/val.txt")
parser.add_argument('--batch-size', type=int, help="size of batch", default=10)
parser.add_argument('--epochs', type=int, help="num epochs", default=10)
parser.add_argument('--embed-file', type=str, help="embedding location", default='./data/glove.6B.100D.txt')
parser.add_argument('--embed-dim', type=int, help="size of embeddings", default=100)
parser.add_argument('--hidden-size', type=int, help="size of hidden dimension", default=128)
args = parser.parse_args()
tf.random.set_seed(1337)
np.random.seed(1337)
random.seed(13370)
MAX_TOKENS = 250
VOCAB_SIZE = 10000
GLOVE_COMMON_WORDS_PATH = os.path.join("data", "glove_common_words.txt")
print(f"\nReading Train Instances")
train_instances = read_instances(args.data_file, MAX_TOKENS)
print(f"\nReading Val Instances")
val_instances = read_instances(args.val_file, MAX_TOKENS)
with open(GLOVE_COMMON_WORDS_PATH) as file:
glove_common_words = [line.strip() for line in file.readlines() if line.strip()]
vocab_token_to_id, vocab_id_to_token = build_vocabulary(train_instances, VOCAB_SIZE,
glove_common_words)
train_instances = index_instances(train_instances, vocab_token_to_id)
val_instances = index_instances(val_instances, vocab_token_to_id)
vocab_size = len(vocab_token_to_id)
config = {'vocab_size': vocab_size, 'embed_dim': args.embed_dim, 'training': True, 'hidden_size': args.hidden_size}
model = MyBasicAttentiveBiGRU(**config)
config['type'] = 'basic'
optimizer = optimizers.Adam()
embeddings = load_glove_embeddings(args.embed_file, args.embed_dim, vocab_id_to_token)
model.embeddings.assign(tf.convert_to_tensor(embeddings))
save_serialization_dir = os.path.join('serialization_dirs', 'basic')
if not os.path.exists(save_serialization_dir):
os.makedirs(save_serialization_dir)
train_output = train(model, optimizer, train_instances, val_instances,
args.epochs, args.batch_size, save_serialization_dir)
config_path = os.path.join(save_serialization_dir, "config.json")
with open(config_path, 'w') as f:
json.dump(config, f)
vocab_path = os.path.join(save_serialization_dir, "vocab.txt")
save_vocabulary(vocab_id_to_token, vocab_path)
print(f"\nModel stored in directory: {save_serialization_dir}")