main_server.py 3.1 KB
Newer Older
0
0YuanZhang0 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import json
import sys
import logging
logging.basicConfig(
    level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
import requests
from flask import Flask
from flask import Response
from flask import request
import numpy as np
0
0YuanZhang0 已提交
14
import argparse
0
0YuanZhang0 已提交
15 16 17 18 19
from multiprocessing.dummy import Pool as ThreadPool

app = Flask(__name__)

logger = logging.getLogger('flask')
0
0YuanZhang0 已提交
20

0
0YuanZhang0 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52

def ensemble_example(answers, n_models=None):
    if n_models is None:
        n_models = len(answers)
    answer_dict = dict()
    for nbest_predictions in answers:
        for prediction in nbest_predictions:
            score_list = answer_dict.setdefault(prediction['text'], [])
            score_list.append(prediction['probability'])

    ensemble_nbest_predictions = []
    for answer, scores in answer_dict.items():
        prediction = dict()
        prediction['text'] = answer
        prediction['probability'] = np.sum(scores) / n_models
        ensemble_nbest_predictions.append(prediction)

    ensemble_nbest_predictions = \
        sorted(ensemble_nbest_predictions, key=lambda item: item['probability'], reverse=True)
    return ensemble_nbest_predictions


@app.route('/', methods=['POST'])
def mrqa_main():
    """Description"""
    # parse input data
    pred = {}
    def _call_model(url, input_json):
        nbest = requests.post(url, json=input_json)
        return nbest
    try:
        input_json = request.get_json(silent=True)
0
0YuanZhang0 已提交
53 54 55 56 57 58
        n_models = len(urls)
        pool = ThreadPool(n_models)
        results = []
        for url in urls:
            result = pool.apply_async(_call_model, (url, input_json))
            results.append(result.get())
0
0YuanZhang0 已提交
59 60
        pool.close()
        pool.join()
0
0YuanZhang0 已提交
61 62
        nbests = [nbest.json()['results'] for nbest in results]
        qids = list(nbests[0].keys())
0
0YuanZhang0 已提交
63
        for qid in qids:
0
0YuanZhang0 已提交
64
            ensemble_nbest = ensemble_example([nbest[qid] for nbest in nbests], n_models=n_models)
0
0YuanZhang0 已提交
65 66 67 68 69 70 71 72 73
            pred[qid] = ensemble_nbest[0]['text']
    except Exception as e:
        pred['error'] = 'empty'
        logger.exception(e)

    return Response(json.dumps(pred), mimetype='application/json')


if __name__ == '__main__':
0
0YuanZhang0 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
    url_1 = 'http://127.0.0.1:5118'   # url for ernie
    url_2 = 'http://127.0.0.1:5119'   # url for xl-net
    url_3 = 'http://127.0.0.1:5120'   # url for bert
    parser = argparse.ArgumentParser('main server')
    parser.add_argument('--ernie', action='store_true', default=False, help="Include ERNIE")
    parser.add_argument('--xlnet', action='store_true', default=False, help="Include XL-NET")
    parser.add_argument('--bert', action='store_true', default=False, help="Include BERT")
    args = parser.parse_args()
    urls = []
    if args.ernie:
        print('Include ERNIE model')
        urls.append(url_1)
    if args.xlnet:
        print('Include XL-NET model')
        urls.append(url_2)
    if args.bert:
        print('Include BERT model')
        urls.append(url_3)
    assert len(urls) > 0, "At lease one model is required"
0
0YuanZhang0 已提交
93 94
    app.run(host='127.0.0.1', port=5121, debug=False, threaded=False, processes=1)