From bbb40430913cdcd4602ef1c2461da9b286876e3e Mon Sep 17 00:00:00 2001 From: barrierye Date: Mon, 20 Apr 2020 18:00:42 +0800 Subject: [PATCH] unify predict return --- python/paddle_serving_server/web_service.py | 24 +++++++------------ .../paddle_serving_server_gpu/web_service.py | 24 +++++++------------ 2 files changed, 16 insertions(+), 32 deletions(-) diff --git a/python/paddle_serving_server/web_service.py b/python/paddle_serving_server/web_service.py index ca43426c..4a033cbc 100755 --- a/python/paddle_serving_server/web_service.py +++ b/python/paddle_serving_server/web_service.py @@ -62,22 +62,14 @@ class WebService(object): abort(400) try: feed, fetch = self.preprocess(request.json, request.json["fetch"]) - if isinstance(feed, list): - fetch_map_batch = self.client_service.predict( - feed_batch=feed, fetch=fetch) - fetch_map_batch = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) - for key in fetch_map_batch: - fetch_map_batch[key] = fetch_map_batch[key].tolist() - result = {"result": fetch_map_batch} - elif isinstance(feed, dict): - if "fetch" in feed: - del feed["fetch"] - fetch_map = self.client_service.predict(feed=feed, fetch=fetch) - for key in fetch_map: - fetch_map[key] = fetch_map[key][0].tolist() - result = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map) + if isinstance(feed, dict) and "fetch" in feed: + del feed["fetch"] + fetch_map = self.client_service.predict(feed=feed, fetch=fetch) + for key in fetch_map: + fetch_map[key] = fetch_map[key][0].tolist() + result = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map) + result = {"result": result} except ValueError: result = {"result": "Request Value Error"} return result diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index 7624ab25..cb833ba3 100644 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -106,22 +106,14 @@ class WebService(object): abort(400) try: feed, fetch = self.preprocess(request.json, request.json["fetch"]) - if isinstance(feed, list): - fetch_map_batch = self.client.predict( - feed_batch=feed, fetch=fetch) - fetch_map_batch = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) - for key in fetch_map_batch: - fetch_map_batch[key] = fetch_map_batch[key].tolist() - result = {"result": fetch_map_batch} - elif isinstance(feed, dict): - if "fetch" in feed: - del feed["fetch"] - fetch_map = self.client.predict(feed=feed, fetch=fetch) - for key in fetch_map: - fetch_map[key] = fetch_map[key][0].tolist() - result = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map) + if isinstance(feed, dict) and "fetch" in feed: + del feed["fetch"] + fetch_map = self.client.predict(feed=feed, fetch=fetch) + for key in fetch_map: + fetch_map[key] = fetch_map[key][0].tolist() + result = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map) + result = {"result": result} except ValueError: result = {"result": "Request Value Error"} return result -- GitLab