client.py 1.1 KB
Newer Older
0
0YuanZhang0 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Query the MRQA model server to generate predictions.
"""
import argparse
import json
import requests
import time


if __name__ == '__main__':
    parse = argparse.ArgumentParser("")
    parse.add_argument("dataset")
    parse.add_argument("output_file")
    parse.add_argument("port", type=int)
    args = parse.parse_args()

    all_predictions = {}
    contexts = []
    f = open(args.dataset)
    for example in f:
        context = json.loads(example)
        if 'header' in context:
            continue
        contexts.append(context)
    f.close()

    results = {}
    cnt = 0
    for context in contexts:
        cnt += 1
        start = time.time()
        pred = requests.post('http://127.0.0.1:%d' % args.port, json=context)
        result = pred.json()
        results.update(result)
        end=time.time()
        print('----- request cnt: {}, time elapsed: {:.2f} ms -----'.format(cnt, (end - start)*1000))
        for qid, answer in result.items():
            print('{}: {}'.format(qid, answer.encode('utf-8')))
    with open(args.output_file,'w') as f:
        json.dump(results, f, indent=1)