diff --git a/python/paddle_serving_server/web_service.py b/python/paddle_serving_server/web_service.py index adf0d68c08c040140f4e4e5bb1b4f6b7743599a7..a03649725b1c41ca94b8ef495a2fc80e8293aba0 100755 --- a/python/paddle_serving_server/web_service.py +++ b/python/paddle_serving_server/web_service.py @@ -67,10 +67,10 @@ class WebService(object): break def _launch_web_service(self): - self.client_service = Client() - self.client_service.load_client_config( - "{}/serving_server_conf.prototxt".format(self.model_config)) - self.client_service.connect(["0.0.0.0:{}".format(self.port_list[0])]) + self.client = Client() + self.client.load_client_config("{}/serving_server_conf.prototxt".format( + self.model_config)) + self.client.connect(["0.0.0.0:{}".format(self.port_list[0])]) def get_prediction(self, request): if not request.json: @@ -79,22 +79,14 @@ class WebService(object): abort(400) try: feed, fetch = self.preprocess(request.json, request.json["fetch"]) - if isinstance(feed, list): - fetch_map_batch = self.client_service.predict( - feed_batch=feed, fetch=fetch) - fetch_map_batch = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) - for key in fetch_map_batch: - fetch_map_batch[key] = fetch_map_batch[key].tolist() - result = {"result": fetch_map_batch} - elif isinstance(feed, dict): - if "fetch" in feed: - del feed["fetch"] - fetch_map = self.client_service.predict(feed=feed, fetch=fetch) - for key in fetch_map: - fetch_map[key] = fetch_map[key][0].tolist() - result = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map) + if isinstance(feed, dict) and "fetch" in feed: + del feed["fetch"] + fetch_map = self.client.predict(feed=feed, fetch=fetch) + fetch_map = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map) + for key in fetch_map: + fetch_map[key] = fetch_map[key].tolist() + result = {"result": fetch_map} except ValueError: result = {"result": "Request Value Error"} return result diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index ed7f16c5b822335d83937f013f44d90b01c0a19a..6841220f9f4e52a23bc7b0a0176c58672fc4b675 100644 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -121,13 +121,18 @@ class WebService(object): abort(400) if "fetch" not in request.json: abort(400) - feed, fetch = self.preprocess(request.json, request.json["fetch"]) - fetch_map_batch = self.client.predict(feed=feed, fetch=fetch) - fetch_map_batch = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) - for key in fetch_map_batch: - fetch_map_batch[key] = fetch_map_batch[key].tolist() - result = {"result": fetch_map_batch} + try: + feed, fetch = self.preprocess(request.json, request.json["fetch"]) + if isinstance(feed, dict) and "fetch" in feed: + del feed["fetch"] + fetch_map = self.client.predict(feed=feed, fetch=fetch) + fetch_map = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map) + for key in fetch_map: + fetch_map[key] = fetch_map[key].tolist() + result = {"result": fetch_map} + except ValueError: + result = {"result": "Request Value Error"} return result def run_server(self):