-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgen.py
35 lines (31 loc) · 1.08 KB
/
gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from fairseq.models.bart import BARTModel
folder_name = "grammarly"
bart = BARTModel.from_pretrained(
'checkpoints/',
checkpoint_file='checkpoint_best.pt',
data_name_or_path=folder_name + '-bin'
)
bart.cuda()
bart.eval()
bart.half()
count = 1
bsz = 32
with open(folder_name + '/test.source') as source, open(folder_name + '/test.hypo', 'w') as fout:
sline = source.readline().strip()
slines = [sline]
for sline in source:
if count % bsz == 0:
with torch.no_grad():
hypotheses_batch = bart.sample(slines, beam=10, lenpen=2.0, max_len_b=40, min_len=5, no_repeat_ngram_size=3)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()
slines = []
slines.append(sline.strip())
count += 1
if slines != []:
hypotheses_batch = bart.sample(slines, beam=10, lenpen=2.0, max_len_b=40, min_len=5, no_repeat_ngram_size=3)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()