-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathactive_learner.py
210 lines (174 loc) · 7.43 KB
/
active_learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import numpy as np
import pandas as pd
from tqdm import tqdm
class ActiveLearner:
"""
Main active learner object to handle AL training for systematic reviews
Methods:
- train: handler function for training
- initialise: initialise active learner attributes for training on a dataset
- initial_sampling: executes initial stage of training
- active_learn: executes main active learning loop
- random_learn: executes a random sampling learning loop if desired by the stopping criteria
"""
def __init__(self, model, selector, stopper, batch_size=10, max_iter=100, evaluator=None, verbose=False):
self.relevant_mask = None
self.indice_mask = None
self.N = None
self.data = None
self.data_indices = None
self.model = model
self.selector = selector
self.stopper = stopper
self.max_iter = max_iter
self.batch_size = batch_size
# handling progress and evaluator output
if verbose:
def progress(active_learner):
active_learner.pbar.update()
print('Recall:', active_learner.evaluator.recall[-1])
self.progress = progress
def end_progress(active_learner):
active_learner.pbar.close()
self.end_progress = end_progress
else:
self.pbar = None
self.progress = lambda *a: None
self.end_progress = lambda *a: None
if evaluator:
self.evaluator = evaluator
def initialise_evaluator(sample, test_data):
evaluator.initialise(sample, test_data)
self.initialise_evaluator = initialise_evaluator
def update_evaluator(m, sample, test_data):
evaluator.update(m, sample, test_data)
self.update_evaluator = update_evaluator
else:
self.initialise_evaluator = lambda *a: None
self.update_evaluator = lambda *a: None
# train (and test) active learner
def train(self, data):
"""
Training handler for the active learner
:param data: training dataset DataFrame
:return:
"""
self.initialise(data)
self.initial_sampling()
self.active_learn()
# initialise active learner parameters
def initialise(self, data):
"""
Initialise active learner parameters
:param data: full dataset DataFrame
"""
self.data = data
self.N = len(data)
# use new dataset indices instead of handling the data directly
self.data_indices = np.arange(self.N)
# create masks for training and testing instances
self.indice_mask = np.zeros(self.N, dtype=np.uint8)
self.relevant_mask = np.zeros(self.N, dtype=np.uint8)
# initial sampling from test set to be training instances
def initial_sampling(self):
"""
Handles the initial selection / sampling of training instances. Keeps sampling until training sample contains
instances belonging to every class
"""
test_data = self.data
test_indices = self.data_indices
while True:
# initial sampling
sample_indices = self.selector.initial_select(test_data, test_indices)
sample = self.data.iloc[sample_indices]
# update mask to include initial training instances
self.indice_mask[sample_indices] = 1
self.relevant_mask[sample_indices] = sample['y']
# get indices for training and testing instances
test_indices = self.data_indices[self.indice_mask == 0]
train_indices = self.data_indices[self.indice_mask == 1]
# new test dataset excludes screened instances
test_data = self.data.iloc[test_indices]
train_data = self.data.iloc[train_indices]
# initialise, update stopper
self.stopper.initialise(sample)
if self.stopper.stop:
break
# update evaluator
self.initialise_evaluator(sample, self.data)
# check if sample has two classes
sample_sum = sum(train_data['y'])
if sample_sum != len(train_data['y']) and sample_sum != 0:
break
# active learning loop
def active_learn(self):
"""
Handles the active learning loop: sample selection, model training, stopping
"""
train_indices = []
for i in range(self.max_iter):
# get indices for training and testing instances
test_indices = self.data_indices[self.indice_mask == 0]
train_indices = self.data_indices[self.indice_mask == 1]
if len(test_indices) == 0:
break
# add screened instances to training data
train_data = self.data.iloc[train_indices]
# new test dataset excludes screened instances
test_data = self.data.iloc[test_indices]
# train and test model
self.model.train(train_data['x'].apply(pd.Series), train_data['y'])
preds = self.model.test(test_data['x'].apply(pd.Series), test_data['y']) # note: -model.test(test_data)[:, 1] ??
# screen test instances
sample_indices = self.selector.select(test_indices, preds)
sample = self.data.iloc[sample_indices]
# add screened instances to training set
self.indice_mask[sample_indices] = 1
self.relevant_mask[sample_indices] = sample['y']
# update eval
self.update_evaluator(self.model, sample, self.data)
# print progress
self.progress(self)
# stopping criteria
self.stopper.stopping_criteria(sample)
if self.stopper.stop == 1:
break
# commence random sampling, no active learning
elif self.stopper.stop == -1:
self.random_learn()
break
# final model
train_data = self.data.iloc[train_indices]
self.model.train(train_data['x'].apply(pd.Series), train_data['y'])
self.model.test(self.data['x'].apply(pd.Series), self.data['y'])
self.end_progress(self)
# random learning loop
def random_learn(self):
"""
Handles the active learning loop: sample selection, model training, stopping
"""
train_indices = []
for i in range(self.max_iter):
# get indices for training and testing instances
test_indices = self.data_indices[self.indice_mask == 0]
train_indices = self.data_indices[self.indice_mask == 1]
if len(test_indices) == 0:
break
# new test dataset excludes screened instances
test_data = self.data.iloc[test_indices]
# screen test instances
sample_indices = self.selector.initial_select(test_data, test_indices)
sample = self.data.iloc[sample_indices]
# add screened instances to training set
self.indice_mask[sample_indices] = 1
self.relevant_mask[sample_indices] = sample['y']
# update eval
self.update_evaluator(self.model, sample, self.data)
# print progress
self.progress(self)
# stopping criteria
self.stopper.stopping_criteria(sample)
if self.stopper.stop:
break
# final
self.end_progress(self)