start_service.py 1.2 KB
Newer Older
0
0YuanZhang0 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Provide MRC service for TOP1 short answer extraction system
Note the services here share some global pre/post process objects, which
are **NOT THREAD SAFE**. Try to use multi-process instead of multi-thread
for deployment.
"""
import json
import sys
import logging
logging.basicConfig(
    level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
import requests
from flask import Flask
from flask import Response
from flask import request
import mrc_service
import model_wrapper
import argparse


assert len(sys.argv) == 3 or len(sys.argv) == 4, "Usage: python serve.py <model_dir> <port> [process_mode]"
if len(sys.argv) == 3:
    _, model_dir, port = sys.argv
    mode = 'parallel'
else:
    _, model_dir, port, mode = sys.argv

max_batch_size = 5

app = Flask(__name__)
app.logger.setLevel(logging.INFO)
model = model_wrapper.BertModelWrapper(model_dir=model_dir)
server = mrc_service.MRQAService('MRQA service', app.logger)

@app.route('/', methods=['POST'])
def mrqa_service():
    """Description"""
    return server(model, process_mode=mode, max_batch_size=max_batch_size)


if __name__ == '__main__':
    app.run(port=port, debug=False, threaded=False, processes=1)