macro_avg.py 1.0 KB
Newer Older
0
0YuanZhang0 已提交
1 2
import numpy as np
import argparse
0
0YuanZhang0 已提交
3
import json
0
0YuanZhang0 已提交
4 5 6
import re

def extract_score(line):
0
0YuanZhang0 已提交
7 8 9 10
    score_json = json.loads(line)
    f1 = score_json['f1']
    em = score_json['exact_match']
    return float(f1), float(em)
0
0YuanZhang0 已提交
11 12 13 14 15 16 17 18


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
    description='Calculate macro average for MRQA')
    parser.add_argument('input_file', help='Score file')
    args = parser.parse_args()
    with open(args.input_file) as fin:
0
0YuanZhang0 已提交
19
        lines = list(map(str.strip, fin.readlines()))
0
0YuanZhang0 已提交
20 21 22 23 24 25 26 27
    in_domain_scores = {}
    for dataset_id in range(0, 12, 2):
        f1, em = extract_score(lines[dataset_id+1])
        in_domain_scores[lines[dataset_id]] = f1
    out_of_domain_scores = {}
    for dataset_id in range(12, 24, 2):
        f1, em = extract_score(lines[dataset_id+1])
        out_of_domain_scores[lines[dataset_id]] = f1
0
0YuanZhang0 已提交
28 29
    print('In domain avg: {}'.format(sum(in_domain_scores.values()) / len(in_domain_scores.values())))
    print('Out of domain avg: {}'.format(sum(out_of_domain_scores.values()) / len(in_domain_scores.values())))