提交 933ac626 编写于 作者: M MRXLT

refine cpu web service

上级 60a375c4
......@@ -19,6 +19,7 @@ Usage:
"""
import argparse
from .web_service import WebService
from flask import Flask, request
def parse_args(): # pylint: disable=doc-string-missing
......@@ -88,3 +89,20 @@ if __name__ == "__main__":
service.prepare_server(
workdir=args.workdir, port=args.port, device=args.device)
service.run_server()
app_instance = Flask(__name__)
@app_instance.before_first_request
def init():
service._launch_web_service()
service_name = "/" + service.name + "/prediction"
@app_instance.route(service_name, methods=["POST"])
def run():
return service.get_prediction(request)
app_instance.run(host="0.0.0.0",
port=service.port,
threaded=False,
processes=4)
......@@ -50,44 +50,33 @@ class WebService(object):
self.device = device
def _launch_web_service(self):
app_instance = Flask(__name__)
client_service = Client()
client_service.load_client_config(
self.client_service = Client()
self.client_service.load_client_config(
"{}/serving_server_conf.prototxt".format(self.model_config))
client_service.connect(["0.0.0.0:{}".format(self.port + 1)])
service_name = "/" + self.name + "/prediction"
self.client_service.connect(["0.0.0.0:{}".format(self.port + 1)])
@app_instance.route(service_name, methods=['POST'])
def get_prediction():
if not request.json:
abort(400)
if "fetch" not in request.json:
abort(400)
try:
feed, fetch = self.preprocess(request.json,
request.json["fetch"])
if isinstance(feed, list):
fetch_map_batch = client_service.predict(
feed_batch=feed, fetch=fetch)
fetch_map_batch = self.postprocess(
feed=request.json,
fetch=fetch,
fetch_map=fetch_map_batch)
result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = client_service.predict(feed=feed, fetch=fetch)
result = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map)
except ValueError:
result = {"result": "Request Value Error"}
return result
app_instance.run(host="0.0.0.0",
port=self.port,
threaded=False,
processes=1)
def get_prediction(self, request):
if not request.json:
abort(400)
if "fetch" not in request.json:
abort(400)
try:
feed, fetch = self.preprocess(request.json, request.json["fetch"])
if isinstance(feed, list):
fetch_map_batch = self.client_service.predict(
feed_batch=feed, fetch=fetch)
fetch_map_batch = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map_batch)
result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = self.client_service.predict(feed=feed, fetch=fetch)
result = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map)
except ValueError:
result = {"result": "Request Value Error"}
return result
def run_server(self):
import socket
......@@ -96,11 +85,7 @@ class WebService(object):
print("http://{}:{}/{}/prediction".format(localIP, self.port,
self.name))
p_rpc = Process(target=self._launch_rpc_service)
p_web = Process(target=self._launch_web_service)
p_rpc.start()
p_web.start()
p_web.join()
p_rpc.join()
def preprocess(self, feed={}, fetch=[]):
return feed, fetch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册