提交 796264df 编写于 作者: H HexToString

fix gpuid and device and time

上级 bf5a47be
...@@ -391,7 +391,8 @@ int InferManager::proc_initialize(const char* path, ...@@ -391,7 +391,8 @@ int InferManager::proc_initialize(const char* path,
return -1; return -1;
} }
uint32_t engine_num = model_toolkit_conf.engines_size(); uint32_t engine_num = model_toolkit_conf.engines_size();
im::bsf::TaskExecutorVector<TaskT>::instance().resize(*engine_index_ptr+engine_num); im::bsf::TaskExecutorVector<TaskT>::instance().resize(*engine_index_ptr +
engine_num);
for (uint32_t ei = 0; ei < engine_num; ++ei) { for (uint32_t ei = 0; ei < engine_num; ++ei) {
LOG(INFO) << "model_toolkit_conf.engines(" << ei LOG(INFO) << "model_toolkit_conf.engines(" << ei
<< ").name: " << model_toolkit_conf.engines(ei).name(); << ").name: " << model_toolkit_conf.engines(ei).name();
......
...@@ -79,7 +79,7 @@ class SDKConfig(object): ...@@ -79,7 +79,7 @@ class SDKConfig(object):
self.tag_list = [] self.tag_list = []
self.cluster_list = [] self.cluster_list = []
self.variant_weight_list = [] self.variant_weight_list = []
self.rpc_timeout_ms = 20000 self.rpc_timeout_ms = 200000
self.load_balance_strategy = "la" self.load_balance_strategy = "la"
def add_server_variant(self, tag, cluster, variant_weight): def add_server_variant(self, tag, cluster, variant_weight):
...@@ -142,7 +142,7 @@ class Client(object): ...@@ -142,7 +142,7 @@ class Client(object):
self.profile_ = _Profiler() self.profile_ = _Profiler()
self.all_numpy_input = True self.all_numpy_input = True
self.has_numpy_input = False self.has_numpy_input = False
self.rpc_timeout_ms = 20000 self.rpc_timeout_ms = 200000
from .serving_client import PredictorRes from .serving_client import PredictorRes
self.predictorres_constructor = PredictorRes self.predictorres_constructor = PredictorRes
......
...@@ -31,6 +31,67 @@ elif sys.version_info.major == 3: ...@@ -31,6 +31,67 @@ elif sys.version_info.major == 3:
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
def format_gpu_to_strlist(unformatted_gpus):
gpus_strlist = []
if isinstance(unformatted_gpus, int):
gpus_strlist = [str(unformatted_gpus)]
elif isinstance(unformatted_gpus, list):
if unformatted_gpus == [""]:
gpus_strlist = ["-1"]
elif len(unformatted_gpus) == 0:
gpus_strlist = ["-1"]
else:
gpus_strlist = [str(x) for x in unformatted_gpus]
elif isinstance(unformatted_gpus, str):
if unformatted_gpus == "":
gpus_strlist = ["-1"]
else:
gpus_strlist = [unformatted_gpus]
elif unformatted_gpus == None:
gpus_strlist = ["-1"]
else:
raise ValueError("error input of set_gpus")
# check cuda visible
if "CUDA_VISIBLE_DEVICES" in os.environ:
env_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
for op_gpus_str in gpus_strlist:
op_gpu_list = op_gpus_str.split(",")
# op_gpu_list == ["-1"] means this op use CPU
# so don`t check cudavisible.
if op_gpu_list == ["-1"]:
continue
for ids in op_gpu_list:
if ids not in env_gpus:
print("gpu_ids is not in CUDA_VISIBLE_DEVICES.")
exit(-1)
# check gpuid is valid
for op_gpus_str in gpus_strlist:
op_gpu_list = op_gpus_str.split(",")
use_gpu = False
for ids in op_gpu_list:
if int(ids) < -1:
raise ValueError("The input of gpuid error.")
if int(ids) >= 0:
use_gpu = True
if int(ids) == -1 and use_gpu:
raise ValueError("You can not use CPU and GPU in one model.")
return gpus_strlist
def is_gpu_mode(unformatted_gpus):
gpus_strlist = format_gpu_to_strlist(unformatted_gpus)
for op_gpus_str in gpus_strlist:
op_gpu_list = op_gpus_str.split(",")
for ids in op_gpu_list:
if int(ids) >= 0:
return True
return False
def serve_args(): def serve_args():
parser = argparse.ArgumentParser("serve") parser = argparse.ArgumentParser("serve")
parser.add_argument( parser.add_argument(
...@@ -211,34 +272,15 @@ def start_gpu_card_model(gpu_mode, port, args): # pylint: disable=doc-string-mi ...@@ -211,34 +272,15 @@ def start_gpu_card_model(gpu_mode, port, args): # pylint: disable=doc-string-mi
def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-missing def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-missing
gpus = []
if serving_port == None: if serving_port == None:
serving_port = args.port serving_port = args.port
if args.gpu_ids == "":
gpus = []
else:
#check the gpu_id is valid or not.
gpus = args.gpu_ids
if isinstance(gpus, str):
gpus = [gpus]
if "CUDA_VISIBLE_DEVICES" in os.environ:
env_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
for op_gpus_str in gpus:
op_gpu_list = op_gpus_str.split(",")
for ids in op_gpu_list:
if ids not in env_gpus:
print("gpu_ids is not in CUDA_VISIBLE_DEVICES.")
exit(-1)
if args.use_lite: if args.use_lite:
print("run using paddle-lite.") print("run using paddle-lite.")
start_gpu_card_model(False, serving_port, args) start_gpu_card_model(False, serving_port, args)
elif len(gpus) <= 0:
print("gpu_ids not set, going to run cpu service.")
start_gpu_card_model(False, serving_port, args)
else: else:
start_gpu_card_model(True, serving_port, args) start_gpu_card_model(is_gpu_mode(args.gpu_ids), serving_port, args)
class MainService(BaseHTTPRequestHandler): class MainService(BaseHTTPRequestHandler):
...@@ -320,7 +362,9 @@ class MainService(BaseHTTPRequestHandler): ...@@ -320,7 +362,9 @@ class MainService(BaseHTTPRequestHandler):
if __name__ == "__main__": if __name__ == "__main__":
# args.device is not used at all.
# just keep the interface.
# so --device should not be recommended at the HomePage.
args = serve_args() args = serve_args()
for single_model_config in args.model: for single_model_config in args.model:
if os.path.isdir(single_model_config): if os.path.isdir(single_model_config):
...@@ -346,29 +390,10 @@ if __name__ == "__main__": ...@@ -346,29 +390,10 @@ if __name__ == "__main__":
web_service = WebService(name=args.name) web_service = WebService(name=args.name)
web_service.load_model_config(args.model) web_service.load_model_config(args.model)
if args.gpu_ids == "":
gpus = []
else:
#check the gpu_id is valid or not.
gpus = args.gpu_ids
if isinstance(gpus, str):
gpus = [gpus]
if "CUDA_VISIBLE_DEVICES" in os.environ:
env_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
for op_gpus_str in gpus:
op_gpu_list = op_gpus_str.split(",")
for ids in op_gpu_list:
if ids not in env_gpus:
print("gpu_ids is not in CUDA_VISIBLE_DEVICES.")
exit(-1)
if len(gpus) > 0:
web_service.set_gpus(gpus)
workdir = "{}_{}".format(args.workdir, args.port) workdir = "{}_{}".format(args.workdir, args.port)
web_service.prepare_server( web_service.prepare_server(
workdir=workdir, workdir=workdir,
port=args.port, port=args.port,
device=args.device,
use_lite=args.use_lite, use_lite=args.use_lite,
use_xpu=args.use_xpu, use_xpu=args.use_xpu,
ir_optim=args.ir_optim, ir_optim=args.ir_optim,
...@@ -378,7 +403,8 @@ if __name__ == "__main__": ...@@ -378,7 +403,8 @@ if __name__ == "__main__":
use_trt=args.use_trt, use_trt=args.use_trt,
gpu_multi_stream=args.gpu_multi_stream, gpu_multi_stream=args.gpu_multi_stream,
op_num=args.op_num, op_num=args.op_num,
op_max_batch=args.op_max_batch) op_max_batch=args.op_max_batch,
gpuid=args.gpu_ids)
web_service.run_rpc_service() web_service.run_rpc_service()
app_instance = Flask(__name__) app_instance = Flask(__name__)
......
...@@ -17,6 +17,7 @@ import tarfile ...@@ -17,6 +17,7 @@ import tarfile
import socket import socket
import paddle_serving_server as paddle_serving_server import paddle_serving_server as paddle_serving_server
from paddle_serving_server.rpc_service import MultiLangServerServiceServicer from paddle_serving_server.rpc_service import MultiLangServerServiceServicer
from paddle_serving_server.serve import format_gpu_to_strlist
from .proto import server_configure_pb2 as server_sdk from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config from .proto import general_model_config_pb2 as m_config
from .proto import multi_lang_general_model_service_pb2_grpc from .proto import multi_lang_general_model_service_pb2_grpc
...@@ -171,12 +172,7 @@ class Server(object): ...@@ -171,12 +172,7 @@ class Server(object):
self.device = device self.device = device
def set_gpuid(self, gpuid): def set_gpuid(self, gpuid):
if isinstance(gpuid, int): self.gpuid = format_gpu_to_strlist(gpuid)
self.gpuid = str(gpuid)
elif isinstance(gpuid, list):
self.gpuid = [str(x) for x in gpuid]
else:
self.gpuid = gpuid
def set_op_num(self, op_num): def set_op_num(self, op_num):
self.op_num = op_num self.op_num = op_num
...@@ -197,23 +193,20 @@ class Server(object): ...@@ -197,23 +193,20 @@ class Server(object):
self.use_xpu = True self.use_xpu = True
def _prepare_engine(self, model_config_paths, device, use_encryption_model): def _prepare_engine(self, model_config_paths, device, use_encryption_model):
self.device = device
if self.model_toolkit_conf == None: if self.model_toolkit_conf == None:
self.model_toolkit_conf = [] self.model_toolkit_conf = []
self.device = device
# Generally, self.gpuid = str[] or str.
# such as "0" or ["0"] or ["0,1"] or ["0,1" , "1,2"]
if isinstance(self.gpuid, str):
self.gpuid = [self.gpuid]
# Generally, self.gpuid = str[] or [].
# when len(self.gpuid) means no gpuid is specified. # when len(self.gpuid) means no gpuid is specified.
# if self.device == "gpu" or self.use_trt: # if self.device == "gpu" or self.use_trt:
# we assume you forget to set gpuid, so set gpuid = ['0']; # we assume you forget to set gpuid, so set gpuid = ['0'];
if len(self.gpuid) == 0: if len(self.gpuid) == 0 or self.gpuid == ["-1"]:
if self.device == "gpu" or self.use_trt: if self.device == "gpu" or self.use_trt or self.gpu_multi_stream:
self.gpuid.append("0") self.gpuid = ["0"]
self.device = "gpu"
else: else:
self.gpuid.append("-1") self.gpuid = ["-1"]
if isinstance(self.op_num, int): if isinstance(self.op_num, int):
self.op_num = [self.op_num] self.op_num = [self.op_num]
...@@ -254,12 +247,14 @@ class Server(object): ...@@ -254,12 +247,14 @@ class Server(object):
for ids in op_gpu_list: for ids in op_gpu_list:
engine.gpu_ids.extend([int(ids)]) engine.gpu_ids.extend([int(ids)])
if self.device == "gpu" or self.use_trt: if self.device == "gpu" or self.use_trt or self.gpu_multi_stream:
engine.use_gpu = True engine.use_gpu = True
# this is for Mixed use of GPU and CPU # this is for Mixed use of GPU and CPU
# if model-1 use GPU and set the device="gpu" # if model-1 use GPU and set the device="gpu"
# but gpuid[1] = "-1" which means use CPU in Model-2 # but gpuid[1] = "-1" which means use CPU in Model-2
# so config about GPU should be False. # so config about GPU should be False.
# op_gpu_list = gpuid[index].split(",")
# which is the gpuid for each engine.
if len(op_gpu_list) == 1: if len(op_gpu_list) == 1:
if int(op_gpu_list[0]) == -1: if int(op_gpu_list[0]) == -1:
engine.use_gpu = False engine.use_gpu = False
...@@ -500,10 +495,17 @@ class Server(object): ...@@ -500,10 +495,17 @@ class Server(object):
def prepare_server(self, def prepare_server(self,
workdir=None, workdir=None,
port=9292, port=9292,
device="cpu", device=None,
use_encryption_model=False, use_encryption_model=False,
cube_conf=None): cube_conf=None):
self.device = device # if `device` is not set, use self.device
# self.device may not be changed.
# or self.device may have changed by set_device.
if device == None:
device = self.device
# if `device` is set, let self.device = device.
else:
self.device = device
if workdir == None: if workdir == None:
workdir = "./tmp" workdir = "./tmp"
os.system("mkdir -p {}".format(workdir)) os.system("mkdir -p {}".format(workdir))
...@@ -602,6 +604,7 @@ class MultiLangServer(object): ...@@ -602,6 +604,7 @@ class MultiLangServer(object):
self.body_size_ = 64 * 1024 * 1024 self.body_size_ = 64 * 1024 * 1024
self.concurrency_ = 100000 self.concurrency_ = 100000
self.is_multi_model_ = False # for model ensemble, which is not useful right now. self.is_multi_model_ = False # for model ensemble, which is not useful right now.
self.device = "cpu" # this is the default value for multilang `device`.
def set_max_concurrency(self, concurrency): def set_max_concurrency(self, concurrency):
self.concurrency_ = concurrency self.concurrency_ = concurrency
...@@ -609,6 +612,7 @@ class MultiLangServer(object): ...@@ -609,6 +612,7 @@ class MultiLangServer(object):
def set_device(self, device="cpu"): def set_device(self, device="cpu"):
self.device = device self.device = device
self.bserver_.set_device(device)
def set_num_threads(self, threads): def set_num_threads(self, threads):
self.worker_num_ = threads self.worker_num_ = threads
...@@ -727,10 +731,18 @@ class MultiLangServer(object): ...@@ -727,10 +731,18 @@ class MultiLangServer(object):
def prepare_server(self, def prepare_server(self,
workdir=None, workdir=None,
port=9292, port=9292,
device="cpu", device=None,
use_encryption_model=False, use_encryption_model=False,
cube_conf=None): cube_conf=None):
self.device = device # if `device` is not set, use self.device
# self.device may not be changed.
# or self.device may have changed by set_device.
if device == None:
device = self.device
# if `device` is set, let self.device = device.
else:
self.device = device
if not self._port_is_available(port): if not self._port_is_available(port):
raise SystemExit("Port {} is already used".format(port)) raise SystemExit("Port {} is already used".format(port))
default_port = 12000 default_port = 12000
......
...@@ -26,6 +26,7 @@ import numpy as np ...@@ -26,6 +26,7 @@ import numpy as np
import os import os
from paddle_serving_server import pipeline from paddle_serving_server import pipeline
from paddle_serving_server.pipeline import Op from paddle_serving_server.pipeline import Op
from paddle_serving_server.serve import format_gpu_to_strlist
def port_is_available(port): def port_is_available(port):
...@@ -44,7 +45,7 @@ class WebService(object): ...@@ -44,7 +45,7 @@ class WebService(object):
# pipeline # pipeline
self._server = pipeline.PipelineServer(self.name) self._server = pipeline.PipelineServer(self.name)
self.gpus = [] # deprecated self.gpus = ["-1"] # deprecated
self.rpc_service_list = [] # deprecated self.rpc_service_list = [] # deprecated
def get_pipeline_response(self, read_op): def get_pipeline_response(self, read_op):
...@@ -103,19 +104,24 @@ class WebService(object): ...@@ -103,19 +104,24 @@ class WebService(object):
if client_config_path == None: if client_config_path == None:
self.client_config_path = file_path_list self.client_config_path = file_path_list
# after this function, self.gpus should be a list of str or [].
def set_gpus(self, gpus): def set_gpus(self, gpus):
print("This API will be deprecated later. Please do not use it") print("This API will be deprecated later. Please do not use it")
if isinstance(gpus, int): self.gpus = format_gpu_to_strlist(gpus)
self.gpus = str(gpus)
elif isinstance(gpus, list): # this function can be called by user
self.gpus = [str(x) for x in gpus] # or by Function create_rpc_config
else: # if by user, user can set_gpus or pass the `gpus`
self.gpus = gpus # if `gpus` == None, which means it`s not set at all.
# at this time, we should use self.gpus instead.
# otherwise, we should use the `gpus` first.
# which means if set_gpus and `gpus` is both set.
# `gpus` will be used.
def default_rpc_service(self, def default_rpc_service(self,
workdir, workdir,
port=9292, port=9292,
gpus=-1, gpus=None,
thread_num=2, thread_num=2,
mem_optim=True, mem_optim=True,
use_lite=False, use_lite=False,
...@@ -127,16 +133,23 @@ class WebService(object): ...@@ -127,16 +133,23 @@ class WebService(object):
gpu_multi_stream=False, gpu_multi_stream=False,
op_num=None, op_num=None,
op_max_batch=None): op_max_batch=None):
device = "gpu" device = "gpu"
server = Server() server = Server()
# only when `gpus == None`, which means it`s not set at all
# we will use the self.gpus.
if gpus == None:
gpus = self.gpus
if gpus == -1 or gpus == "-1": gpus = format_gpu_to_strlist(gpus)
server.set_gpuid(gpus)
if len(gpus) == 0 or gpus == ["-1"]:
if use_lite: if use_lite:
device = "arm" device = "arm"
else: else:
device = "cpu" device = "cpu"
else:
server.set_gpuid(gpus)
op_maker = OpMaker() op_maker = OpMaker()
op_seq_maker = OpSeqMaker() op_seq_maker = OpSeqMaker()
...@@ -190,40 +203,26 @@ class WebService(object): ...@@ -190,40 +203,26 @@ class WebService(object):
def _launch_rpc_service(self, service_idx): def _launch_rpc_service(self, service_idx):
self.rpc_service_list[service_idx].run_server() self.rpc_service_list[service_idx].run_server()
# if use this function, self.gpus must be set before.
# if not, we will use the default value, self.gpus = ["-1"].
# so we always pass the `gpus` = self.gpus.
def create_rpc_config(self): def create_rpc_config(self):
if len(self.gpus) == 0: self.rpc_service_list.append(
# init cpu service self.default_rpc_service(
self.rpc_service_list.append( self.workdir,
self.default_rpc_service( self.port_list[0],
self.workdir, self.gpus,
self.port_list[0], thread_num=self.thread_num,
-1, mem_optim=self.mem_optim,
thread_num=self.thread_num, use_lite=self.use_lite,
mem_optim=self.mem_optim, use_xpu=self.use_xpu,
use_lite=self.use_lite, ir_optim=self.ir_optim,
use_xpu=self.use_xpu, precision=self.precision,
ir_optim=self.ir_optim, use_calib=self.use_calib,
precision=self.precision, use_trt=self.use_trt,
use_calib=self.use_calib, gpu_multi_stream=self.gpu_multi_stream,
op_num=self.op_num, op_num=self.op_num,
op_max_batch=self.op_max_batch)) op_max_batch=self.op_max_batch))
else:
self.rpc_service_list.append(
self.default_rpc_service(
self.workdir,
self.port_list[0],
self.gpus,
thread_num=self.thread_num,
mem_optim=self.mem_optim,
use_lite=self.use_lite,
use_xpu=self.use_xpu,
ir_optim=self.ir_optim,
precision=self.precision,
use_calib=self.use_calib,
use_trt=self.use_trt,
gpu_multi_stream=self.gpu_multi_stream,
op_num=self.op_num,
op_max_batch=self.op_max_batch))
def prepare_server(self, def prepare_server(self,
workdir, workdir,
...@@ -240,12 +239,13 @@ class WebService(object): ...@@ -240,12 +239,13 @@ class WebService(object):
gpu_multi_stream=False, gpu_multi_stream=False,
op_num=None, op_num=None,
op_max_batch=None, op_max_batch=None,
gpuid=-1): gpuid=None):
print("This API will be deprecated later. Please do not use it") print("This API will be deprecated later. Please do not use it")
self.workdir = workdir self.workdir = workdir
self.port = port self.port = port
self.thread_num = thread_num self.thread_num = thread_num
self.device = device # self.device is not used at all.
# device is set by gpuid.
self.precision = precision self.precision = precision
self.use_calib = use_calib self.use_calib = use_calib
self.use_lite = use_lite self.use_lite = use_lite
...@@ -257,12 +257,14 @@ class WebService(object): ...@@ -257,12 +257,14 @@ class WebService(object):
self.gpu_multi_stream = gpu_multi_stream self.gpu_multi_stream = gpu_multi_stream
self.op_num = op_num self.op_num = op_num
self.op_max_batch = op_max_batch self.op_max_batch = op_max_batch
if isinstance(gpuid, int):
self.gpus = str(gpuid) # if gpuid != None, we will use gpuid first.
elif isinstance(gpuid, list): # otherwise, keep the self.gpus unchanged.
self.gpus = [str(x) for x in gpuid] # maybe self.gpus is set by the Function set_gpus.
if gpuid != None:
self.gpus = format_gpu_to_strlist(gpuid)
else: else:
self.gpus = gpuid pass
default_port = 12000 default_port = 12000
for i in range(1000): for i in range(1000):
...@@ -359,8 +361,8 @@ class WebService(object): ...@@ -359,8 +361,8 @@ class WebService(object):
if gpu: if gpu:
# if user forget to call function `set_gpus` to set self.gpus. # if user forget to call function `set_gpus` to set self.gpus.
# default self.gpus = [0]. # default self.gpus = [0].
if len(self.gpus) == 0: if len(self.gpus) == 0 or self.gpus == ["-1"]:
self.gpus.append(0) self.gpus = ["0"]
# right now, local Predictor only support 1 card. # right now, local Predictor only support 1 card.
# no matter how many gpu_id is in gpus, we only use the first one. # no matter how many gpu_id is in gpus, we only use the first one.
gpu_id = (self.gpus[0].split(","))[0] gpu_id = (self.gpus[0].split(","))[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册