diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py old mode 100755 new mode 100644 index 757f3eb1987258f6f656ca07546cd90e58a1b1a4..b4c911b7d805dab376a9709ce1ca5d719ea49794 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -90,77 +90,26 @@ class WebService(object): self.client = Client() self.client.load_client_config("{}/serving_server_conf.prototxt".format( self.model_config)) + endpoints = "" + if gpu_num > 0: + for i in range(gpu_num): + endpoints += "127.0.0.1:{},".format(self.port + i + 1) + else: + endpoints = "127.0.0.1:{}".format(self.port + 1) + self.client.connect([endpoints]) - client.connect([endpoint]) - while True: - request_json = inputqueue.get() - try: - feed, fetch = self.preprocess(request_json, - request_json["fetch"]) - if isinstance(feed, list): - fetch_map_batch = client.predict( - feed_batch=feed, fetch=fetch) - fetch_map_batch = self.postprocess( - feed=request_json, - fetch=fetch, - fetch_map=fetch_map_batch) - for key in fetch_map_batch: - fetch_map_batch[key] = fetch_map_batch[key].tolist() - result = {"result": fetch_map_batch} - elif isinstance(feed, dict): - if "fetch" in feed: - del feed["fetch"] - fetch_map = client.predict(feed=feed, fetch=fetch) - for key in fetch_map: - fetch_map[key] = fetch_map[key][0].tolist() - result = self.postprocess( - feed=request_json, fetch=fetch, fetch_map=fetch_map) - self.output_queue.put(result) - except ValueError: - self.output_queue.put(-1) - - def _launch_web_service(self, gpu_num): - app_instance = Flask(__name__) - service_name = "/" + self.name + "/prediction" - - self.input_queues = [] - self.output_queue = Queue() - for i in range(gpu_num): - self.input_queues.append(Queue()) - - producer_list = [] - for i, input_q in enumerate(self.input_queues): - producer_processes = Process( - target=self.producers, - args=( - input_q, - "0.0.0.0:{}".format(self.port + 1 + i), )) - producer_list.append(producer_processes) - - for p in producer_list: - p.start() - - client = Client() - client.load_client_config("{}/serving_server_conf.prototxt".format( - self.model_config)) - client.connect(["0.0.0.0:{}".format(self.port + 1)]) - - self.idx = 0 - - def get_prediction(): + def get_prediction(self, request): if not request.json: abort(400) if "fetch" not in request.json: abort(400) - - self.input_queues[self.idx].put(request.json) - - self.idx += 1 - if self.idx >= len(self.gpus): - self.idx = 0 - result = self.output_queue.get() - if not isinstance(result, dict) and result == -1: - result = {"result": "Request Value Error"} + feed, fetch = self.preprocess(request.json, request.json["fetch"]) + fetch_map_batch = self.client.predict(feed=feed, fetch=fetch) + fetch_map_batch = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) + for key in fetch_map_batch: + fetch_map_batch[key] = fetch_map_batch[key].tolist() + result = {"result": fetch_map_batch} return result def run_server(self):