提交 e6263b1f 编写于 作者: T TeslaZhao

fix python bad code style by yapf

上级 14c9f0c0
...@@ -54,8 +54,8 @@ class OpMaker(object): ...@@ -54,8 +54,8 @@ class OpMaker(object):
def create(self, node_type, engine_name=None, inputs=[], outputs=[]): def create(self, node_type, engine_name=None, inputs=[], outputs=[]):
if node_type not in self.op_dict: if node_type not in self.op_dict:
raise Exception("Op type {} is not supported right now".format( raise Exception(
node_type)) "Op type {} is not supported right now".format(node_type))
node = server_sdk.DAGNode() node = server_sdk.DAGNode()
# node.name will be used as the infer engine name # node.name will be used as the infer engine name
if engine_name: if engine_name:
...@@ -103,9 +103,9 @@ class OpSeqMaker(object): ...@@ -103,9 +103,9 @@ class OpSeqMaker(object):
elif len(node.dependencies) == 1: elif len(node.dependencies) == 1:
if node.dependencies[0].name != self.workflow.nodes[-1].name: if node.dependencies[0].name != self.workflow.nodes[-1].name:
raise Exception( raise Exception(
'You must add op in order in OpSeqMaker. The previous op is {}, but the current op is followed by {}.'. 'You must add op in order in OpSeqMaker. The previous op is {}, but the current op is followed by {}.'
format(node.dependencies[0].name, self.workflow.nodes[ .format(node.dependencies[0].name,
-1].name)) self.workflow.nodes[-1].name))
self.workflow.nodes.extend([node]) self.workflow.nodes.extend([node])
def get_op_sequence(self): def get_op_sequence(self):
...@@ -308,8 +308,8 @@ class Server(object): ...@@ -308,8 +308,8 @@ class Server(object):
self.model_config_paths[node.name] = path self.model_config_paths[node.name] = path
print("You have specified multiple model paths, please ensure " print("You have specified multiple model paths, please ensure "
"that the input and output of multiple models are the same.") "that the input and output of multiple models are the same.")
workflow_oi_config_path = list(self.model_config_paths.items())[0][ workflow_oi_config_path = list(
1] self.model_config_paths.items())[0][1]
else: else:
raise Exception("The type of model_config_paths must be str or " raise Exception("The type of model_config_paths must be str or "
"dict({op: model_path}), not {}.".format( "dict({op: model_path}), not {}.".format(
...@@ -367,8 +367,8 @@ class Server(object): ...@@ -367,8 +367,8 @@ class Server(object):
if os.path.exists(tar_name): if os.path.exists(tar_name):
os.remove(tar_name) os.remove(tar_name)
raise SystemExit( raise SystemExit(
'Download failed, please check your network or permission of {}.'. 'Download failed, please check your network or permission of {}.'
format(self.module_path)) .format(self.module_path))
else: else:
try: try:
print('Decompressing files ..') print('Decompressing files ..')
...@@ -379,8 +379,8 @@ class Server(object): ...@@ -379,8 +379,8 @@ class Server(object):
if os.path.exists(exe_path): if os.path.exists(exe_path):
os.remove(exe_path) os.remove(exe_path)
raise SystemExit( raise SystemExit(
'Decompressing failed, please check your permission of {} or disk space left.'. 'Decompressing failed, please check your permission of {} or disk space left.'
format(self.module_path)) .format(self.module_path))
finally: finally:
os.remove(tar_name) os.remove(tar_name)
#release lock #release lock
...@@ -569,20 +569,20 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -569,20 +569,20 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
tensor.data = model_result[name].tobytes() tensor.data = model_result[name].tobytes()
else: else:
if v_type == 0: # int64 if v_type == 0: # int64
tensor.int64_data.extend(model_result[name].reshape(-1) tensor.int64_data.extend(
.tolist()) model_result[name].reshape(-1).tolist())
elif v_type == 1: # float32 elif v_type == 1: # float32
tensor.float_data.extend(model_result[name].reshape(-1) tensor.float_data.extend(
.tolist()) model_result[name].reshape(-1).tolist())
elif v_type == 2: # int32 elif v_type == 2: # int32
tensor.int_data.extend(model_result[name].reshape(-1) tensor.int_data.extend(
.tolist()) model_result[name].reshape(-1).tolist())
else: else:
raise Exception("error type.") raise Exception("error type.")
tensor.shape.extend(list(model_result[name].shape)) tensor.shape.extend(list(model_result[name].shape))
if name in self.lod_tensor_set_: if name in self.lod_tensor_set_:
tensor.lod.extend(model_result["{}.lod".format(name)] tensor.lod.extend(
.tolist()) model_result["{}.lod".format(name)].tolist())
inst.tensor_array.append(tensor) inst.tensor_array.append(tensor)
model_output.insts.append(inst) model_output.insts.append(inst)
model_output.engine_name = model_name model_output.engine_name = model_name
...@@ -601,11 +601,10 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -601,11 +601,10 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
def Inference(self, request, context): def Inference(self, request, context):
feed_dict, fetch_names, is_python, log_id = \ feed_dict, fetch_names, is_python, log_id = \
self._unpack_inference_request(request) self._unpack_inference_request(request)
ret = self.bclient_.predict( ret = self.bclient_.predict(feed=feed_dict,
feed=feed_dict, fetch=fetch_names,
fetch=fetch_names, need_variant_tag=True,
need_variant_tag=True, log_id=log_id)
log_id=log_id)
return self._pack_inference_response(ret, fetch_names, is_python) return self._pack_inference_response(ret, fetch_names, is_python)
def GetClientConfig(self, request, context): def GetClientConfig(self, request, context):
...@@ -685,15 +684,14 @@ class MultiLangServer(object): ...@@ -685,15 +684,14 @@ class MultiLangServer(object):
default_port = 12000 default_port = 12000
self.port_list_ = [] self.port_list_ = []
for i in range(1000): for i in range(1000):
if default_port + i != port and self._port_is_available(default_port if default_port + i != port and self._port_is_available(
+ i): default_port + i):
self.port_list_.append(default_port + i) self.port_list_.append(default_port + i)
break break
self.bserver_.prepare_server( self.bserver_.prepare_server(workdir=workdir,
workdir=workdir, port=self.port_list_[0],
port=self.port_list_[0], device=device,
device=device, cube_conf=cube_conf)
cube_conf=cube_conf)
self.set_port(port) self.set_port(port)
def _launch_brpc_service(self, bserver): def _launch_brpc_service(self, bserver):
...@@ -706,8 +704,8 @@ class MultiLangServer(object): ...@@ -706,8 +704,8 @@ class MultiLangServer(object):
return result != 0 return result != 0
def run_server(self): def run_server(self):
p_bserver = Process( p_bserver = Process(target=self._launch_brpc_service,
target=self._launch_brpc_service, args=(self.bserver_, )) args=(self.bserver_, ))
p_bserver.start() p_bserver.start()
options = [('grpc.max_send_message_length', self.body_size_), options = [('grpc.max_send_message_length', self.body_size_),
('grpc.max_receive_message_length', self.body_size_)] ('grpc.max_receive_message_length', self.body_size_)]
......
...@@ -24,50 +24,58 @@ from flask import Flask, request ...@@ -24,50 +24,58 @@ from flask import Flask, request
def parse_args(): # pylint: disable=doc-string-missing def parse_args(): # pylint: disable=doc-string-missing
parser = argparse.ArgumentParser("serve") parser = argparse.ArgumentParser("serve")
parser.add_argument( parser.add_argument("--thread",
"--thread", type=int, default=10, help="Concurrency of server") type=int,
parser.add_argument( default=10,
"--model", type=str, default="", help="Model for serving") help="Concurrency of server")
parser.add_argument( parser.add_argument("--model",
"--port", type=int, default=9292, help="Port the server") type=str,
parser.add_argument( default="",
"--name", type=str, default="None", help="Web service name") help="Model for serving")
parser.add_argument( parser.add_argument("--port",
"--workdir", type=int,
type=str, default=9292,
default="workdir", help="Port the server")
help="Working dir of current service") parser.add_argument("--name",
parser.add_argument( type=str,
"--device", type=str, default="cpu", help="Type of device") default="None",
parser.add_argument( help="Web service name")
"--mem_optim_off", parser.add_argument("--workdir",
default=False, type=str,
action="store_true", default="workdir",
help="Memory optimize") help="Working dir of current service")
parser.add_argument( parser.add_argument("--device",
"--ir_optim", default=False, action="store_true", help="Graph optimize") type=str,
parser.add_argument( default="cpu",
"--use_mkl", default=False, action="store_true", help="Use MKL") help="Type of device")
parser.add_argument( parser.add_argument("--mem_optim_off",
"--max_body_size", default=False,
type=int, action="store_true",
default=512 * 1024 * 1024, help="Memory optimize")
help="Limit sizes of messages") parser.add_argument("--ir_optim",
parser.add_argument( default=False,
"--use_multilang", action="store_true",
default=False, help="Graph optimize")
action="store_true", parser.add_argument("--use_mkl",
help="Use Multi-language-service") default=False,
parser.add_argument( action="store_true",
"--product_name", help="Use MKL")
type=str, parser.add_argument("--max_body_size",
default=None, type=int,
help="product_name for authentication") default=512 * 1024 * 1024,
parser.add_argument( help="Limit sizes of messages")
"--container_id", parser.add_argument("--use_multilang",
type=str, default=False,
default=None, action="store_true",
help="container_id for authentication") help="Use Multi-language-service")
parser.add_argument("--product_name",
type=str,
default=None,
help="product_name for authentication")
parser.add_argument("--container_id",
type=str,
default=None,
help="container_id for authentication")
return parser.parse_args() return parser.parse_args()
...@@ -129,8 +137,9 @@ if __name__ == "__main__": ...@@ -129,8 +137,9 @@ if __name__ == "__main__":
else: else:
service = WebService(name=args.name) service = WebService(name=args.name)
service.load_model_config(args.model) service.load_model_config(args.model)
service.prepare_server( service.prepare_server(workdir=args.workdir,
workdir=args.workdir, port=args.port, device=args.device) port=args.port,
device=args.device)
service.run_rpc_service() service.run_rpc_service()
app_instance = Flask(__name__) app_instance = Flask(__name__)
......
...@@ -40,49 +40,55 @@ from concurrent import futures ...@@ -40,49 +40,55 @@ from concurrent import futures
def serve_args(): def serve_args():
parser = argparse.ArgumentParser("serve") parser = argparse.ArgumentParser("serve")
parser.add_argument( parser.add_argument("--thread",
"--thread", type=int, default=2, help="Concurrency of server") type=int,
parser.add_argument( default=2,
"--model", type=str, default="", help="Model for serving") help="Concurrency of server")
parser.add_argument( parser.add_argument("--model",
"--port", type=int, default=9292, help="Port of the starting gpu") type=str,
parser.add_argument( default="",
"--workdir", help="Model for serving")
type=str, parser.add_argument("--port",
default="workdir", type=int,
help="Working dir of current service") default=9292,
parser.add_argument( help="Port of the starting gpu")
"--device", type=str, default="gpu", help="Type of device") parser.add_argument("--workdir",
type=str,
default="workdir",
help="Working dir of current service")
parser.add_argument("--device",
type=str,
default="gpu",
help="Type of device")
parser.add_argument("--gpu_ids", type=str, default="", help="gpu ids") parser.add_argument("--gpu_ids", type=str, default="", help="gpu ids")
parser.add_argument( parser.add_argument("--name",
"--name", type=str, default="None", help="Default service name") type=str,
parser.add_argument( default="None",
"--mem_optim_off", help="Default service name")
default=False, parser.add_argument("--mem_optim_off",
action="store_true", default=False,
help="Memory optimize") action="store_true",
parser.add_argument( help="Memory optimize")
"--ir_optim", default=False, action="store_true", help="Graph optimize") parser.add_argument("--ir_optim",
parser.add_argument( default=False,
"--max_body_size", action="store_true",
type=int, help="Graph optimize")
default=512 * 1024 * 1024, parser.add_argument("--max_body_size",
help="Limit sizes of messages") type=int,
parser.add_argument( default=512 * 1024 * 1024,
"--use_multilang", help="Limit sizes of messages")
default=False, parser.add_argument("--use_multilang",
action="store_true", default=False,
help="Use Multi-language-service") action="store_true",
parser.add_argument( help="Use Multi-language-service")
"--product_name", parser.add_argument("--product_name",
type=str, type=str,
default=None, default=None,
help="product_name for authentication") help="product_name for authentication")
parser.add_argument( parser.add_argument("--container_id",
"--container_id", type=str,
type=str, default=None,
default=None, help="container_id for authentication")
help="container_id for authentication")
return parser.parse_args() return parser.parse_args()
...@@ -102,8 +108,8 @@ class OpMaker(object): ...@@ -102,8 +108,8 @@ class OpMaker(object):
def create(self, node_type, engine_name=None, inputs=[], outputs=[]): def create(self, node_type, engine_name=None, inputs=[], outputs=[]):
if node_type not in self.op_dict: if node_type not in self.op_dict:
raise Exception("Op type {} is not supported right now".format( raise Exception(
node_type)) "Op type {} is not supported right now".format(node_type))
node = server_sdk.DAGNode() node = server_sdk.DAGNode()
# node.name will be used as the infer engine name # node.name will be used as the infer engine name
if engine_name: if engine_name:
...@@ -151,9 +157,9 @@ class OpSeqMaker(object): ...@@ -151,9 +157,9 @@ class OpSeqMaker(object):
elif len(node.dependencies) == 1: elif len(node.dependencies) == 1:
if node.dependencies[0].name != self.workflow.nodes[-1].name: if node.dependencies[0].name != self.workflow.nodes[-1].name:
raise Exception( raise Exception(
'You must add op in order in OpSeqMaker. The previous op is {}, but the current op is followed by {}.'. 'You must add op in order in OpSeqMaker. The previous op is {}, but the current op is followed by {}.'
format(node.dependencies[0].name, self.workflow.nodes[ .format(node.dependencies[0].name,
-1].name)) self.workflow.nodes[-1].name))
self.workflow.nodes.extend([node]) self.workflow.nodes.extend([node])
def get_op_sequence(self): def get_op_sequence(self):
...@@ -366,8 +372,8 @@ class Server(object): ...@@ -366,8 +372,8 @@ class Server(object):
self.model_config_paths[node.name] = path self.model_config_paths[node.name] = path
print("You have specified multiple model paths, please ensure " print("You have specified multiple model paths, please ensure "
"that the input and output of multiple models are the same.") "that the input and output of multiple models are the same.")
workflow_oi_config_path = list(self.model_config_paths.items())[0][ workflow_oi_config_path = list(
1] self.model_config_paths.items())[0][1]
else: else:
raise Exception("The type of model_config_paths must be str or " raise Exception("The type of model_config_paths must be str or "
"dict({op: model_path}), not {}.".format( "dict({op: model_path}), not {}.".format(
...@@ -419,8 +425,8 @@ class Server(object): ...@@ -419,8 +425,8 @@ class Server(object):
if os.path.exists(tar_name): if os.path.exists(tar_name):
os.remove(tar_name) os.remove(tar_name)
raise SystemExit( raise SystemExit(
'Download failed, please check your network or permission of {}.'. 'Download failed, please check your network or permission of {}.'
format(self.module_path)) .format(self.module_path))
else: else:
try: try:
print('Decompressing files ..') print('Decompressing files ..')
...@@ -431,8 +437,8 @@ class Server(object): ...@@ -431,8 +437,8 @@ class Server(object):
if os.path.exists(exe_path): if os.path.exists(exe_path):
os.remove(exe_path) os.remove(exe_path)
raise SystemExit( raise SystemExit(
'Decompressing failed, please check your permission of {} or disk space left.'. 'Decompressing failed, please check your permission of {} or disk space left.'
format(self.module_path)) .format(self.module_path))
finally: finally:
os.remove(tar_name) os.remove(tar_name)
#release lock #release lock
...@@ -630,20 +636,20 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -630,20 +636,20 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
tensor.data = model_result[name].tobytes() tensor.data = model_result[name].tobytes()
else: else:
if v_type == 0: # int64 if v_type == 0: # int64
tensor.int64_data.extend(model_result[name].reshape(-1) tensor.int64_data.extend(
.tolist()) model_result[name].reshape(-1).tolist())
elif v_type == 1: # float32 elif v_type == 1: # float32
tensor.float_data.extend(model_result[name].reshape(-1) tensor.float_data.extend(
.tolist()) model_result[name].reshape(-1).tolist())
elif v_type == 2: # int32 elif v_type == 2: # int32
tensor.int_data.extend(model_result[name].reshape(-1) tensor.int_data.extend(
.tolist()) model_result[name].reshape(-1).tolist())
else: else:
raise Exception("error type.") raise Exception("error type.")
tensor.shape.extend(list(model_result[name].shape)) tensor.shape.extend(list(model_result[name].shape))
if name in self.lod_tensor_set_: if name in self.lod_tensor_set_:
tensor.lod.extend(model_result["{}.lod".format(name)] tensor.lod.extend(
.tolist()) model_result["{}.lod".format(name)].tolist())
inst.tensor_array.append(tensor) inst.tensor_array.append(tensor)
model_output.insts.append(inst) model_output.insts.append(inst)
model_output.engine_name = model_name model_output.engine_name = model_name
...@@ -662,11 +668,10 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -662,11 +668,10 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
def Inference(self, request, context): def Inference(self, request, context):
feed_dict, fetch_names, is_python, log_id \ feed_dict, fetch_names, is_python, log_id \
= self._unpack_inference_request(request) = self._unpack_inference_request(request)
ret = self.bclient_.predict( ret = self.bclient_.predict(feed=feed_dict,
feed=feed_dict, fetch=fetch_names,
fetch=fetch_names, need_variant_tag=True,
need_variant_tag=True, log_id=log_id)
log_id=log_id)
return self._pack_inference_response(ret, fetch_names, is_python) return self._pack_inference_response(ret, fetch_names, is_python)
def GetClientConfig(self, request, context): def GetClientConfig(self, request, context):
...@@ -743,15 +748,14 @@ class MultiLangServer(object): ...@@ -743,15 +748,14 @@ class MultiLangServer(object):
default_port = 12000 default_port = 12000
self.port_list_ = [] self.port_list_ = []
for i in range(1000): for i in range(1000):
if default_port + i != port and self._port_is_available(default_port if default_port + i != port and self._port_is_available(
+ i): default_port + i):
self.port_list_.append(default_port + i) self.port_list_.append(default_port + i)
break break
self.bserver_.prepare_server( self.bserver_.prepare_server(workdir=workdir,
workdir=workdir, port=self.port_list_[0],
port=self.port_list_[0], device=device,
device=device, cube_conf=cube_conf)
cube_conf=cube_conf)
self.set_port(port) self.set_port(port)
def _launch_brpc_service(self, bserver): def _launch_brpc_service(self, bserver):
...@@ -764,8 +768,8 @@ class MultiLangServer(object): ...@@ -764,8 +768,8 @@ class MultiLangServer(object):
return result != 0 return result != 0
def run_server(self): def run_server(self):
p_bserver = Process( p_bserver = Process(target=self._launch_brpc_service,
target=self._launch_brpc_service, args=(self.bserver_, )) args=(self.bserver_, ))
p_bserver.start() p_bserver.start()
options = [('grpc.max_send_message_length', self.body_size_), options = [('grpc.max_send_message_length', self.body_size_),
('grpc.max_receive_message_length', self.body_size_)] ('grpc.max_receive_message_length', self.body_size_)]
......
...@@ -88,8 +88,8 @@ def start_multi_card(args): # pylint: disable=doc-string-missing ...@@ -88,8 +88,8 @@ def start_multi_card(args): # pylint: disable=doc-string-missing
for ids in gpus: for ids in gpus:
if int(ids) >= len(env_gpus): if int(ids) >= len(env_gpus):
print( print(
" Max index of gpu_ids out of range, the number of CUDA_VISIBLE_DEVICES is {}.". " Max index of gpu_ids out of range, the number of CUDA_VISIBLE_DEVICES is {}."
format(len(env_gpus))) .format(len(env_gpus)))
exit(-1) exit(-1)
else: else:
env_gpus = [] env_gpus = []
...@@ -99,11 +99,11 @@ def start_multi_card(args): # pylint: disable=doc-string-missing ...@@ -99,11 +99,11 @@ def start_multi_card(args): # pylint: disable=doc-string-missing
else: else:
gpu_processes = [] gpu_processes = []
for i, gpu_id in enumerate(gpus): for i, gpu_id in enumerate(gpus):
p = Process( p = Process(target=start_gpu_card_model, args=(
target=start_gpu_card_model, args=( i,
i, gpu_id,
gpu_id, args,
args, )) ))
gpu_processes.append(p) gpu_processes.append(p)
for p in gpu_processes: for p in gpu_processes:
p.start() p.start()
...@@ -125,8 +125,9 @@ if __name__ == "__main__": ...@@ -125,8 +125,9 @@ if __name__ == "__main__":
gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"] gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"]
if len(gpu_ids) > 0: if len(gpu_ids) > 0:
web_service.set_gpus(gpu_ids) web_service.set_gpus(gpu_ids)
web_service.prepare_server( web_service.prepare_server(workdir=args.workdir,
workdir=args.workdir, port=args.port, device=args.device) port=args.port,
device=args.device)
web_service.run_rpc_service() web_service.run_rpc_service()
app_instance = Flask(__name__) app_instance = Flask(__name__)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册