提交 957d6c81 编写于 作者: G guru4elephant

add multiple gpu rpc service for web service startup

上级 32d1b406
...@@ -14,16 +14,13 @@ ...@@ -14,16 +14,13 @@
from paddle_serving_server_gpu.web_service import WebService from paddle_serving_server_gpu.web_service import WebService
import sys import sys
import os import cv2
import base64 import base64
import numpy as np
from image_reader import ImageReader from image_reader import ImageReader
class ImageService(WebService): class ImageService(WebService):
"""
preprocessing function for image classification
"""
def preprocess(self, feed={}, fetch=[]): def preprocess(self, feed={}, fetch=[]):
reader = ImageReader() reader = ImageReader()
if "image" not in feed: if "image" not in feed:
...@@ -37,9 +34,7 @@ class ImageService(WebService): ...@@ -37,9 +34,7 @@ class ImageService(WebService):
image_service = ImageService(name="image") image_service = ImageService(name="image")
image_service.load_model_config(sys.argv[1]) image_service.load_model_config(sys.argv[1])
gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"] image_service.set_gpus("0,1,2,3")
gpus = [int(x) for x in gpu_ids.split(",")]
image_service.set_gpus(gpus)
image_service.prepare_server( image_service.prepare_server(
workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu") workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu")
image_service.run_server() image_service.run_server()
...@@ -16,6 +16,7 @@ import requests ...@@ -16,6 +16,7 @@ import requests
import base64 import base64
import json import json
import time import time
import os
def predict(image_path, server): def predict(image_path, server):
...@@ -23,13 +24,17 @@ def predict(image_path, server): ...@@ -23,13 +24,17 @@ def predict(image_path, server):
req = json.dumps({"image": image, "fetch": ["score"]}) req = json.dumps({"image": image, "fetch": ["score"]})
r = requests.post( r = requests.post(
server, data=req, headers={"Content-Type": "application/json"}) server, data=req, headers={"Content-Type": "application/json"})
return r
if __name__ == "__main__": if __name__ == "__main__":
server = "http://127.0.0.1:9393/image/prediction" server = "http://127.0.0.1:9295/image/prediction"
image_path = "./data/n01440764_10026.JPEG" #image_path = "./data/n01440764_10026.JPEG"
image_list = os.listdir("./data/image_data/n01440764/")
start = time.time() start = time.time()
for i in range(1000): for img in image_list:
predict(image_path, server) image_file = "./data/image_data/n01440764/" + img
res = predict(image_file, server)
print(res.json()["score"][0])
end = time.time() end = time.time()
print(end - start) print(end - start)
g # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -24,16 +24,27 @@ import time ...@@ -24,16 +24,27 @@ import time
import random import random
def producers(input_queue, output_queue, endpoint):
pass
class WebService(object): class WebService(object):
def __init__(self, name="default_service"): def __init__(self, name="default_service"):
self.name = name self.name = name
self.gpus = [] self.gpus = []
self.rpc_service_list = [] self.rpc_service_list = []
def producers(self, input_queue, output_queue, endpoint):
client = Client()
client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config))
client.connect([endpoint])
while True:
request_json = input_queue.get()
feed, fetch = self.preprocess(request_json, request_json["fetch"])
if "fetch" in feed:
del feed["fetch"]
fetch_map = client.predict(feed=feed, fetch=fetch)
fetch_map = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map)
output_queue.put(fetch_map)
def load_model_config(self, model_config): def load_model_config(self, model_config):
self.model_config = model_config self.model_config = model_config
...@@ -82,6 +93,21 @@ class WebService(object): ...@@ -82,6 +93,21 @@ class WebService(object):
self.default_rpc_service( self.default_rpc_service(
self.workdir, self.port + 1, -1, thread_num=10)) self.workdir, self.port + 1, -1, thread_num=10))
else: else:
self.producer_pros = []
for i, gpuid in enumerate(self.gpus):
producer_p = Process(
target=self.producers,
args=(
self,
self.input_queue[i],
self.output_queue,
self.port + 1 + i,
gpuid, ))
self.producer_pros.append(producer_p)
for p in producer_pros:
p.start()
for i, gpuid in enumerate(self.gpus): for i, gpuid in enumerate(self.gpus):
self.rpc_service_list.append( self.rpc_service_list.append(
self.default_rpc_service( self.default_rpc_service(
...@@ -90,18 +116,23 @@ class WebService(object): ...@@ -90,18 +116,23 @@ class WebService(object):
gpuid, gpuid,
thread_num=10)) thread_num=10))
def producers(self, inputqueue, endpoint):
client = Client()
client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config))
client.connect([endpoint])
while True:
request_json = input_queue.get()
feed, fetch = self.preprocess(request_json, request_json["fetch"])
if "fetch" in feed:
del feed["fetch"]
fetch_map = client.predict(feed=feed, fetch=fetch)
fetch_map = self.postprocess(
feed=request_json, fetch=fetch, fetch_map=fetch_map)
self.output_queue.put(fetch_map)
def _launch_web_service(self, gpu_num): def _launch_web_service(self, gpu_num):
app_instance = Flask(__name__) app_instance = Flask(__name__)
client_list = []
if gpu_num > 1:
gpu_num = 0
for i in range(gpu_num):
client_service = Client()
client_service.load_client_config(
"{}/serving_server_conf.prototxt".format(self.model_config))
client_service.connect(["0.0.0.0:{}".format(self.port + i + 1)])
client_list.append(client_service)
time.sleep(1)
service_name = "/" + self.name + "/prediction" service_name = "/" + self.name + "/prediction"
input_queues = [] input_queues = []
...@@ -109,12 +140,18 @@ class WebService(object): ...@@ -109,12 +140,18 @@ class WebService(object):
for i in range(gpu_num): for i in range(gpu_num):
input_queues.append(Queue()) input_queues.append(Queue())
@app_instance.route("{}_batch".format(service_name), methods['POST']) producer_list = []
def get_prediction(): for i, input_q in enumerate(input_queues):
if not request.json: producer_processes = Process(
abort(400) target=self.producers,
if "fetch" not in request.json: input_q,
abort(400) "0.0.0.0:{}".format(self.port + 1 + i))
producer_list.append(producer_processes)
for p in producer_list:
p.start()
idx = 0
@app_instance.route(service_name, methods=['POST']) @app_instance.route(service_name, methods=['POST'])
def get_prediction(): def get_prediction():
...@@ -122,40 +159,41 @@ class WebService(object): ...@@ -122,40 +159,41 @@ class WebService(object):
abort(400) abort(400)
if "fetch" not in request.json: if "fetch" not in request.json:
abort(400) abort(400)
feed, fetch = self.preprocess(request.json, request.json["fetch"])
if "fetch" in feed: input_queues[idx].put(request.json)
del feed["fetch"] result = output_queue.get()
fetch_map = client_list[0].predict(feed=feed, fetch=fetch) idx += 1
fetch_map = self.postprocess( if idx >= len(self.gpus):
feed=request.json, fetch=fetch, fetch_map=fetch_map) idx = 0
return fetch_map return result
app_instance.run(host="0.0.0.0", app_instance.run(host="0.0.0.0",
port=self.port, port=self.port,
threaded=False, threaded=False,
processes=1) processes=1)
for p in producer_list:
p.join()
def run_server(self): def run_server(self):
import socket import socket
localIP = socket.gethostbyname(socket.gethostname()) localIP = socket.gethostbyname(socket.gethostname())
print("web service address:") print("web service address:")
print("http://{}:{}/{}/prediction".format(localIP, self.port, print("http://{}:{}/{}/prediction".format(localIP, self.port,
self.name)) self.name))
server_pros = []
rpc_processes = [] for i, service in enumerate(self.rpc_service_list):
for idx in range(len(self.rpc_service_list)): p = Process(target=_launch_rpc_service, args=(i, ))
p_rpc = Process(target=self._launch_rpc_service, args=(idx, )) server_pros.append(p)
rpc_processes.append(p_rpc) for p in server_pros:
for p in rpc_processes:
p.start() p.start()
p_web = Process( p_web = Process(
target=self._launch_web_service, args=(len(self.gpus), )) target=self._launch_web_service, args=(len(self.gpus), ))
p_web.start() p_web.start()
for p in rpc_processes:
p.join()
p_web.join() p_web.join()
for p in server_pros:
p.join()
def preprocess(self, feed={}, fetch=[]): def preprocess(self, feed={}, fetch=[]):
return feed, fetch return feed, fetch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册