diff --git a/python/examples/imagenet/image_classification_service_gpu.py b/python/examples/imagenet/image_classification_service_gpu.py index 8fc92d918867142c6c442cb9eba61e2a9fb1f0e5..8a0bea938638c57a609a604181420929c4a9ca59 100644 --- a/python/examples/imagenet/image_classification_service_gpu.py +++ b/python/examples/imagenet/image_classification_service_gpu.py @@ -14,16 +14,13 @@ from paddle_serving_server_gpu.web_service import WebService import sys -import os +import cv2 import base64 +import numpy as np from image_reader import ImageReader class ImageService(WebService): - """ - preprocessing function for image classification - """ - def preprocess(self, feed={}, fetch=[]): reader = ImageReader() if "image" not in feed: @@ -37,9 +34,7 @@ class ImageService(WebService): image_service = ImageService(name="image") image_service.load_model_config(sys.argv[1]) -gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"] -gpus = [int(x) for x in gpu_ids.split(",")] -image_service.set_gpus(gpus) +image_service.set_gpus("0,1,2,3") image_service.prepare_server( workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu") image_service.run_server() diff --git a/python/examples/imagenet/image_http_client.py b/python/examples/imagenet/image_http_client.py index f8b15dafc73ffac1e121610204ffe3cce23748a3..b61f0dd7d8d5ed25ecc828b5d0882ba11a116019 100644 --- a/python/examples/imagenet/image_http_client.py +++ b/python/examples/imagenet/image_http_client.py @@ -16,6 +16,7 @@ import requests import base64 import json import time +import os def predict(image_path, server): @@ -23,13 +24,17 @@ def predict(image_path, server): req = json.dumps({"image": image, "fetch": ["score"]}) r = requests.post( server, data=req, headers={"Content-Type": "application/json"}) + return r if __name__ == "__main__": - server = "http://127.0.0.1:9393/image/prediction" - image_path = "./data/n01440764_10026.JPEG" + server = "http://127.0.0.1:9295/image/prediction" + #image_path = "./data/n01440764_10026.JPEG" + image_list = os.listdir("./data/image_data/n01440764/") start = time.time() - for i in range(1000): - predict(image_path, server) + for img in image_list: + image_file = "./data/image_data/n01440764/" + img + res = predict(image_file, server) + print(res.json()["score"][0]) end = time.time() print(end - start) diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index fbb52470d48f45795c6e910600a1368d4bf3d8d2..4d88994cc6094488aaf71ff3e37a74acc93579c4 100755 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -15,7 +15,7 @@ # pylint: disable=doc-string-missing from flask import Flask, request, abort -from multiprocessing import Pool, Process +from multiprocessing import Pool, Process, Queue from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server import paddle_serving_server_gpu as serving from paddle_serving_client import Client @@ -29,12 +29,13 @@ class WebService(object): self.name = name self.gpus = [] self.rpc_service_list = [] + self.input_queues = [] def load_model_config(self, model_config): self.model_config = model_config def set_gpus(self, gpus): - self.gpus = gpus + self.gpus = [int(x) for x in gpus.split(",")] def default_rpc_service(self, workdir="conf", @@ -86,60 +87,101 @@ class WebService(object): gpuid, thread_num=10)) + def producers(self, inputqueue, endpoint): + client = Client() + client.load_client_config("{}/serving_server_conf.prototxt".format( + self.model_config)) + client.connect([endpoint]) + while True: + request_json = inputqueue.get() + feed, fetch = self.preprocess(request_json, request_json["fetch"]) + if "fetch" in feed: + del feed["fetch"] + fetch_map = client.predict(feed=feed, fetch=fetch) + fetch_map = self.postprocess( + feed=request_json, fetch=fetch, fetch_map=fetch_map) + self.output_queue.put(fetch_map) + def _launch_web_service(self, gpu_num): app_instance = Flask(__name__) - client_list = [] - if gpu_num > 1: - gpu_num = 0 - for i in range(gpu_num): - client_service = Client() - client_service.load_client_config( - "{}/serving_server_conf.prototxt".format(self.model_config)) - client_service.connect(["0.0.0.0:{}".format(self.port + i + 1)]) - client_list.append(client_service) - time.sleep(1) service_name = "/" + self.name + "/prediction" + self.input_queues = [] + self.output_queue = Queue() + for i in range(gpu_num): + self.input_queues.append(Queue()) + + producer_list = [] + for i, input_q in enumerate(self.input_queues): + producer_processes = Process( + target=self.producers, + args=( + input_q, + "0.0.0.0:{}".format(self.port + 1 + i), )) + producer_list.append(producer_processes) + + for p in producer_list: + p.start() + + client = Client() + client.load_client_config("{}/serving_server_conf.prototxt".format( + self.model_config)) + client.connect(["0.0.0.0:{}".format(self.port + 1)]) + + self.idx = 0 + @app_instance.route(service_name, methods=['POST']) def get_prediction(): if not request.json: abort(400) if "fetch" not in request.json: abort(400) + + self.input_queues[self.idx].put(request.json) + + #self.input_queues[0].put(request.json) + self.idx += 1 + if self.idx >= len(self.gpus): + self.idx = 0 + result = self.output_queue.get() + return result + ''' feed, fetch = self.preprocess(request.json, request.json["fetch"]) if "fetch" in feed: del feed["fetch"] - fetch_map = client_list[0].predict(feed=feed, fetch=fetch) + fetch_map = client.predict(feed=feed, fetch=fetch) fetch_map = self.postprocess( feed=request.json, fetch=fetch, fetch_map=fetch_map) return fetch_map + ''' app_instance.run(host="0.0.0.0", port=self.port, threaded=False, processes=1) + for p in producer_list: + p.join() + def run_server(self): import socket localIP = socket.gethostbyname(socket.gethostname()) print("web service address:") print("http://{}:{}/{}/prediction".format(localIP, self.port, self.name)) - - rpc_processes = [] - for idx in range(len(self.rpc_service_list)): - p_rpc = Process(target=self._launch_rpc_service, args=(idx, )) - rpc_processes.append(p_rpc) - - for p in rpc_processes: + server_pros = [] + for i, service in enumerate(self.rpc_service_list): + p = Process(target=self._launch_rpc_service, args=(i, )) + server_pros.append(p) + for p in server_pros: p.start() p_web = Process( target=self._launch_web_service, args=(len(self.gpus), )) p_web.start() - for p in rpc_processes: - p.join() p_web.join() + for p in server_pros: + p.join() def preprocess(self, feed={}, fetch=[]): return feed, fetch