diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index 37425acd6b7209fbdcf38a52c5a78e0c15b4cf61..7624ab259111dcba6e53fd178ed53cdafbdec61e 100644 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -104,13 +104,26 @@ class WebService(object): abort(400) if "fetch" not in request.json: abort(400) - feed, fetch = self.preprocess(request.json, request.json["fetch"]) - fetch_map_batch = self.client.predict(feed=feed, fetch=fetch) - fetch_map_batch = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) - for key in fetch_map_batch: - fetch_map_batch[key] = fetch_map_batch[key].tolist() - result = {"result": fetch_map_batch} + try: + feed, fetch = self.preprocess(request.json, request.json["fetch"]) + if isinstance(feed, list): + fetch_map_batch = self.client.predict( + feed_batch=feed, fetch=fetch) + fetch_map_batch = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) + for key in fetch_map_batch: + fetch_map_batch[key] = fetch_map_batch[key].tolist() + result = {"result": fetch_map_batch} + elif isinstance(feed, dict): + if "fetch" in feed: + del feed["fetch"] + fetch_map = self.client.predict(feed=feed, fetch=fetch) + for key in fetch_map: + fetch_map[key] = fetch_map[key][0].tolist() + result = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map) + except ValueError: + result = {"result": "Request Value Error"} return result def run_server(self):