提交 19dea60d 编写于 作者: T TeslaZhao

fix one bad explaination

上级 dc512ce7
...@@ -35,12 +35,12 @@ else: ...@@ -35,12 +35,12 @@ else:
raise Exception("Error Python version") raise Exception("Error Python version")
from .error_catch import ErrorCatch, CustomException, CustomExceptionCode, ParamChecker, ParamVerify from .error_catch import ErrorCatch, CustomException, CustomExceptionCode, ParamChecker, ParamVerify
check_feed_dict=ParamVerify.check_feed_dict check_feed_dict = ParamVerify.check_feed_dict
check_fetch_list=ParamVerify.check_fetch_list check_fetch_list = ParamVerify.check_fetch_list
from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2
from .channel import (ThreadChannel, ProcessChannel,ChannelData, from .channel import (ThreadChannel, ProcessChannel, ChannelData,
ChannelDataType, ChannelStopError, ChannelTimeoutError) ChannelDataType, ChannelStopError, ChannelTimeoutError)
from .error_catch import ProductErrCode from .error_catch import ProductErrCode
from .error_catch import CustomExceptionCode as ChannelDataErrcode from .error_catch import CustomExceptionCode as ChannelDataErrcode
from .util import NameGenerator from .util import NameGenerator
from .profiler import UnsafeTimeProfiler as TimeProfiler from .profiler import UnsafeTimeProfiler as TimeProfiler
...@@ -119,9 +119,9 @@ class Op(object): ...@@ -119,9 +119,9 @@ class Op(object):
self._for_close_op_lock = threading.Lock() self._for_close_op_lock = threading.Lock()
self._succ_init_op = False self._succ_init_op = False
self._succ_close_op = False self._succ_close_op = False
self.dynamic_shape_info = {} self.dynamic_shape_info = {}
self.set_dynamic_shape_info() self.set_dynamic_shape_info()
def set_dynamic_shape_info(self): def set_dynamic_shape_info(self):
""" """
when opening tensorrt(configure in config.yml) and each time the input shape when opening tensorrt(configure in config.yml) and each time the input shape
...@@ -141,7 +141,6 @@ class Op(object): ...@@ -141,7 +141,6 @@ class Op(object):
feed_names = client.feed_names_ feed_names = client.feed_names_
fetch_names = client.fetch_names_ fetch_names = client.fetch_names_
return feed_names, fetch_names return feed_names, fetch_names
def init_from_dict(self, conf): def init_from_dict(self, conf):
""" """
...@@ -164,9 +163,10 @@ class Op(object): ...@@ -164,9 +163,10 @@ class Op(object):
if self._client_config is None: if self._client_config is None:
self._client_config = conf.get("client_config") self._client_config = conf.get("client_config")
if self._use_encryption_model is None: if self._use_encryption_model is None:
print ("config use_encryption model here", conf.get("use_encryption_model")) print("config use_encryption model here",
conf.get("use_encryption_model"))
self._use_encryption_model = conf.get("use_encryption_model") self._use_encryption_model = conf.get("use_encryption_model")
if self._encryption_key is None or self._encryption_key=="": if self._encryption_key is None or self._encryption_key == "":
self._encryption_key = conf.get("encryption_key") self._encryption_key = conf.get("encryption_key")
if self._timeout is None: if self._timeout is None:
self._timeout = conf["timeout"] self._timeout = conf["timeout"]
...@@ -401,14 +401,16 @@ class Op(object): ...@@ -401,14 +401,16 @@ class Op(object):
if self.client_type == 'brpc': if self.client_type == 'brpc':
client = Client() client = Client()
client.load_client_config(client_config) client.load_client_config(client_config)
self.right_feed_names, self.right_fetch_names = self.get_feed_fetch_list(client) self.right_feed_names, self.right_fetch_names = self.get_feed_fetch_list(
client)
elif self.client_type == 'pipeline_grpc': elif self.client_type == 'pipeline_grpc':
client = PPClient() client = PPClient()
elif self.client_type == 'local_predictor': elif self.client_type == 'local_predictor':
if self.local_predictor is None: if self.local_predictor is None:
raise ValueError("local predictor not yet created") raise ValueError("local predictor not yet created")
client = self.local_predictor client = self.local_predictor
self.right_feed_names, self.right_fetch_names = self.get_feed_fetch_list(client) self.right_feed_names, self.right_fetch_names = self.get_feed_fetch_list(
client)
else: else:
raise ValueError("Failed to init client: unknow client " raise ValueError("Failed to init client: unknow client "
"type {}".format(self.client_type)) "type {}".format(self.client_type))
...@@ -417,12 +419,13 @@ class Op(object): ...@@ -417,12 +419,13 @@ class Op(object):
_LOGGER.info("Op({}) has no fetch name set. So fetch all vars") _LOGGER.info("Op({}) has no fetch name set. So fetch all vars")
if self.client_type != "local_predictor": if self.client_type != "local_predictor":
if self._use_encryption_model is None or self._use_encryption_model is False: if self._use_encryption_model is None or self._use_encryption_model is False:
client.connect(server_endpoints) client.connect(server_endpoints)
else: else:
print("connect to encryption rpc client") print("connect to encryption rpc client")
client.use_key(self._encryption_key) client.use_key(self._encryption_key)
client.connect(server_endpoints, encryption=True) client.connect(server_endpoints, encryption=True)
_LOGGER.info("init_client, feed_list:{}, fetch_list: {}".format(self.right_feed_names, self.right_fetch_names)) _LOGGER.info("init_client, feed_list:{}, fetch_list: {}".format(
self.right_feed_names, self.right_fetch_names))
return client return client
def get_input_ops(self): def get_input_ops(self):
...@@ -599,7 +602,7 @@ class Op(object): ...@@ -599,7 +602,7 @@ class Op(object):
(_, input_dict), = input_dicts.items() (_, input_dict), = input_dicts.items()
return input_dict, False, None, "" return input_dict, False, None, ""
def process(self, feed_batch, typical_logid=0): def process(self, feed_batch, typical_logid=0):
""" """
In process stage, send requests to the inference server or predict locally. In process stage, send requests to the inference server or predict locally.
...@@ -616,19 +619,23 @@ class Op(object): ...@@ -616,19 +619,23 @@ class Op(object):
call_result = None call_result = None
err_code = ChannelDataErrcode.OK.value err_code = ChannelDataErrcode.OK.value
err_info = "" err_info = ""
@ErrorCatch
@ErrorCatch
@ParamChecker @ParamChecker
def feed_fetch_list_check_helper(feed_batch : lambda feed_batch: check_feed_dict(feed_batch[0], self.right_feed_names), def feed_fetch_list_check_helper(
fetch_list : lambda fetch_list: check_fetch_list(fetch_list, self.right_fetch_names), feed_batch: lambda feed_batch: check_feed_dict(feed_batch[0], self.right_feed_names),
log_id): fetch_list: lambda fetch_list: check_fetch_list(fetch_list, self.right_fetch_names),
log_id):
return None return None
_, resp = feed_fetch_list_check_helper(feed_batch, self._fetch_names, log_id=typical_logid)
_, resp = feed_fetch_list_check_helper(
feed_batch, self._fetch_names, log_id=typical_logid)
if resp.err_no != CustomExceptionCode.OK.value: if resp.err_no != CustomExceptionCode.OK.value:
err_code = resp.err_no err_code = resp.err_no
err_info = resp.err_msg err_info = resp.err_msg
call_result = None call_result = None
return call_result, err_code, err_info return call_result, err_code, err_info
if self.client_type == "local_predictor": if self.client_type == "local_predictor":
err, err_info = ChannelData.check_batch_npdata(feed_batch) err, err_info = ChannelData.check_batch_npdata(feed_batch)
if err != 0: if err != 0:
...@@ -804,7 +811,7 @@ class Op(object): ...@@ -804,7 +811,7 @@ class Op(object):
self.mkldnn_cache_capacity, self.mkldnn_op_list, self.mkldnn_cache_capacity, self.mkldnn_op_list,
self.mkldnn_bf16_op_list, self.is_jump_op(), self.mkldnn_bf16_op_list, self.is_jump_op(),
self.get_output_channels_of_jump_ops(), self.get_output_channels_of_jump_ops(),
self.min_subgraph_size, self.dynamic_shape_info, self.min_subgraph_size, self.dynamic_shape_info,
self.use_calib)) self.use_calib))
p.daemon = True p.daemon = True
p.start() p.start()
...@@ -839,9 +846,9 @@ class Op(object): ...@@ -839,9 +846,9 @@ class Op(object):
self._get_output_channels(), True, trace_buffer, self._get_output_channels(), True, trace_buffer,
self.model_config, self.workdir, self.thread_num, self.model_config, self.workdir, self.thread_num,
self.device_type, self.devices, self.mem_optim, self.device_type, self.devices, self.mem_optim,
self.ir_optim, self.precision, self.use_mkldnn, self.ir_optim, self.precision, self.use_mkldnn,
self.mkldnn_cache_capacity, self.mkldnn_op_list, self.mkldnn_cache_capacity, self.mkldnn_op_list,
self.mkldnn_bf16_op_list, self.is_jump_op(), self.mkldnn_bf16_op_list, self.is_jump_op(),
self.get_output_channels_of_jump_ops(), self.get_output_channels_of_jump_ops(),
self.min_subgraph_size, self.dynamic_shape_info, self.min_subgraph_size, self.dynamic_shape_info,
self.use_calib)) self.use_calib))
...@@ -873,40 +880,43 @@ class Op(object): ...@@ -873,40 +880,43 @@ class Op(object):
preped_data_dict = collections.OrderedDict() preped_data_dict = collections.OrderedDict()
err_channeldata_dict = collections.OrderedDict() err_channeldata_dict = collections.OrderedDict()
skip_process_dict = {} skip_process_dict = {}
@ErrorCatch @ErrorCatch
def preprocess_help(self, parsed_data, data_id, logid_dict): def preprocess_help(self, parsed_data, data_id, logid_dict):
preped_data, is_skip_process, prod_errcode, prod_errinfo = self.preprocess( preped_data, is_skip_process, prod_errcode, prod_errinfo = self.preprocess(
parsed_data, data_id, logid_dict.get(data_id)) parsed_data, data_id, logid_dict.get(data_id))
return preped_data, is_skip_process, prod_errcode, prod_errinfo return preped_data, is_skip_process, prod_errcode, prod_errinfo
for data_id, parsed_data in parsed_data_dict.items(): for data_id, parsed_data in parsed_data_dict.items():
preped_data, error_channeldata = None, None preped_data, error_channeldata = None, None
is_skip_process = False is_skip_process = False
prod_errcode, prod_errinfo = None, None prod_errcode, prod_errinfo = None, None
log_id = logid_dict.get(data_id) log_id = logid_dict.get(data_id)
process_res, resp = preprocess_help(self, parsed_data, data_id = data_id, process_res, resp = preprocess_help(
logid_dict = logid_dict) self, parsed_data, data_id=data_id, logid_dict=logid_dict)
if resp.err_no == CustomExceptionCode.OK.value: if resp.err_no == CustomExceptionCode.OK.value:
preped_data, is_skip_process, prod_errcode, prod_errinfo = process_res preped_data, is_skip_process, prod_errcode, prod_errinfo = process_res
if is_skip_process is True: if is_skip_process is True:
skip_process_dict[data_id] = True skip_process_dict[data_id] = True
if prod_errcode is not None: if prod_errcode is not None:
_LOGGER.error("data_id: {} return product error. Product ErrNo:{}, Product ErrMsg: {}".format(data_id, prod_errcode, prod_errinfo)) _LOGGER.error(
"data_id: {} return product error. Product ErrNo:{}, Product ErrMsg: {}".
format(data_id, prod_errcode, prod_errinfo))
error_channeldata = ChannelData( error_channeldata = ChannelData(
error_code=ChannelDataErrcode.PRODUCT_ERROR.value, error_code=ChannelDataErrcode.PRODUCT_ERROR.value,
error_info="", error_info="",
prod_error_code=prod_errcode, prod_error_code=prod_errcode,
prod_error_info=prod_errinfo, prod_error_info=prod_errinfo,
data_id=data_id, data_id=data_id,
log_id=log_id) log_id=log_id)
else: else:
error_channeldata = ChannelData( error_channeldata = ChannelData(
error_code=resp.err_no, error_code=resp.err_no,
error_info=resp.err_msg, error_info=resp.err_msg,
data_id=data_id, data_id=data_id,
log_id=log_id) log_id=log_id)
skip_process_dict[data_id] = True skip_process_dict[data_id] = True
if error_channeldata is not None: if error_channeldata is not None:
err_channeldata_dict[data_id] = error_channeldata err_channeldata_dict[data_id] = error_channeldata
...@@ -1086,8 +1096,8 @@ class Op(object): ...@@ -1086,8 +1096,8 @@ class Op(object):
# 2 kinds of errors # 2 kinds of errors
if error_code != ChannelDataErrcode.OK.value or midped_batch is None: if error_code != ChannelDataErrcode.OK.value or midped_batch is None:
error_info = "[{}] failed to predict. {}. Please check the input dict and checkout PipelineServingLogs/pipeline.log for more details.".format( error_info = "[{}] failed to predict. {}. Please check the input dict and checkout PipelineServingLogs/pipeline.log for more details.".format(
self.name, error_info) self.name, error_info)
_LOGGER.error(error_info) _LOGGER.error(error_info)
for data_id in data_ids: for data_id in data_ids:
err_channeldata_dict[data_id] = ChannelData( err_channeldata_dict[data_id] = ChannelData(
...@@ -1162,12 +1172,16 @@ class Op(object): ...@@ -1162,12 +1172,16 @@ class Op(object):
_LOGGER.debug("{} Running postprocess".format(op_info_prefix)) _LOGGER.debug("{} Running postprocess".format(op_info_prefix))
postped_data_dict = collections.OrderedDict() postped_data_dict = collections.OrderedDict()
err_channeldata_dict = collections.OrderedDict() err_channeldata_dict = collections.OrderedDict()
@ErrorCatch @ErrorCatch
def postprocess_help(self, parsed_data_dict, midped_data, data_id, logid_dict): def postprocess_help(self, parsed_data_dict, midped_data, data_id,
postped_data, prod_errcode, prod_errinfo = self.postprocess(parsed_data_dict[data_id], logid_dict):
midped_data, data_id, logid_dict.get(data_id)) postped_data, prod_errcode, prod_errinfo = self.postprocess(
parsed_data_dict[data_id], midped_data, data_id,
logid_dict.get(data_id))
if not isinstance(postped_data, dict): if not isinstance(postped_data, dict):
raise CustomException(CustomExceptionCode.TYPE_ERROR, "postprocess should return dict", True) raise CustomException(CustomExceptionCode.TYPE_ERROR,
"postprocess should return dict", True)
return postped_data, prod_errcode, prod_errinfo return postped_data, prod_errcode, prod_errinfo
for data_id, midped_data in midped_data_dict.items(): for data_id, midped_data in midped_data_dict.items():
...@@ -1175,19 +1189,23 @@ class Op(object): ...@@ -1175,19 +1189,23 @@ class Op(object):
postped_data, err_channeldata = None, None postped_data, err_channeldata = None, None
prod_errcode, prod_errinfo = None, None prod_errcode, prod_errinfo = None, None
post_res, resp = postprocess_help(self, parsed_data_dict, midped_data, data_id post_res, resp = postprocess_help(
= data_id, logid_dict = logid_dict) self,
parsed_data_dict,
midped_data,
data_id=data_id,
logid_dict=logid_dict)
if resp.err_no == CustomExceptionCode.OK.value: if resp.err_no == CustomExceptionCode.OK.value:
postped_data, prod_errcode, prod_errinfo = post_res postped_data, prod_errcode, prod_errinfo = post_res
if prod_errcode is not None: if prod_errcode is not None:
# product errors occured # product errors occured
err_channeldata = ChannelData( err_channeldata = ChannelData(
error_code=ChannelDataErrcode.PRODUCT_ERROR.value, error_code=ChannelDataErrcode.PRODUCT_ERROR.value,
error_info="", error_info="",
prod_error_code=prod_errcode, prod_error_code=prod_errcode,
prod_error_info=prod_errinfo, prod_error_info=prod_errinfo,
data_id=data_id, data_id=data_id,
log_id=log_id) log_id=log_id)
else: else:
err_channeldata = ChannelData( err_channeldata = ChannelData(
error_code=resp.err_no, error_code=resp.err_no,
...@@ -1203,16 +1221,16 @@ class Op(object): ...@@ -1203,16 +1221,16 @@ class Op(object):
err, _ = ChannelData.check_npdata(postped_data) err, _ = ChannelData.check_npdata(postped_data)
if err == 0: if err == 0:
output_data = ChannelData( output_data = ChannelData(
ChannelDataType.CHANNEL_NPDATA.value, ChannelDataType.CHANNEL_NPDATA.value,
npdata=postped_data, npdata=postped_data,
data_id=data_id, data_id=data_id,
log_id=log_id) log_id=log_id)
else: else:
output_data = ChannelData( output_data = ChannelData(
ChannelDataType.DICT.value, ChannelDataType.DICT.value,
dictdata=postped_data, dictdata=postped_data,
data_id=data_id, data_id=data_id,
log_id=log_id) log_id=log_id)
postped_data_dict[data_id] = output_data postped_data_dict[data_id] = output_data
_LOGGER.debug("{} Succ postprocess".format(op_info_prefix)) _LOGGER.debug("{} Succ postprocess".format(op_info_prefix))
return postped_data_dict, err_channeldata_dict return postped_data_dict, err_channeldata_dict
...@@ -1303,10 +1321,10 @@ class Op(object): ...@@ -1303,10 +1321,10 @@ class Op(object):
def _run(self, concurrency_idx, input_channel, output_channels, def _run(self, concurrency_idx, input_channel, output_channels,
is_thread_op, trace_buffer, model_config, workdir, thread_num, is_thread_op, trace_buffer, model_config, workdir, thread_num,
device_type, devices, mem_optim, ir_optim, precision, device_type, devices, mem_optim, ir_optim, precision, use_mkldnn,
use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list, mkldnn_cache_capacity, mkldnn_op_list, mkldnn_bf16_op_list,
mkldnn_bf16_op_list, is_jump_op, output_channels_of_jump_ops, is_jump_op, output_channels_of_jump_ops, min_subgraph_size,
min_subgraph_size, dynamic_shape_info, use_calib): dynamic_shape_info, use_calib):
""" """
_run() is the entry function of OP process / thread model.When client _run() is the entry function of OP process / thread model.When client
type is local_predictor in process mode, the CUDA environment needs to type is local_predictor in process mode, the CUDA environment needs to
...@@ -1344,12 +1362,14 @@ class Op(object): ...@@ -1344,12 +1362,14 @@ class Op(object):
# init ops # init ops
profiler = None profiler = None
@ErrorCatch @ErrorCatch
def check_helper(self, is_thread_op, model_config, workdir, def check_helper(self, is_thread_op, model_config, workdir, thread_num,
thread_num, device_type, devices, mem_optim, ir_optim, device_type, devices, mem_optim, ir_optim, precision,
precision, use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list, use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list,
mkldnn_bf16_op_list, min_subgraph_size, dynamic_shape_info): mkldnn_bf16_op_list, min_subgraph_size,
dynamic_shape_info):
if is_thread_op == False and self.client_type == "local_predictor": if is_thread_op == False and self.client_type == "local_predictor":
self.service_handler = local_service_handler.LocalServiceHandler( self.service_handler = local_service_handler.LocalServiceHandler(
model_config=model_config, model_config=model_config,
...@@ -1377,17 +1397,19 @@ class Op(object): ...@@ -1377,17 +1397,19 @@ class Op(object):
profiler = self._initialize(is_thread_op, concurrency_idx) profiler = self._initialize(is_thread_op, concurrency_idx)
return profiler return profiler
profiler, resp = check_helper(self, is_thread_op, model_config, workdir, profiler, resp = check_helper(
thread_num, device_type, devices, mem_optim, ir_optim, self, is_thread_op, model_config, workdir, thread_num, device_type,
precision, use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list, devices, mem_optim, ir_optim, precision, use_mkldnn,
mkldnn_bf16_op_list, min_subgraph_size, dynamic_shape_info) mkldnn_cache_capacity, mkldnn_op_list, mkldnn_bf16_op_list,
min_subgraph_size, dynamic_shape_info)
if resp.err_no != CustomExceptionCode.OK.value: if resp.err_no != CustomExceptionCode.OK.value:
_LOGGER.critical( _LOGGER.critical(
"{} failed to init op: {}".format(op_info_prefix, resp.err_msg), "{} failed to init op: {}".format(op_info_prefix, resp.err_msg),
exc_info=False) exc_info=False)
print("{} failed to init op: {}".format(op_info_prefix, resp.err_msg)) print("{} failed to init op: {}".format(op_info_prefix,
resp.err_msg))
kill_stop_process_by_pid("kill", os.getpgid(os.getpid())) kill_stop_process_by_pid("kill", os.getpgid(os.getpid()))
_LOGGER.info("{} Succ init".format(op_info_prefix)) _LOGGER.info("{} Succ init".format(op_info_prefix))
...@@ -1583,6 +1605,7 @@ class Op(object): ...@@ -1583,6 +1605,7 @@ class Op(object):
Returns: Returns:
TimeProfiler TimeProfiler
""" """
@ErrorCatch @ErrorCatch
def init_helper(self, is_thread_op, concurrency_idx): def init_helper(self, is_thread_op, concurrency_idx):
if is_thread_op: if is_thread_op:
...@@ -1592,7 +1615,7 @@ class Op(object): ...@@ -1592,7 +1615,7 @@ class Op(object):
self.concurrency_idx = None self.concurrency_idx = None
# init client # init client
self.client = self.init_client(self._client_config, self.client = self.init_client(self._client_config,
self._server_endpoints) self._server_endpoints)
# user defined # user defined
self.init_op() self.init_op()
self._succ_init_op = True self._succ_init_op = True
...@@ -1601,10 +1624,10 @@ class Op(object): ...@@ -1601,10 +1624,10 @@ class Op(object):
self.concurrency_idx = concurrency_idx self.concurrency_idx = concurrency_idx
# init client # init client
self.client = self.init_client(self._client_config, self.client = self.init_client(self._client_config,
self._server_endpoints) self._server_endpoints)
# user defined # user defined
self.init_op() self.init_op()
init_helper(self, is_thread_op, concurrency_idx) init_helper(self, is_thread_op, concurrency_idx)
print("[OP Object] init success") print("[OP Object] init success")
# use a separate TimeProfiler per thread or process # use a separate TimeProfiler per thread or process
...@@ -1910,8 +1933,8 @@ class VirtualOp(Op): ...@@ -1910,8 +1933,8 @@ class VirtualOp(Op):
\-> E ----------/ \-> E ----------/
DAG view: [[A], [B, E], [C], [D], [F]] DAG view: [[A], [B, E], [C], [D], [F]]
BUILD DAG: [A -> B -> C -> D -> E -> F] BUILD DAG: [A -> B -> C -> D -> F]
\-> E -> V1-> V2-> V3/ \-> E -> V1-> V2->/
""" """
def __init__(self, name, concurrency=1): def __init__(self, name, concurrency=1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册