提交 818e707e 编写于 作者: H HexToString

Compatible GPU fields

上级 49f80620
......@@ -169,7 +169,13 @@ class Server(object):
self.device = device
def set_gpuid(self, gpuid):
self.gpuid = gpuid
if isinstance(gpuid, int):
self.gpuid = str(gpuid)
elif isinstance(gpuid, list):
gpu_list = [str(x) for x in gpuid]
self.gpuid = ",".join(gpu_list)
else:
self.gpuid = gpuid
def set_op_num(self, op_num):
self.op_num = op_num
......
......@@ -105,7 +105,13 @@ class WebService(object):
def set_gpus(self, gpus):
print("This API will be deprecated later. Please do not use it")
self.gpus = gpus
if isinstance(gpus, int):
self.gpus = str(gpus)
elif isinstance(gpus, list):
gpu_list = [str(x) for x in gpus]
self.gpus = ",".join(gpu_list)
else:
self.gpus = gpus
def default_rpc_service(self,
workdir,
......@@ -125,7 +131,7 @@ class WebService(object):
device = "gpu"
server = Server()
if gpus == -1:
if gpus == -1 or gpus == "-1":
if use_lite:
device = "arm"
else:
......@@ -234,7 +240,8 @@ class WebService(object):
use_trt=False,
gpu_multi_stream=False,
op_num=None,
op_max_batch=None):
op_max_batch=None,
gpuid=-1):
print("This API will be deprecated later. Please do not use it")
self.workdir = workdir
self.port = port
......@@ -251,6 +258,13 @@ class WebService(object):
self.gpu_multi_stream = gpu_multi_stream
self.op_num = op_num
self.op_max_batch = op_max_batch
if isinstance(gpuid, int):
self.gpus = str(gpuid)
elif isinstance(gpuid, list):
gpu_list = [str(x) for x in gpuid]
self.gpus = ",".join(gpu_list)
else:
self.gpus = gpuid
default_port = 12000
for i in range(1000):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册