-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
400 lines (283 loc) · 14.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
import matplotlib.pyplot as plt
#######################################################################################################################
""" Groups all cases by metric, creating a dictionary with each group value as a key
@:param test_data: List of lists, the full data set to pull groups from
@:param categories: List of column titles, stored separately from the numerical data
@:param metric: String used to determine groups, i.e 'race'
@:param mappings: Dictionary mappings from training, used to convert the data between numerical and string format
@:param predictions: List of predicted values produced from a machine learning model.
@:param labels: List of labels for the test data
@:note: test_data, predictions, and labels should correspond to one another
@:note: Groups with less than 50 members are not considered sufficiently representative and are excluded.
@:returns total_cases: Dictionary with each group value as keys. Each key has a list of (prediction, label) tuples
representing all of the data points within that group
"""
def get_cases_by_metric(test_data, categories, metric, mappings, predictions, labels):
total_cases = {}
index = -1
for i in range(len(categories)):
if metric in categories[i]:
index = i
break
for value in mappings[metric].keys():
cases = []
for i in range(len(test_data)):
if test_data[i][index] == mappings[metric][value]:
cases.append((float(predictions[i]), int(labels[i])))
# Only include groups that have more than 50 members
if len(cases) > 50:
total_cases[value] = cases
return total_cases
#######################################################################################################################
"""Applies a threshold to real-valued model predictions to make them either 0 or 1. Values above the threshold become
1's, values below or equal to the threshold become 0's.
@:param predictions: Tuples of the form (prediction, label), such as those returned by get_cases_by_metric
@:param threshold: Float or Int value used to calculate the predicted value
@:returns predictions: The thresholded version of the same input (prediction, label) tuples
"""
def apply_threshold(prediction_label_pairs, threshold):
threshed = [(0, 0)] * len(prediction_label_pairs)
for i in range(len(prediction_label_pairs)):
if prediction_label_pairs[i][0] <= threshold:
threshed[i] = (0, prediction_label_pairs[i][1])
else:
threshed[i] = (1, prediction_label_pairs[i][1])
return threshed
#######################################################################################################################
"""Gets the total accuracy of a set of classifications
@:param classifications: a dictionary of all the classifications, separated into groups. Each group contains
a list of (prediction, label) tuples
@:note: assumes that the predictions have been already thresholded
@:returns total_accuracy: the total accuracy of the classifications
"""
def get_total_accuracy(classifications):
total_correct = 0.0
total_num_cases = 0.0
for group in classifications.keys():
for prediction, label in classifications[group]:
total_num_cases += 1.0
if prediction == label:
total_correct += 1.0
return total_correct / total_num_cases
#######################################################################################################################
"""Determines the number of correct predictions in a group
@:param prediction_label_pairs: List of (prediction, label) tuples
@:note: Assumes predictions have already been thresholded
@:returns num_correct: Int value of correct predictions. Dividing this by len(category) would give the
accuracy for the group
"""
def get_num_correct(prediction_label_pairs):
num_correct = 0
for pair in prediction_label_pairs:
prediction = int(pair[0])
label = int(pair[1])
if prediction == label:
num_correct += 1
return num_correct
#######################################################################################################################
"""Determines the number of false positives in a group
@:param prediction_label_pairs: List of (prediction, label) tuples
@:note: Assumes predictions have already been thresholded
@:returns false_positives: The number of false positives (prediction == 1, label == 0)
"""
def get_num_false_positives(prediction_label_pairs):
false_positives = 0
for pair in prediction_label_pairs:
prediction = int(pair[0])
label = int(pair[1])
if prediction == 1 and label == 0:
false_positives += 1
return false_positives
#######################################################################################################################
"""Determines the rate of false positives in a group
@:param prediction_label_pairs: List of (prediction, label) tuples
@:note: Assumes predictions have already been thresholded
@:returns FPR: The number of false positives divided by the number of labelled negatives. Will
return 0 to avoid divide by 0, but in practice there should be no instances of no
labelled negatives.
"""
def get_false_positive_rate(prediction_label_pairs):
false_positives = 0
labelled_negatives = 0
for pair in prediction_label_pairs:
prediction = int(pair[0])
label = int(pair[1])
if label == 0:
labelled_negatives += 1
if prediction == 1:
false_positives += 1
if labelled_negatives != 0:
return false_positives / labelled_negatives
else:
return 0
#######################################################################################################################
"""Determines the number of true negatives in a group
@:param prediction_label_pairs: List of (prediction, label) tuples
@:note: Assumes predictions have already been thresholded
@:returns true_negatives The number of true negatives (prediction == 0, label == 0)
"""
def get_num_true_negatives(prediction_label_pairs):
true_negatives = 0
for pair in prediction_label_pairs:
prediction = int(pair[0])
label = int(pair[1])
if prediction == 0 and label == 0:
true_negatives += 1
return true_negatives
#######################################################################################################################
"""Determines the rate of true negatives in a group
@:param prediction_label_pairs: List of (prediction, label) tuples
@:note: Assumes predictions have already been thresholded
@:returns TNR: 1 - false_positive_rate.
"""
def get_true_negative_rate(prediction_labels_pairs):
return 1 - get_false_positive_rate(prediction_labels_pairs)
#######################################################################################################################
"""Determines the number of false negatives in a group
@:param prediction_label_pairs: List of (prediction, label) tuples
@:note: Assumes predictions have already been thresholded
@:returns false_negatives The number of false negatives (prediction == 0, label == 1)
"""
def get_num_false_negatives(prediction_label_pairs):
false_negatives = 0
for pair in prediction_label_pairs:
prediction = int(pair[0])
label = int(pair[1])
if prediction == 0 and label == 1:
false_negatives += 1
return false_negatives
#######################################################################################################################
"""Determines the rate of false negatives in a group
@:param prediction_label_pairs: List of (prediction, label) tuples
@:note: Assumes predictions have already been thresholded
@:returns FNR: The number of false negatives divided by the number of labelled positives. Will
return 0 to avoid divide by 0, but in practice there should be no instances of no
labelled positives.
"""
def get_false_negative_rate(prediction_label_pairs):
false_negatives = 0
labelled_positives = 0
for pair in prediction_label_pairs:
prediction = int(pair[0])
label = int(pair[1])
if label == 1:
labelled_positives += 1
if prediction == 0:
false_negatives += 1
if labelled_positives != 0:
return false_negatives / labelled_positives
else:
return 0
#######################################################################################################################
"""Determines the number of true positives in a group
@:param prediction_label_pairs: List of (prediction, label) tuples
@:note: Assumes predictions have already been thresholded
@:returns true_positives The number of true positives (prediction == 1, label == 1)
"""
def get_num_true_positives(prediction_label_pairs):
true_positives = 0
for pair in prediction_label_pairs:
prediction = int(pair[0])
label = int(pair[1])
if prediction == 1 and label == 1:
true_positives += 1
return true_positives
#######################################################################################################################
"""Determines the rate of true positives in a group
@:param prediction_label_pairs: List of (prediction, label) tuples
@:note: Assumes predictions have already been thresholded
@:returns TPR: 1 - false_negative_rate.
"""
def get_true_positive_rate(category):
return 1 - get_false_negative_rate(category)
#######################################################################################################################
"""Determines the number of samples that have a positive prediction
@:param prediction_label_pairs: List of (prediction, label) tuples
@:note: Assumes predictions have already been thresholded
@:returns predicted_positives Number of samples with a positive prediction"""
def get_num_predicted_positives(prediction_label_pairs):
predicted_positives = 0
for pair in prediction_label_pairs:
prediction = int(pair[0])
if prediction == 1:
predicted_positives += 1
return predicted_positives
#######################################################################################################################
"""Determines the positive predictive value of a group, defined as the number of true positives divided by the
number of predicted positives
@:param prediction_label_pairs: List of (prediction, label) tuples
@:note: Assumes predictions have already been thresholded
@:returns PPV: true positives / predicted positives
"""
def get_positive_predictive_value(prediction_label_pairs):
true_positives = get_num_true_positives(prediction_label_pairs)
predicted_positives = get_num_predicted_positives(prediction_label_pairs)
if predicted_positives == 0:
return 0
else:
return true_positives / predicted_positives
#######################################################################################################################
"""Calculates the Fscore (or harmonic mean) of a group. Used as a substitute for accuracy when data is skewed
@:param prediction_label_pairs: List of (prediction, label) tuples
@:note: Assumes predictions have already been thresholded
@:returns Fscore Harmonic mean, defined as 2 * (precision * recall) + (precision + recall)
"""
def calculate_Fscore(prediction_label_pairs):
precision = get_positive_predictive_value(prediction_label_pairs)
recall = get_true_positive_rate(prediction_label_pairs)
numerator = precision * recall
denominator = precision + recall
return 2 * (numerator/denominator)
#######################################################################################################################
def get_ROC_data(prediction_label_pairs, group):
true_positives = []
false_positives = []
for i in range(1, 101):
threshold = float(i) / 100.0
eval_copy = list.copy(prediction_label_pairs)
eval_copy = apply_threshold(eval_copy, threshold)
TPR = get_true_positive_rate(eval_copy)
FPR = get_false_positive_rate(eval_copy)
true_positives.append(TPR)
false_positives.append(FPR)
return (true_positives, false_positives, group)
#######################################################################################################################
def plot_ROC_data(ROC_data_list):
for curve in ROC_data_list:
TPR = curve[0]
FPR = curve[1]
title = curve[2]
plt.plot(FPR, TPR, label=title)
plt.legend()
axes = plt.gca()
axes.set_xlim([0.0, 1.0])
axes.set_ylim([0.0, 1.0])
plt.ylabel("True Positive Rate")
plt.xlabel("False Positive Rate")
plt.show()
#######################################################################################################################
def apply_financials(data, group_level=False):
# Costs for the various categories
tp_val = -60076
tn_val = 23088
fp_val = -110076
fn_val = -202330
full_list = []
if group_level:
full_list = data
else:
for group in data.keys():
full_list += data[group]
num_tp = get_num_true_positives(full_list)
num_tn = get_num_true_negatives(full_list)
num_fp = get_num_false_positives(full_list)
num_fn = get_num_false_negatives(full_list)
total = 0.0
total += num_tp * tp_val
total += num_tn * tn_val
total += num_fp * fp_val
total += num_fn * fn_val
return total
#######################################################################################################################
def compare_probs(prob1, prob2, epsilon):
return abs(prob1 - prob2) <= epsilon