From 60a375c47bb8532b4464edc8f3ba0e57b22d05e3 Mon Sep 17 00:00:00 2001 From: MRXLT Date: Tue, 14 Apr 2020 16:30:21 +0800 Subject: [PATCH] refine gpu web service --- python/paddle_serving_server_gpu/__init__.py | 4 +- python/paddle_serving_server_gpu/serve.py | 18 +++ .../paddle_serving_server_gpu/web_service.py | 128 ++++-------------- 3 files changed, 48 insertions(+), 102 deletions(-) diff --git a/python/paddle_serving_server_gpu/__init__.py b/python/paddle_serving_server_gpu/__init__.py index cfd5eee9..d7d1fe10 100644 --- a/python/paddle_serving_server_gpu/__init__.py +++ b/python/paddle_serving_server_gpu/__init__.py @@ -306,6 +306,8 @@ class Server(object): self.check_local_bin() if not self.use_local_bin: self.download_bin() + while not os.path.exists(self.server_path): + time.sleep(1) else: print("Use local bin : {}".format(self.bin_path)) command = "{} " \ @@ -338,7 +340,5 @@ class Server(object): print("Going to Run Comand") print(command) # wait for other process to download server bin - while not os.path.exists(self.server_path): - time.sleep(1) os.system(command) diff --git a/python/paddle_serving_server_gpu/serve.py b/python/paddle_serving_server_gpu/serve.py index cb82e02c..916af05a 100644 --- a/python/paddle_serving_server_gpu/serve.py +++ b/python/paddle_serving_server_gpu/serve.py @@ -21,6 +21,7 @@ import argparse import os from multiprocessing import Pool, Process from paddle_serving_server_gpu import serve_args +from flask import Flask, request def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-missing @@ -114,3 +115,20 @@ if __name__ == "__main__": web_service.prepare_server( workdir=args.workdir, port=args.port, device=args.device) web_service.run_server() + + app_instance = Flask(__name__) + + @app_instance.before_first_request + def init(): + web_service._launch_web_service() + + service_name = "/" + web_service.name + "/prediction" + + @app_instance.route(service_name, methods=["POST"]) + def run(): + return web_service.get_prediction(request) + + app_instance.run(host="0.0.0.0", + port=web_service.port, + threaded=False, + processes=4) diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index 5d507c94..1bb8e93b 100755 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -11,17 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -#!flask/bin/python -# pylint: disable=doc-string-missing from flask import Flask, request, abort -from multiprocessing import Pool, Process, Queue from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server import paddle_serving_server_gpu as serving +from multiprocessing import Pool, Process, Queue from paddle_serving_client import Client -from .serve import start_multi_card -import time -import random +from paddle_serving_server_gpu.serve import start_multi_card + +import sys +import numpy as np class WebService(object): @@ -29,7 +28,6 @@ class WebService(object): self.name = name self.gpus = [] self.rpc_service_list = [] - self.input_queues = [] def load_model_config(self, model_config): self.model_config = model_config @@ -66,12 +64,6 @@ class WebService(object): return server def _launch_rpc_service(self, service_idx): - if service_idx == 0: - self.rpc_service_list[service_idx].check_local_bin() - if not self.rpc_service_list[service_idx].use_local_bin: - self.rpc_service_list[service_idx].download_bin() - else: - time.sleep(3) self.rpc_service_list[service_idx].run_server() def prepare_server(self, workdir="", port=9393, device="gpu", gpuid=0): @@ -93,87 +85,30 @@ class WebService(object): gpuid, thread_num=10)) - def producers(self, inputqueue, endpoint): - client = Client() - client.load_client_config("{}/serving_server_conf.prototxt".format( + def _launch_web_service(self): + gpu_num = len(self.gpus) + self.client = Client() + self.client.load_client_config("{}/serving_server_conf.prototxt".format( self.model_config)) - client.connect([endpoint]) - while True: - request_json = inputqueue.get() - try: - feed, fetch = self.preprocess(request_json, - request_json["fetch"]) - if isinstance(feed, list): - fetch_map_batch = client.predict( - feed_batch=feed, fetch=fetch) - fetch_map_batch = self.postprocess( - feed=request_json, - fetch=fetch, - fetch_map=fetch_map_batch) - result = {"result": fetch_map_batch} - elif isinstance(feed, dict): - if "fetch" in feed: - del feed["fetch"] - fetch_map = client.predict(feed=feed, fetch=fetch) - result = self.postprocess( - feed=request_json, fetch=fetch, fetch_map=fetch_map) - self.output_queue.put(result) - except ValueError: - self.output_queue.put(-1) - - def _launch_web_service(self, gpu_num): - app_instance = Flask(__name__) - service_name = "/" + self.name + "/prediction" - - self.input_queues = [] - self.output_queue = Queue() - for i in range(gpu_num): - self.input_queues.append(Queue()) - - producer_list = [] - for i, input_q in enumerate(self.input_queues): - producer_processes = Process( - target=self.producers, - args=( - input_q, - "0.0.0.0:{}".format(self.port + 1 + i), )) - producer_list.append(producer_processes) - - for p in producer_list: - p.start() - - client = Client() - client.load_client_config("{}/serving_server_conf.prototxt".format( - self.model_config)) - client.connect(["0.0.0.0:{}".format(self.port + 1)]) - - self.idx = 0 - - @app_instance.route(service_name, methods=['POST']) - def get_prediction(): - if not request.json: - abort(400) - if "fetch" not in request.json: - abort(400) - - self.input_queues[self.idx].put(request.json) - - #self.input_queues[0].put(request.json) - self.idx += 1 - if self.idx >= len(self.gpus): - self.idx = 0 - result = self.output_queue.get() - if not isinstance(result, dict) and result == -1: - result = {"result": "Request Value Error"} - return result - - app_instance.run(host="0.0.0.0", - port=self.port, - threaded=False, - processes=1) - - for p in producer_list: - p.join() + endpoints = "" + if gpu_num > 0: + for i in range(gpu_num): + endpoints += "127.0.0.1:{},".format(self.port + i + 1) + else: + endpoints = "127.0.0.1:{}".format(self.port + 1) + self.client.connect([endpoints]) + + def get_prediction(self, request): + if not request.json: + abort(400) + if "fetch" not in request.json: + abort(400) + feed, fetch = self.preprocess(request.json, request.json["fetch"]) + fetch_map_batch = self.client.predict(feed=feed, fetch=fetch) + fetch_map_batch = self.postprocess( + feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) + result = {"result": fetch_map_batch} + return result def run_server(self): import socket @@ -188,13 +123,6 @@ class WebService(object): for p in server_pros: p.start() - p_web = Process( - target=self._launch_web_service, args=(len(self.gpus), )) - p_web.start() - p_web.join() - for p in server_pros: - p.join() - def preprocess(self, feed={}, fetch=[]): return feed, fetch -- GitLab