diff --git a/python/paddle_serving_server/serve.py b/python/paddle_serving_server/serve.py index 088e3928f4409eaac4d42d771a72ecc9d13fdbce..b57d2253dbe1f14caff50eb79543f224b8d0ec45 100644 --- a/python/paddle_serving_server/serve.py +++ b/python/paddle_serving_server/serve.py @@ -19,6 +19,7 @@ Usage: """ import argparse from .web_service import WebService +from flask import Flask, request def parse_args(): # pylint: disable=doc-string-missing @@ -88,3 +89,20 @@ if __name__ == "__main__": service.prepare_server( workdir=args.workdir, port=args.port, device=args.device) service.run_server() + + app_instance = Flask(__name__) + + @app_instance.before_first_request + def init(): + service._launch_web_service() + + service_name = "/" + service.name + "/prediction" + + @app_instance.route(service_name, methods=["POST"]) + def run(): + return service.get_prediction(request) + + app_instance.run(host="0.0.0.0", + port=service.port, + threaded=False, + processes=4) diff --git a/python/paddle_serving_server/web_service.py b/python/paddle_serving_server/web_service.py index e94916ccf371022544707e7bb8e03d37045e54b5..c1a86eaecc899c987bd346f8a747fb486d4789ee 100755 --- a/python/paddle_serving_server/web_service.py +++ b/python/paddle_serving_server/web_service.py @@ -50,44 +50,33 @@ class WebService(object): self.device = device def _launch_web_service(self): - app_instance = Flask(__name__) - client_service = Client() - client_service.load_client_config( + self.client_service = Client() + self.client_service.load_client_config( "{}/serving_server_conf.prototxt".format(self.model_config)) - client_service.connect(["0.0.0.0:{}".format(self.port + 1)]) - service_name = "/" + self.name + "/prediction" + self.client_service.connect(["0.0.0.0:{}".format(self.port + 1)]) - @app_instance.route(service_name, methods=['POST']) - def get_prediction(): - if not request.json: - abort(400) - if "fetch" not in request.json: - abort(400) - try: - feed, fetch = self.preprocess(request.json, - request.json["fetch"]) - if isinstance(feed, list): - fetch_map_batch = client_service.predict( - feed_batch=feed, fetch=fetch) - fetch_map_batch = self.postprocess( - feed=request.json, - fetch=fetch, - fetch_map=fetch_map_batch) - result = {"result": fetch_map_batch} - elif isinstance(feed, dict): - if "fetch" in feed: - del feed["fetch"] - fetch_map = client_service.predict(feed=feed, fetch=fetch) - result = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map) - except ValueError: - result = {"result": "Request Value Error"} - return result - - app_instance.run(host="0.0.0.0", - port=self.port, - threaded=False, - processes=1) + def get_prediction(self, request): + if not request.json: + abort(400) + if "fetch" not in request.json: + abort(400) + try: + feed, fetch = self.preprocess(request.json, request.json["fetch"]) + if isinstance(feed, list): + fetch_map_batch = self.client_service.predict( + feed_batch=feed, fetch=fetch) + fetch_map_batch = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) + result = {"result": fetch_map_batch} + elif isinstance(feed, dict): + if "fetch" in feed: + del feed["fetch"] + fetch_map = self.client_service.predict(feed=feed, fetch=fetch) + result = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map) + except ValueError: + result = {"result": "Request Value Error"} + return result def run_server(self): import socket @@ -96,11 +85,7 @@ class WebService(object): print("http://{}:{}/{}/prediction".format(localIP, self.port, self.name)) p_rpc = Process(target=self._launch_rpc_service) - p_web = Process(target=self._launch_web_service) p_rpc.start() - p_web.start() - p_web.join() - p_rpc.join() def preprocess(self, feed={}, fetch=[]): return feed, fetch