-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcomplete_ann.py
126 lines (101 loc) · 4.49 KB
/
complete_ann.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import nltk
import json
import operator
import os
import numpy as np
import csv
import re
import random
import pickle
from annoy import AnnoyIndex
from sklearn.feature_extraction.text import TfidfVectorizer
from bratreader.repomodel import RepoModel
r = RepoModel('wlpdata')
print('Preprocessing')
documents = list(r.documents.keys())
random.shuffle(documents)
trainlen = int(len(documents) * 0.8)
traindata = documents[:trainlen]
testdata = documents[trainlen:]
corrected_ann = 0
docs_corrected_ann = 0
def get_annotation_map(doc):
annotation_map = dict()
for annotation in set(doc.annotations):
text = get_words(annotation)
entity_type = list(annotation.labels.keys())[0]
entity_spans = annotation.spans # list of lists (having two values each). Length of this list gives total spans
if text not in annotation_map:
annotation_map[text] = dict()
annotation_map[text][entity_type] = 1
elif entity_type not in annotation_map[text]:
annotation_map[text][entity_type] = 1
else:
annotation_map[text][entity_type] += 1
return annotation_map
def get_words(annotation):
return ' '.join([x.form for x in annotation.words])
def get_sentences(docs):
global corrected_ann, docs_corrected_ann
all_docs = []
for fil in docs:
previous_corrected_ann = corrected_ann
doc = r.documents[fil]
annotation_map = get_annotation_map(doc)
sentinfo = []
for s in doc.sentences:
sents_repl = []
i = 0
while i < len(s.words):
if not len(s.words[i].annotations):
text = s.words[i].form
if text not in annotation_map:
# Handles only one word missing annotations, these might also occur when mention spans > 1 word
sents_repl.append(s.words[i].form)
else:
corrected_ann += 1
annotation_dict = annotation_map[text]
if len(list(annotation_dict.keys())) > 1:
# Multiple type of annotation for same word
print(text, annotation_dict)
# Find entity_type with max value (size of entity_spans list). Bcoz many words have duplicate annotations in same doc
# Resolve by finding majority
entity_type = max(annotation_dict.items(), key=operator.itemgetter(1))[0]
sents_repl.append('~~' + entity_type + '~~')
i += 1
else:
sents_repl.append('~~' + \
list(s.words[i].annotations[0].labels.keys())[0]\
+ '~~')
i += len(s.words[i].annotations[0].words)
sentinfo.append({'sentence': get_words(s),\
'replaced':' '.join(sents_repl),\
'relations' : [], 'key': s.key})
if corrected_ann > previous_corrected_ann:
previous_corrected_ann = corrected_ann
docs_corrected_ann += 1
for annotation in set(doc.annotations):
for linktype in annotation.links:
for tail in annotation.links[linktype]:
sent = annotation.words[0].sentkey
if sent != tail.words[0].sentkey:
continue
sentinfo[sent]['relations'].append({'head': {'text': get_words(annotation),
'type': \
list(annotation.labels.keys())[0],
'spans': annotation.spans},
'tail': {'text': get_words(tail),
'type': \
list(tail.labels.keys())[0],
'spans': tail.spans},
'relation_type': linktype})
all_docs.extend(sentinfo)
return all_docs
train_json = get_sentences(traindata)
test_json = get_sentences(testdata)
print('Corrected annotations', corrected_ann)
print('Corrected Docs', docs_corrected_ann)
with open('wetlabs_train.json', 'w') as f:
json.dump(train_json, f, indent = 4)
with open('wetlabs_test.json', 'w') as f:
json.dump(test_json, f, indent = 4)