From c9b1e9ec55336062a00869d7611eede59c97e979 Mon Sep 17 00:00:00 2001 From: guru4elephant Date: Tue, 17 Mar 2020 14:49:42 +0800 Subject: [PATCH] refine web_service.py --- .../paddle_serving_server_gpu/web_service.py | 80 ++++++++----------- 1 file changed, 34 insertions(+), 46 deletions(-) diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index 657c1db5..4d88994c 100755 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -15,7 +15,7 @@ # pylint: disable=doc-string-missing from flask import Flask, request, abort -from multiprocessing import Pool, Process +from multiprocessing import Pool, Process, Queue from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server import paddle_serving_server_gpu as serving from paddle_serving_client import Client @@ -29,27 +29,13 @@ class WebService(object): self.name = name self.gpus = [] self.rpc_service_list = [] - - def producers(self, input_queue, output_queue, endpoint): - client = Client() - client.load_client_config("{}/serving_server_conf.prototxt".format( - self.model_config)) - client.connect([endpoint]) - while True: - request_json = input_queue.get() - feed, fetch = self.preprocess(request_json, request_json["fetch"]) - if "fetch" in feed: - del feed["fetch"] - fetch_map = client.predict(feed=feed, fetch=fetch) - fetch_map = self.postprocess( - feed=request.json, fetch=fetch, fetch_map=fetch_map) - output_queue.put(fetch_map) + self.input_queues = [] def load_model_config(self, model_config): self.model_config = model_config def set_gpus(self, gpus): - self.gpus = gpus + self.gpus = [int(x) for x in gpus.split(",")] def default_rpc_service(self, workdir="conf", @@ -93,21 +79,6 @@ class WebService(object): self.default_rpc_service( self.workdir, self.port + 1, -1, thread_num=10)) else: - self.producer_pros = [] - for i, gpuid in enumerate(self.gpus): - producer_p = Process( - target=self.producers, - args=( - self, - self.input_queue[i], - self.output_queue, - self.port + 1 + i, - gpuid, )) - - self.producer_pros.append(producer_p) - for p in producer_pros: - p.start() - for i, gpuid in enumerate(self.gpus): self.rpc_service_list.append( self.default_rpc_service( @@ -122,7 +93,7 @@ class WebService(object): self.model_config)) client.connect([endpoint]) while True: - request_json = input_queue.get() + request_json = inputqueue.get() feed, fetch = self.preprocess(request_json, request_json["fetch"]) if "fetch" in feed: del feed["fetch"] @@ -135,23 +106,29 @@ class WebService(object): app_instance = Flask(__name__) service_name = "/" + self.name + "/prediction" - input_queues = [] - output_queue = Queue() + self.input_queues = [] + self.output_queue = Queue() for i in range(gpu_num): - input_queues.append(Queue()) + self.input_queues.append(Queue()) producer_list = [] - for i, input_q in enumerate(input_queues): + for i, input_q in enumerate(self.input_queues): producer_processes = Process( target=self.producers, - input_q, - "0.0.0.0:{}".format(self.port + 1 + i)) + args=( + input_q, + "0.0.0.0:{}".format(self.port + 1 + i), )) producer_list.append(producer_processes) for p in producer_list: p.start() - idx = 0 + client = Client() + client.load_client_config("{}/serving_server_conf.prototxt".format( + self.model_config)) + client.connect(["0.0.0.0:{}".format(self.port + 1)]) + + self.idx = 0 @app_instance.route(service_name, methods=['POST']) def get_prediction(): @@ -160,12 +137,23 @@ class WebService(object): if "fetch" not in request.json: abort(400) - input_queues[idx].put(request.json) - result = output_queue.get() - idx += 1 - if idx >= len(self.gpus): - idx = 0 + self.input_queues[self.idx].put(request.json) + + #self.input_queues[0].put(request.json) + self.idx += 1 + if self.idx >= len(self.gpus): + self.idx = 0 + result = self.output_queue.get() return result + ''' + feed, fetch = self.preprocess(request.json, request.json["fetch"]) + if "fetch" in feed: + del feed["fetch"] + fetch_map = client.predict(feed=feed, fetch=fetch) + fetch_map = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map) + return fetch_map + ''' app_instance.run(host="0.0.0.0", port=self.port, @@ -183,7 +171,7 @@ class WebService(object): self.name)) server_pros = [] for i, service in enumerate(self.rpc_service_list): - p = Process(target=_launch_rpc_service, args=(i, )) + p = Process(target=self._launch_rpc_service, args=(i, )) server_pros.append(p) for p in server_pros: p.start() -- GitLab