提交 32863e68 编写于 作者: G guru4elephant

add multi gpu card startup

上级 0814608f
...@@ -18,6 +18,7 @@ Usage: ...@@ -18,6 +18,7 @@ Usage:
python -m paddle_serving_server.serve --model ./serving_server_model --port 9292 python -m paddle_serving_server.serve --model ./serving_server_model --port 9292
""" """
import argparse import argparse
from multiprocessing import Pool, Process
def parse_args(): def parse_args():
...@@ -27,7 +28,7 @@ def parse_args(): ...@@ -27,7 +28,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--model", type=str, default="", help="Model for serving") "--model", type=str, default="", help="Model for serving")
parser.add_argument( parser.add_argument(
"--port", type=int, default=9292, help="Port the server") "--port", type=int, default=9292, help="Port of the starting gpu")
parser.add_argument( parser.add_argument(
"--workdir", "--workdir",
type=str, type=str,
...@@ -35,18 +36,21 @@ def parse_args(): ...@@ -35,18 +36,21 @@ def parse_args():
help="Working dir of current service") help="Working dir of current service")
parser.add_argument( parser.add_argument(
"--device", type=str, default="gpu", help="Type of device") "--device", type=str, default="gpu", help="Type of device")
parser.add_argument("--gpuid", type=int, default=0, help="Index of GPU") parser.add_argument(
"--gpu_ids", type=int, default=0, help="gpu ids")
return parser.parse_args() return parser.parse_args()
args = parse_args()
def start_standard_model(): def start_gpu_card_model(gpuid):
args = parse_args() device = "gpu"
port = args.port
if gpuid == -1:
device = "cpu"
port = args.port + gpuid
thread_num = args.thread thread_num = args.thread
model = args.model model = args.model
port = args.port
workdir = args.workdir workdir = args.workdir
device = args.device
gpuid = args.gpuid
if model == "": if model == "":
print("You must specify your serving model") print("You must specify your serving model")
...@@ -57,7 +61,7 @@ def start_standard_model(): ...@@ -57,7 +61,7 @@ def start_standard_model():
read_op = op_maker.create('general_reader') read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer') general_infer_op = op_maker.create('general_infer')
general_response_op = op_maker.create('general_response') general_response_op = op_maker.create('general_response')
op_seq_maker = serving.OpSeqMaker() op_seq_maker = serving.OpSeqMaker()
op_seq_maker.add_op(read_op) op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op) op_seq_maker.add_op(general_infer_op)
...@@ -69,9 +73,21 @@ def start_standard_model(): ...@@ -69,9 +73,21 @@ def start_standard_model():
server.load_model_config(model) server.load_model_config(model)
server.prepare_server(workdir=workdir, port=port, device=device) server.prepare_server(workdir=workdir, port=port, device=device)
server.set_gpuid(gpuid) if gpuid >= 0:
server.set_gpuid(gpuid)
server.run_server() server.run_server()
if __name__ == "__main__": if __name__ == "__main__":
start_standard_model() gpus = args.gpu_ids.split(",")
if len(gpus) <= 0:
start_gpu_card_model(-1)
else:
gpu_processes = []
for i, gpu_id in gpus:
p = Process(target=start_gpu_card_model, (i,))
gpu_processes.append(p)
for p in gpu_processes:
p.start()
for p in gpu_processes:
p.join()
...@@ -37,6 +37,8 @@ def parse_args(): ...@@ -37,6 +37,8 @@ def parse_args():
help="Working dir of current service") help="Working dir of current service")
parser.add_argument( parser.add_argument(
"--device", type=str, default="cpu", help="Type of device") "--device", type=str, default="cpu", help="Type of device")
parser.add_argument(
"--gpu_ids", type=str, default="", help="GPU ids of current service")
parser.add_argument( parser.add_argument(
"--name", type=str, default="default", help="Default service name") "--name", type=str, default="default", help="Default service name")
return parser.parse_args() return parser.parse_args()
...@@ -48,4 +50,6 @@ if __name__ == "__main__": ...@@ -48,4 +50,6 @@ if __name__ == "__main__":
service.load_model_config(args.model) service.load_model_config(args.model)
service.prepare_server( service.prepare_server(
workdir=args.workdir, port=args.port, device=args.device) workdir=args.workdir, port=args.port, device=args.device)
service.run_server() service.run_server(args.gpu_ids)
...@@ -25,7 +25,12 @@ class WebService(object): ...@@ -25,7 +25,12 @@ class WebService(object):
def load_model_config(self, model_config): def load_model_config(self, model_config):
self.model_config = model_config self.model_config = model_config
def _launch_rpc_service(self): def _launch_rpc_service(self, gpuid):
if gpuid < 0:
device = "cpu"
else:
device = "gpu"
op_maker = OpMaker() op_maker = OpMaker()
read_op = op_maker.create('general_reader') read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer') general_infer_op = op_maker.create('general_infer')
...@@ -36,11 +41,13 @@ class WebService(object): ...@@ -36,11 +41,13 @@ class WebService(object):
op_seq_maker.add_op(general_response_op) op_seq_maker.add_op(general_response_op)
server = Server() server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(16) server.set_num_threads(10)
server.set_gpuid = self.gpuid if gpuid >= 0:
server.set_gpuid(gpuid)
server.load_model_config(self.model_config) server.load_model_config(self.model_config)
server.prepare_server( server.prepare_server(
workdir=self.workdir, port=self.port + 1, device=self.device) workdir="{}_{}".format(self.workdir, gpuid),
port=self.port + gpuid + 1, device=device)
server.run_server() server.run_server()
def prepare_server(self, workdir="", port=9393, device="gpu", gpuid=0): def prepare_server(self, workdir="", port=9393, device="gpu", gpuid=0):
...@@ -74,18 +81,27 @@ class WebService(object): ...@@ -74,18 +81,27 @@ class WebService(object):
threaded=False, threaded=False,
processes=1) processes=1)
def run_server(self): def run_server(self, gpu_ids):
import socket import socket
localIP = socket.gethostbyname(socket.gethostname()) localIP = socket.gethostbyname(socket.gethostname())
print("web service address:") print("web service address:")
print("http://{}:{}/{}/prediction".format(localIP, self.port, print("http://{}:{}/{}/prediction".format(localIP, self.port,
self.name)) self.name))
p_rpc = Process(target=self._launch_rpc_service)
p_web = Process(target=self._launch_web_service) gpus = gpu_ids.split(",")
p_rpc.start() if len(gpus) <= 0:
p_web.start() self._launch_rpc_service(-1)
p_web.join() else:
p_rpc.join() gpu_processes = []
for i, gpu_id in gpus:
p = Process(target=self._launch_rpc_service, (i,))
gpu_processes.append(p)
for p in gpu_processes:
p.start()
p_web = Process(target=self._launch_web_service)
for p in gpu_processes:
p.join()
p_web.join()
def preprocess(self, feed={}, fetch=[]): def preprocess(self, feed={}, fetch=[]):
return feed, fetch return feed, fetch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册