-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathparse_responses.py
69 lines (56 loc) · 2.2 KB
/
parse_responses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import json
import os
import argparse
import pandas as pd
from utils.metric_fns import METRIC_FNS
from utils.dataset_loading import get_dataframe_from_local_file
from utils.task_metadata import get_metadata_for_task
def get_pbase_response(response_data):
if "generated_text" in response_data[1]:
return response_data[1]["generated_text"]
def main(main_args):
task_metadata = get_metadata_for_task(main_args.task)
df = get_dataframe_from_local_file(task_metadata, main_args.num_examples)
skipped = []
metric_fn = METRIC_FNS[task_metadata["metric_name"]]
scores = []
with open(main_args.jsonl_responses_path) as file:
for line in file:
response_data = json.loads(line)
metadata = response_data[2]
df_index = metadata["df_index"]
generated_text = get_pbase_response(response_data)
target_text = df.loc[df.index == df_index][
task_metadata["target_col"]
].item()
if generated_text is not None:
score = metric_fn(generated_text, target_text)
scores.append(score)
else:
skipped.append(df_index)
overall_score = sum(scores) / len(scores)
print(f"Overall {task_metadata['metric_name']}: {overall_score:.3f}")
print(f"Skipped {len(skipped)} examples: {skipped}")
with open(
os.path.join(os.path.dirname(main_args.jsonl_responses_path), f"skipped.txt"),
"w",
) as file:
file.write(f"Skipped {len(skipped)} examples: {skipped}")
with open(
os.path.join(
os.path.dirname(main_args.jsonl_responses_path),
f"{task_metadata['metric_name']}.txt",
),
"w",
) as file:
file.write(f"Overall {task_metadata['metric_name']}: {overall_score:.3f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="Parse responses from Predibase.",
description="Parse responses from Predibase.",
)
parser.add_argument("--jsonl_responses_path", required=True)
parser.add_argument("--task", required=True)
parser.add_argument("--num_examples", default=None)
main_args = parser.parse_args()
main(main_args)