From 99fdb7e9a8e1b963238e072fe99f35acf6c3cd07 Mon Sep 17 00:00:00 2001 From: MRXLT Date: Fri, 24 Apr 2020 20:07:54 +0800 Subject: [PATCH] fix bug && unify http service --- python/paddle_serving_server/web_service.py | 32 +++++++------------ .../paddle_serving_server_gpu/web_service.py | 19 +++++++---- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/python/paddle_serving_server/web_service.py b/python/paddle_serving_server/web_service.py index adf0d68c..a0364972 100755 --- a/python/paddle_serving_server/web_service.py +++ b/python/paddle_serving_server/web_service.py @@ -67,10 +67,10 @@ class WebService(object): break def _launch_web_service(self): - self.client_service = Client() - self.client_service.load_client_config( - "{}/serving_server_conf.prototxt".format(self.model_config)) - self.client_service.connect(["0.0.0.0:{}".format(self.port_list[0])]) + self.client = Client() + self.client.load_client_config("{}/serving_server_conf.prototxt".format( + self.model_config)) + self.client.connect(["0.0.0.0:{}".format(self.port_list[0])]) def get_prediction(self, request): if not request.json: @@ -79,22 +79,14 @@ class WebService(object): abort(400) try: feed, fetch = self.preprocess(request.json, request.json["fetch"]) - if isinstance(feed, list): - fetch_map_batch = self.client_service.predict( - feed_batch=feed, fetch=fetch) - fetch_map_batch = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) - for key in fetch_map_batch: - fetch_map_batch[key] = fetch_map_batch[key].tolist() - result = {"result": fetch_map_batch} - elif isinstance(feed, dict): - if "fetch" in feed: - del feed["fetch"] - fetch_map = self.client_service.predict(feed=feed, fetch=fetch) - for key in fetch_map: - fetch_map[key] = fetch_map[key][0].tolist() - result = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map) + if isinstance(feed, dict) and "fetch" in feed: + del feed["fetch"] + fetch_map = self.client.predict(feed=feed, fetch=fetch) + fetch_map = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map) + for key in fetch_map: + fetch_map[key] = fetch_map[key].tolist() + result = {"result": fetch_map} except ValueError: result = {"result": "Request Value Error"} return result diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index ed7f16c5..6841220f 100644 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -121,13 +121,18 @@ class WebService(object): abort(400) if "fetch" not in request.json: abort(400) - feed, fetch = self.preprocess(request.json, request.json["fetch"]) - fetch_map_batch = self.client.predict(feed=feed, fetch=fetch) - fetch_map_batch = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) - for key in fetch_map_batch: - fetch_map_batch[key] = fetch_map_batch[key].tolist() - result = {"result": fetch_map_batch} + try: + feed, fetch = self.preprocess(request.json, request.json["fetch"]) + if isinstance(feed, dict) and "fetch" in feed: + del feed["fetch"] + fetch_map = self.client.predict(feed=feed, fetch=fetch) + fetch_map = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map) + for key in fetch_map: + fetch_map[key] = fetch_map[key].tolist() + result = {"result": fetch_map} + except ValueError: + result = {"result": "Request Value Error"} return result def run_server(self): -- GitLab