未验证 提交 f1ca5fe2 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #1259 from HexToString/v0.6.0

cherry-pick #1258
......@@ -263,10 +263,8 @@ def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-mis
if "CUDA_VISIBLE_DEVICES" in os.environ:
env_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
for ids in gpus:
if int(ids) >= len(env_gpus):
print(
" Max index of gpu_ids out of range, the number of CUDA_VISIBLE_DEVICES is {}."
.format(len(env_gpus)))
if ids not in env_gpus:
print("gpu_ids is not in CUDA_VISIBLE_DEVICES.")
exit(-1)
else:
env_gpus = []
......
......@@ -166,6 +166,36 @@ class WebService(object):
def _launch_rpc_service(self, service_idx):
self.rpc_service_list[service_idx].run_server()
def create_rpc_config(self):
if len(self.gpus) == 0:
# init cpu service
self.rpc_service_list.append(
self.default_rpc_service(
self.workdir,
self.port_list[0],
-1,
thread_num=self.thread_num,
mem_optim=self.mem_optim,
use_lite=self.use_lite,
use_xpu=self.use_xpu,
ir_optim=self.ir_optim,
precision=self.precision,
use_calib=self.use_calib))
else:
for i, gpuid in enumerate(self.gpus):
self.rpc_service_list.append(
self.default_rpc_service(
"{}_{}".format(self.workdir, i),
self.port_list[i],
gpuid,
thread_num=self.thread_num,
mem_optim=self.mem_optim,
use_lite=self.use_lite,
use_xpu=self.use_xpu,
ir_optim=self.ir_optim,
precision=self.precision,
use_calib=self.use_calib))
def prepare_server(self,
workdir="",
port=9393,
......@@ -183,6 +213,12 @@ class WebService(object):
self.port = port
self.thread_num = thread_num
self.device = device
self.precision = precision
self.use_calib = use_calib
self.use_lite = use_lite
self.use_xpu = use_xpu
self.ir_optim = ir_optim
self.mem_optim = mem_optim
self.gpuid = gpuid
self.port_list = []
default_port = 12000
......@@ -192,35 +228,6 @@ class WebService(object):
if len(self.port_list) > len(self.gpus):
break
if len(self.gpus) == 0:
# init cpu service
self.rpc_service_list.append(
self.default_rpc_service(
self.workdir,
self.port_list[0],
-1,
thread_num=self.thread_num,
mem_optim=mem_optim,
use_lite=use_lite,
use_xpu=use_xpu,
ir_optim=ir_optim,
precision=precision,
use_calib=use_calib))
else:
for i, gpuid in enumerate(self.gpus):
self.rpc_service_list.append(
self.default_rpc_service(
"{}_{}".format(self.workdir, i),
self.port_list[i],
gpuid,
thread_num=self.thread_num,
mem_optim=mem_optim,
use_lite=use_lite,
use_xpu=use_xpu,
ir_optim=ir_optim,
precision=precision,
use_calib=use_calib))
def _launch_web_service(self):
gpu_num = len(self.gpus)
self.client = Client()
......@@ -262,6 +269,7 @@ class WebService(object):
print("http://{}:{}/{}/prediction".format(localIP, self.port,
self.name))
server_pros = []
self.create_rpc_config()
for i, service in enumerate(self.rpc_service_list):
p = Process(target=self._launch_rpc_service, args=(i, ))
server_pros.append(p)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册