From 243ca6ef6f89429d5795dd4f9312cf8c9add4127 Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Wed, 15 Apr 2020 23:10:32 +0800 Subject: [PATCH] update web service --- python/paddle_serving_client/__init__.py | 4 + python/paddle_serving_server/web_service.py | 57 ++++++---- .../paddle_serving_server_gpu/web_service.py | 101 ++++++++++++++---- 3 files changed, 121 insertions(+), 41 deletions(-) diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index 765c368a..91a59299 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -235,6 +235,8 @@ class Client(object): int_feed_names.append(key) if isinstance(feed_i[key], np.ndarray): int_shape.append(list(feed_i[key].shape)) + else: + int_shape.append(self.feed_shapes_[key]) if isinstance(feed_i[key], np.ndarray): int_slot.append(np.reshape(feed_i[key], (-1)).tolist()) else: @@ -244,6 +246,8 @@ class Client(object): float_feed_names.append(key) if isinstance(feed_i[key], np.ndarray): float_shape.append(list(feed_i[key].shape)) + else: + float_shape.append(self.feed_shapes_[key]) if isinstance(feed_i[key], np.ndarray): float_slot.append( np.reshape(feed_i[key], (-1)).tolist()) diff --git a/python/paddle_serving_server/web_service.py b/python/paddle_serving_server/web_service.py index c1a86eae..d91f82d4 100755 --- a/python/paddle_serving_server/web_service.py +++ b/python/paddle_serving_server/web_service.py @@ -55,28 +55,41 @@ class WebService(object): "{}/serving_server_conf.prototxt".format(self.model_config)) self.client_service.connect(["0.0.0.0:{}".format(self.port + 1)]) - def get_prediction(self, request): - if not request.json: - abort(400) - if "fetch" not in request.json: - abort(400) - try: - feed, fetch = self.preprocess(request.json, request.json["fetch"]) - if isinstance(feed, list): - fetch_map_batch = self.client_service.predict( - feed_batch=feed, fetch=fetch) - fetch_map_batch = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) - result = {"result": fetch_map_batch} - elif isinstance(feed, dict): - if "fetch" in feed: - del feed["fetch"] - fetch_map = self.client_service.predict(feed=feed, fetch=fetch) - result = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map) - except ValueError: - result = {"result": "Request Value Error"} - return result + @app_instance.route(service_name, methods=['POST']) + def get_prediction(): + if not request.json: + abort(400) + if "fetch" not in request.json: + abort(400) + try: + feed, fetch = self.preprocess(request.json, + request.json["fetch"]) + if isinstance(feed, list): + fetch_map_batch = client_service.predict( + feed_batch=feed, fetch=fetch) + fetch_map_batch = self.postprocess( + feed=request.json, + fetch=fetch, + fetch_map=fetch_map_batch) + for key in fetch_map_batch: + fetch_map_batch[key] = fetch_map_batch[key].tolist() + result = {"result": fetch_map_batch} + elif isinstance(feed, dict): + if "fetch" in feed: + del feed["fetch"] + fetch_map = client_service.predict(feed=feed, fetch=fetch) + for key in fetch_map: + fetch_map[key] = fetch_map[key][0].tolist() + result = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map) + except ValueError: + result = {"result": "Request Value Error"} + return result + + app_instance.run(host="0.0.0.0", + port=self.port, + threaded=False, + processes=1) def run_server(self): import socket diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index 1bb8e93b..4e0d30f9 100755 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -90,25 +90,88 @@ class WebService(object): self.client = Client() self.client.load_client_config("{}/serving_server_conf.prototxt".format( self.model_config)) - endpoints = "" - if gpu_num > 0: - for i in range(gpu_num): - endpoints += "127.0.0.1:{},".format(self.port + i + 1) - else: - endpoints = "127.0.0.1:{}".format(self.port + 1) - self.client.connect([endpoints]) - - def get_prediction(self, request): - if not request.json: - abort(400) - if "fetch" not in request.json: - abort(400) - feed, fetch = self.preprocess(request.json, request.json["fetch"]) - fetch_map_batch = self.client.predict(feed=feed, fetch=fetch) - fetch_map_batch = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) - result = {"result": fetch_map_batch} - return result + + client.connect([endpoint]) + while True: + request_json = inputqueue.get() + try: + feed, fetch = self.preprocess(request_json, + request_json["fetch"]) + if isinstance(feed, list): + fetch_map_batch = client.predict( + feed_batch=feed, fetch=fetch) + fetch_map_batch = self.postprocess( + feed=request_json, + fetch=fetch, + fetch_map=fetch_map_batch) + for key in fetch_map_batch: + fetch_map_batch[key] = fetch_map_batch[key].tolist() + result = {"result": fetch_map_batch} + elif isinstance(feed, dict): + if "fetch" in feed: + del feed["fetch"] + fetch_map = client.predict(feed=feed, fetch=fetch) + for key in fetch_map: + fetch_map[key] = fetch_map[key][0].tolist() + result = self.postprocess( + feed=request_json, fetch=fetch, fetch_map=fetch_map) + self.output_queue.put(result) + except ValueError: + self.output_queue.put(-1) + + def _launch_web_service(self, gpu_num): + app_instance = Flask(__name__) + service_name = "/" + self.name + "/prediction" + + self.input_queues = [] + self.output_queue = Queue() + for i in range(gpu_num): + self.input_queues.append(Queue()) + + producer_list = [] + for i, input_q in enumerate(self.input_queues): + producer_processes = Process( + target=self.producers, + args=( + input_q, + "0.0.0.0:{}".format(self.port + 1 + i), )) + producer_list.append(producer_processes) + + for p in producer_list: + p.start() + + client = Client() + client.load_client_config("{}/serving_server_conf.prototxt".format( + self.model_config)) + client.connect(["0.0.0.0:{}".format(self.port + 1)]) + + self.idx = 0 + + @app_instance.route(service_name, methods=['POST']) + def get_prediction(): + if not request.json: + abort(400) + if "fetch" not in request.json: + abort(400) + + self.input_queues[self.idx].put(request.json) + + #self.input_queues[0].put(request.json) + self.idx += 1 + if self.idx >= len(self.gpus): + self.idx = 0 + result = self.output_queue.get() + if not isinstance(result, dict) and result == -1: + result = {"result": "Request Value Error"} + return result + + app_instance.run(host="0.0.0.0", + port=self.port, + threaded=False, + processes=1) + + for p in producer_list: + p.join() def run_server(self): import socket -- GitLab