serve.py 1.3 KB
Newer Older
0
0YuanZhang0 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Provide MRC service for TOP1 short answer extraction system
Note the services here share some global pre/post process objects, which
are **NOT THREAD SAFE**. Try to use multi-process instead of multi-thread
for deployment.
"""
import json
import sys
import logging
logging.basicConfig(
    level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
import requests
from flask import Flask
from flask import Response
from flask import request
import server_utils
import wrapper as bert_wrapper

assert len(sys.argv) == 3 or len(sys.argv) == 4, "Usage: python serve.py <model_dir> <port> [process_mode]"
if len(sys.argv) == 3:
    _, model_dir, port = sys.argv
    mode = 'parallel'
else:
    _, model_dir, port, mode = sys.argv

app = Flask(__name__)
app.logger.setLevel(logging.INFO)
bert_model = bert_wrapper.BertModelWrapper(model_dir=model_dir)
server = server_utils.BasicMRCService('Short answer MRC service', app.logger)

@app.route('/', methods=['POST'])
def mrqa_service():
    """Description"""
    model = bert_model
    return server(model, process_mode=mode, max_batch_size=5)
    # return server(model)


if __name__ == '__main__':
    app.run(port=port, debug=False, threaded=False, processes=1)