提交 9256114c 编写于 作者: H HexToString

fix code style

上级 9bba1f72
...@@ -42,19 +42,30 @@ from concurrent import futures ...@@ -42,19 +42,30 @@ from concurrent import futures
class Server(object): class Server(object):
def __init__(self): def __init__(self):
"""
self.model_toolkit_conf:'list'=[] # The quantity of self.model_toolkit_conf is equal to the InferOp quantity/Engine--OP
self.model_conf:'collections.OrderedDict()' # Save the serving_server_conf.prototxt content (feed and fetch information) this is a map for multi-model in a workflow
self.workflow_fn:'str'="workflow.prototxt" # Only one for one Service/Workflow
self.resource_fn:'str'="resource.prototxt" # Only one for one Service,model_toolkit_fn and general_model_config_fn is recorded in this file
self.infer_service_fn:'str'="infer_service.prototxt" # Only one for one Service,Service--Workflow
self.model_toolkit_fn:'list'=[] # ["general_infer_0/model_toolkit.prototxt"]The quantity is equal to the InferOp quantity,Engine--OP
self.general_model_config_fn:'list'=[] # ["general_infer_0/general_model.prototxt"]The quantity is equal to the InferOp quantity,Feed and Fetch --OP
self.subdirectory:'list'=[] # The quantity is equal to the InferOp quantity, and name = node.name = engine.name
self.model_config_paths:'collections.OrderedDict()' # Save the serving_server_conf.prototxt path (feed and fetch information) this is a map for multi-model in a workflow
"""
self.server_handle_ = None self.server_handle_ = None
self.infer_service_conf = None self.infer_service_conf = None
self.model_toolkit_conf = []#The quantity is equal to the InferOp quantity,Engine--OP self.model_toolkit_conf = []
self.resource_conf = None self.resource_conf = None
self.memory_optimization = False self.memory_optimization = False
self.ir_optimization = False self.ir_optimization = False
self.model_conf = collections.OrderedDict()# save the serving_server_conf.prototxt content (feed and fetch information) this is a map for multi-model in a workflow self.model_conf = collections.OrderedDict()
self.workflow_fn = "workflow.prototxt"#only one for one Service,Workflow--Op self.workflow_fn = "workflow.prototxt"
self.resource_fn = "resource.prototxt"#only one for one Service,model_toolkit_fn and general_model_config_fn is recorded in this file self.resource_fn = "resource.prototxt"
self.infer_service_fn = "infer_service.prototxt"#only one for one Service,Service--Workflow self.infer_service_fn = "infer_service.prototxt"
self.model_toolkit_fn = []#["general_infer_0/model_toolkit.prototxt"]The quantity is equal to the InferOp quantity,Engine--OP self.model_toolkit_fn = []
self.general_model_config_fn = []#["general_infer_0/general_model.prototxt"]The quantity is equal to the InferOp quantity,Feed and Fetch --OP self.general_model_config_fn = []
self.subdirectory = []#The quantity is equal to the InferOp quantity, and name = node.name = engine.name self.subdirectory = []
self.cube_config_fn = "cube.conf" self.cube_config_fn = "cube.conf"
self.workdir = "" self.workdir = ""
self.max_concurrency = 0 self.max_concurrency = 0
...@@ -71,12 +82,15 @@ class Server(object): ...@@ -71,12 +82,15 @@ class Server(object):
self.use_trt = False self.use_trt = False
self.use_lite = False self.use_lite = False
self.use_xpu = False self.use_xpu = False
self.model_config_paths = collections.OrderedDict() # save the serving_server_conf.prototxt path (feed and fetch information) this is a map for multi-model in a workflow self.model_config_paths = collections.OrderedDict()
self.product_name = None self.product_name = None
self.container_id = None self.container_id = None
def get_fetch_list(self,infer_node_idx = -1 ): def get_fetch_list(self, infer_node_idx=-1):
fetch_names = [var.alias_name for var in list(self.model_conf.values())[infer_node_idx].fetch_var] fetch_names = [
var.alias_name
for var in list(self.model_conf.values())[infer_node_idx].fetch_var
]
return fetch_names return fetch_names
def set_max_concurrency(self, concurrency): def set_max_concurrency(self, concurrency):
...@@ -195,9 +209,10 @@ class Server(object): ...@@ -195,9 +209,10 @@ class Server(object):
self.workdir = workdir self.workdir = workdir
if self.resource_conf == None: if self.resource_conf == None:
self.resource_conf = server_sdk.ResourceConf() self.resource_conf = server_sdk.ResourceConf()
for idx, op_general_model_config_fn in enumerate(self.general_model_config_fn): for idx, op_general_model_config_fn in enumerate(
self.general_model_config_fn):
with open("{}/{}".format(workdir, op_general_model_config_fn), with open("{}/{}".format(workdir, op_general_model_config_fn),
"w") as fout: "w") as fout:
fout.write(str(list(self.model_conf.values())[idx])) fout.write(str(list(self.model_conf.values())[idx]))
for workflow in self.workflow_conf.workflows: for workflow in self.workflow_conf.workflows:
for node in workflow.nodes: for node in workflow.nodes:
...@@ -212,9 +227,11 @@ class Server(object): ...@@ -212,9 +227,11 @@ class Server(object):
if "quant" in node.name: if "quant" in node.name:
self.resource_conf.cube_quant_bits = 8 self.resource_conf.cube_quant_bits = 8
self.resource_conf.model_toolkit_path.extend([workdir]) self.resource_conf.model_toolkit_path.extend([workdir])
self.resource_conf.model_toolkit_file.extend([self.model_toolkit_fn[idx]]) self.resource_conf.model_toolkit_file.extend(
[self.model_toolkit_fn[idx]])
self.resource_conf.general_model_path.extend([workdir]) self.resource_conf.general_model_path.extend([workdir])
self.resource_conf.general_model_file.extend([op_general_model_config_fn]) self.resource_conf.general_model_file.extend(
[op_general_model_config_fn])
#TODO:figure out the meaning of product_name and container_id. #TODO:figure out the meaning of product_name and container_id.
if self.product_name != None: if self.product_name != None:
self.resource_conf.auth_product_name = self.product_name self.resource_conf.auth_product_name = self.product_name
...@@ -237,15 +254,18 @@ class Server(object): ...@@ -237,15 +254,18 @@ class Server(object):
if os.path.isdir(single_model_config): if os.path.isdir(single_model_config):
pass pass
elif os.path.isfile(single_model_config): elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.") raise ValueError(
"The input of --model should be a dir not file.")
if isinstance(model_config_paths_args, list): if isinstance(model_config_paths_args, list):
# If there is only one model path, use the default infer_op. # If there is only one model path, use the default infer_op.
# Because there are several infer_op type, we need to find # Because there are several infer_op type, we need to find
# it from workflow_conf. # it from workflow_conf.
default_engine_types = [ default_engine_types = [
'GeneralInferOp', 'GeneralDistKVInferOp', 'GeneralInferOp',
'GeneralDistKVQuantInferOp','GeneralDetectionOp', 'GeneralDistKVInferOp',
'GeneralDistKVQuantInferOp',
'GeneralDetectionOp',
] ]
# now only support single-workflow. # now only support single-workflow.
# TODO:support multi-workflow # TODO:support multi-workflow
...@@ -256,16 +276,24 @@ class Server(object): ...@@ -256,16 +276,24 @@ class Server(object):
raise Exception( raise Exception(
"You have set the engine_name of Op. Please use the form {op: model_path} to configure model path" "You have set the engine_name of Op. Please use the form {op: model_path} to configure model path"
) )
f = open("{}/serving_server_conf.prototxt".format( f = open("{}/serving_server_conf.prototxt".format(
model_config_paths_args[model_config_paths_list_idx]), 'r') model_config_paths_args[model_config_paths_list_idx]),
self.model_conf[node.name] = google.protobuf.text_format.Merge(str(f.read()), m_config.GeneralModelConfig()) 'r')
self.model_config_paths[node.name] = model_config_paths_args[model_config_paths_list_idx] self.model_conf[
self.general_model_config_fn.append(node.name+"/general_model.prototxt") node.name] = google.protobuf.text_format.Merge(
self.model_toolkit_fn.append(node.name+"/model_toolkit.prototxt") str(f.read()), m_config.GeneralModelConfig())
self.model_config_paths[
node.name] = model_config_paths_args[
model_config_paths_list_idx]
self.general_model_config_fn.append(
node.name + "/general_model.prototxt")
self.model_toolkit_fn.append(node.name +
"/model_toolkit.prototxt")
self.subdirectory.append(node.name) self.subdirectory.append(node.name)
model_config_paths_list_idx += 1 model_config_paths_list_idx += 1
if model_config_paths_list_idx == len(model_config_paths_args): if model_config_paths_list_idx == len(
model_config_paths_args):
break break
#Right now, this is not useful. #Right now, this is not useful.
elif isinstance(model_config_paths_args, dict): elif isinstance(model_config_paths_args, dict):
...@@ -278,11 +306,12 @@ class Server(object): ...@@ -278,11 +306,12 @@ class Server(object):
"that the input and output of multiple models are the same.") "that the input and output of multiple models are the same.")
f = open("{}/serving_server_conf.prototxt".format(path), 'r') f = open("{}/serving_server_conf.prototxt".format(path), 'r')
self.model_conf[node.name] = google.protobuf.text_format.Merge( self.model_conf[node.name] = google.protobuf.text_format.Merge(
str(f.read()), m_config.GeneralModelConfig()) str(f.read()), m_config.GeneralModelConfig())
else: else:
raise Exception("The type of model_config_paths must be str or list or " raise Exception(
"dict({op: model_path}), not {}.".format( "The type of model_config_paths must be str or list or "
type(model_config_paths_args))) "dict({op: model_path}), not {}.".format(
type(model_config_paths_args)))
# check config here # check config here
# print config here # print config here
...@@ -409,7 +438,7 @@ class Server(object): ...@@ -409,7 +438,7 @@ class Server(object):
resource_fn = "{}/{}".format(workdir, self.resource_fn) resource_fn = "{}/{}".format(workdir, self.resource_fn)
self._write_pb_str(resource_fn, self.resource_conf) self._write_pb_str(resource_fn, self.resource_conf)
for idx,single_model_toolkit_fn in enumerate(self.model_toolkit_fn): for idx, single_model_toolkit_fn in enumerate(self.model_toolkit_fn):
model_toolkit_fn = "{}/{}".format(workdir, single_model_toolkit_fn) model_toolkit_fn = "{}/{}".format(workdir, single_model_toolkit_fn)
self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf[idx]) self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf[idx])
...@@ -498,6 +527,7 @@ class Server(object): ...@@ -498,6 +527,7 @@ class Server(object):
os.system(command) os.system(command)
class MultiLangServer(object): class MultiLangServer(object):
def __init__(self): def __init__(self):
self.bserver_ = Server() self.bserver_ = Server()
...@@ -553,22 +583,23 @@ class MultiLangServer(object): ...@@ -553,22 +583,23 @@ class MultiLangServer(object):
def set_gpuid(self, gpuid=0): def set_gpuid(self, gpuid=0):
self.bserver_.set_gpuid(gpuid) self.bserver_.set_gpuid(gpuid)
def load_model_config(self, server_config_dir_paths, client_config_path=None): def load_model_config(self,
server_config_dir_paths,
client_config_path=None):
if isinstance(server_config_dir_paths, str): if isinstance(server_config_dir_paths, str):
server_config_dir_paths = [server_config_dir_paths] server_config_dir_paths = [server_config_dir_paths]
elif isinstance(server_config_dir_paths, list): elif isinstance(server_config_dir_paths, list):
pass pass
else: else:
raise Exception("The type of model_config_paths must be str or list" raise Exception("The type of model_config_paths must be str or list"
", not {}.".format( ", not {}.".format(type(server_config_dir_paths)))
type(server_config_dir_paths)))
for single_model_config in server_config_dir_paths: for single_model_config in server_config_dir_paths:
if os.path.isdir(single_model_config): if os.path.isdir(single_model_config):
pass pass
elif os.path.isfile(single_model_config): elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.") raise ValueError(
"The input of --model should be a dir not file.")
self.bserver_.load_model_config(server_config_dir_paths) self.bserver_.load_model_config(server_config_dir_paths)
if client_config_path is None: if client_config_path is None:
...@@ -576,27 +607,30 @@ class MultiLangServer(object): ...@@ -576,27 +607,30 @@ class MultiLangServer(object):
if isinstance(server_config_dir_paths, dict): if isinstance(server_config_dir_paths, dict):
self.is_multi_model_ = True self.is_multi_model_ = True
client_config_path = [] client_config_path = []
for server_config_path_items in list(server_config_dir_paths.items()): for server_config_path_items in list(
client_config_path.append( server_config_path_items[1] ) server_config_dir_paths.items()):
client_config_path.append(server_config_path_items[1])
elif isinstance(server_config_dir_paths, list): elif isinstance(server_config_dir_paths, list):
self.is_multi_model_ = False self.is_multi_model_ = False
client_config_path = server_config_dir_paths client_config_path = server_config_dir_paths
else: else:
raise Exception("The type of model_config_paths must be str or list or " raise Exception(
"dict({op: model_path}), not {}.".format( "The type of model_config_paths must be str or list or "
type(server_config_dir_paths))) "dict({op: model_path}), not {}.".format(
type(server_config_dir_paths)))
if isinstance(client_config_path, str): if isinstance(client_config_path, str):
client_config_path = [client_config_path] client_config_path = [client_config_path]
elif isinstance(client_config_path, list): elif isinstance(client_config_path, list):
pass pass
else:# dict is not support right now. else: # dict is not support right now.
raise Exception("The type of client_config_path must be str or list or " raise Exception(
"dict({op: model_path}), not {}.".format( "The type of client_config_path must be str or list or "
type(client_config_path))) "dict({op: model_path}), not {}.".format(
type(client_config_path)))
if len(client_config_path) != len(server_config_dir_paths): if len(client_config_path) != len(server_config_dir_paths):
raise Warning("The len(client_config_path) is {}, != len(server_config_dir_paths) {}." raise Warning(
.format( len(client_config_path), len(server_config_dir_paths) ) "The len(client_config_path) is {}, != len(server_config_dir_paths) {}."
) .format(len(client_config_path), len(server_config_dir_paths)))
self.bclient_config_path_list = client_config_path self.bclient_config_path_list = client_config_path
def prepare_server(self, def prepare_server(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册