From 7bd372c794e32f545f149609192121b1c007737f Mon Sep 17 00:00:00 2001 From: MRXLT Date: Thu, 19 Mar 2020 01:16:30 +0800 Subject: [PATCH] add batch predict for HTTP service --- .../imagenet/image_classification_service.py | 20 ++++++++++++---- .../image_classification_service_gpu.py | 23 ++++++++++++++----- python/examples/imagenet/image_http_client.py | 17 ++++++++++---- python/paddle_serving_server/web_service.py | 21 +++++++++++------ .../paddle_serving_server_gpu/web_service.py | 22 ++++++++++++------ 5 files changed, 74 insertions(+), 29 deletions(-) diff --git a/python/examples/imagenet/image_classification_service.py b/python/examples/imagenet/image_classification_service.py index c78ae1d2..2776eb1b 100644 --- a/python/examples/imagenet/image_classification_service.py +++ b/python/examples/imagenet/image_classification_service.py @@ -25,11 +25,21 @@ class ImageService(WebService): reader = ImageReader() if "image" not in feed: raise ("feed data error!") - sample = base64.b64decode(feed["image"]) - img = reader.process_image(sample) - res_feed = {} - res_feed["image"] = img.reshape(-1) - return res_feed, fetch + if isinstance(feed["image"], list): + feed_batch = [] + for image in feed["image"]: + sample = base64.b64decode(image) + img = reader.process_image(sample) + res_feed = {} + res_feed["image"] = img.reshape(-1) + feed_batch.append(res_feed) + return feed_batch, fetch + else: + sample = base64.b64decode(feed["image"]) + img = reader.process_image(sample) + res_feed = {} + res_feed["image"] = img.reshape(-1) + return res_feed, fetch image_service = ImageService(name="image") diff --git a/python/examples/imagenet/image_classification_service_gpu.py b/python/examples/imagenet/image_classification_service_gpu.py index 8a0bea93..287392e4 100644 --- a/python/examples/imagenet/image_classification_service_gpu.py +++ b/python/examples/imagenet/image_classification_service_gpu.py @@ -25,16 +25,27 @@ class ImageService(WebService): reader = ImageReader() if "image" not in feed: raise ("feed data error!") - sample = base64.b64decode(feed["image"]) - img = reader.process_image(sample) - res_feed = {} - res_feed["image"] = img.reshape(-1) - return res_feed, fetch + print(type(feed["image"]), isinstance(feed["image"], list)) + if isinstance(feed["image"], list): + feed_batch = [] + for image in feed["image"]: + sample = base64.b64decode(image) + img = reader.process_image(sample) + res_feed = {} + res_feed["image"] = img.reshape(-1) + feed_batch.append(res_feed) + return feed_batch, fetch + else: + sample = base64.b64decode(feed["image"]) + img = reader.process_image(sample) + res_feed = {} + res_feed["image"] = img.reshape(-1) + return res_feed, fetch image_service = ImageService(name="image") image_service.load_model_config(sys.argv[1]) -image_service.set_gpus("0,1,2,3") +image_service.set_gpus("0,1") image_service.prepare_server( workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu") image_service.run_server() diff --git a/python/examples/imagenet/image_http_client.py b/python/examples/imagenet/image_http_client.py index b61f0dd7..c567b900 100644 --- a/python/examples/imagenet/image_http_client.py +++ b/python/examples/imagenet/image_http_client.py @@ -24,17 +24,26 @@ def predict(image_path, server): req = json.dumps({"image": image, "fetch": ["score"]}) r = requests.post( server, data=req, headers={"Content-Type": "application/json"}) + print(r.json()["score"][0]) + return r + + +def batch_predict(image_path, server): + image = base64.b64encode(open(image_path).read()) + req = json.dumps({"image": [image, image], "fetch": ["score"]}) + r = requests.post( + server, data=req, headers={"Content-Type": "application/json"}) + print(r.json()["result"][1]["score"][0]) return r if __name__ == "__main__": - server = "http://127.0.0.1:9295/image/prediction" + server = "http://127.0.0.1:9393/image/prediction" #image_path = "./data/n01440764_10026.JPEG" - image_list = os.listdir("./data/image_data/n01440764/") + image_list = os.listdir("./image_data/n01440764/") start = time.time() for img in image_list: - image_file = "./data/image_data/n01440764/" + img + image_file = "./image_data/n01440764/" + img res = predict(image_file, server) - print(res.json()["score"][0]) end = time.time() print(end - start) diff --git a/python/paddle_serving_server/web_service.py b/python/paddle_serving_server/web_service.py index 71614129..298e65e7 100755 --- a/python/paddle_serving_server/web_service.py +++ b/python/paddle_serving_server/web_service.py @@ -64,12 +64,19 @@ class WebService(object): if "fetch" not in request.json: abort(400) feed, fetch = self.preprocess(request.json, request.json["fetch"]) - if "fetch" in feed: - del feed["fetch"] - fetch_map = client_service.predict(feed=feed, fetch=fetch) - fetch_map = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map) - return fetch_map + if isinstance(feed, list): + fetch_map_batch = client_service.batch_predict( + feed_batch=feed, fetch=fetch) + fetch_map_batch = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) + result = {"result": fetch_map_batch} + elif isinstance(feed, dict): + if "fetch" in feed: + del feed["fetch"] + fetch_map = client_service.predict(feed=feed, fetch=fetch) + result = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map) + return result app_instance.run(host="0.0.0.0", port=self.port, @@ -92,5 +99,5 @@ class WebService(object): def preprocess(self, feed={}, fetch=[]): return feed, fetch - def postprocess(self, feed={}, fetch=[], fetch_map={}): + def postprocess(self, feed={}, fetch=[], fetch_map=None): return fetch_map diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index 4d88994c..22b534dd 100755 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -95,12 +95,20 @@ class WebService(object): while True: request_json = inputqueue.get() feed, fetch = self.preprocess(request_json, request_json["fetch"]) - if "fetch" in feed: - del feed["fetch"] - fetch_map = client.predict(feed=feed, fetch=fetch) - fetch_map = self.postprocess( - feed=request_json, fetch=fetch, fetch_map=fetch_map) - self.output_queue.put(fetch_map) + if isinstance(feed, list): + fetch_map_batch = client.batch_predict( + feed_batch=feed, fetch=fetch) + fetch_map_batch = self.postprocess( + feed=request_json, fetch=fetch, fetch_map=fetch_map_batch) + result = {"result": fetch_map_batch} + elif isinstance(feed, dict): + if "fetch" in feed: + del feed["fetch"] + fetch_map = client.predict(feed=feed, fetch=fetch) + result = self.postprocess( + feed=request_json, fetch=fetch, fetch_map=fetch_map) + + self.output_queue.put(result) def _launch_web_service(self, gpu_num): app_instance = Flask(__name__) @@ -186,5 +194,5 @@ class WebService(object): def preprocess(self, feed={}, fetch=[]): return feed, fetch - def postprocess(self, feed={}, fetch=[], fetch_map={}): + def postprocess(self, feed={}, fetch=[], fetch_map=None): return fetch_map -- GitLab