-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_keywords.py
75 lines (50 loc) · 2.11 KB
/
extract_keywords.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import yake
import random
import shutil
import pandas as pd
import gensim.downloader as api
os.environ['GENSIM_DATA_DIR']='/cluster/scratch/goezsoy/nlp_lss_datasets'
data_dir = '/cluster/scratch/goezsoy/nlp_lss_datasets'
keyword_file_name = 'valid_df_keywords_20k.txt'
encoder = api.load("glove-wiki-gigaword-300")
number_of_keywords = 3
max_ngram_size = 1
total_shards = 10
prompt_len = 25 # words
custom_kw_extractor = yake.KeywordExtractor(n=max_ngram_size)
valid_df = pd.read_csv(os.path.join(data_dir,'processed_df_valid.csv'))
with open(keyword_file_name, 'w') as f:
for _, row in valid_df.iterrows():
text = row.speech
keywords_dict = custom_kw_extractor.extract_keywords(text)
keywords = list(map(lambda temp_dict: temp_dict[0].lower(), keywords_dict))
filtered_keywords = list(filter(lambda kw: kw in encoder.key_to_index.keys(), keywords))
selected_keywords = ', '.join(random.sample(filtered_keywords,min(len(filtered_keywords),number_of_keywords)))
# no prompt
if text.split()[0] in ['Mr.','Madam']:
selected_keywords = text.split(',')[0] + '|| ' + selected_keywords + '\n'
else:
selected_keywords = '<|endoftext|>' + '|| ' + selected_keywords + '\n'
# prompt
#selected_keywords = ' '.join(text.split()[:prompt_len]) + '|| ' + selected_keywords + '\n'
f.write(selected_keywords)
f.close()
speech_per_shard = len(valid_df)//total_shards
with open(keyword_file_name, 'r') as fp:
speech_counter = 0
shard_counter = 0
shard_path = None
for line in fp:
if speech_counter == speech_per_shard:
shard_counter += 1
speech_counter = 0
if speech_counter == 0:
shard_path = os.path.join('data', 'shard'+str(shard_counter))
if os.path.exists(shard_path):
shutil.rmtree(shard_path)
os.mkdir(shard_path)
with open(os.path.join(shard_path,'keywords.txt'), 'a') as f:
f.write(line)
f.close()
speech_counter += 1