提交 19dea60d 编写于 作者: T TeslaZhao

fix one bad explaination

上级 dc512ce7
......@@ -35,10 +35,10 @@ else:
raise Exception("Error Python version")
from .error_catch import ErrorCatch, CustomException, CustomExceptionCode, ParamChecker, ParamVerify
check_feed_dict=ParamVerify.check_feed_dict
check_fetch_list=ParamVerify.check_fetch_list
check_feed_dict = ParamVerify.check_feed_dict
check_fetch_list = ParamVerify.check_fetch_list
from .proto import pipeline_service_pb2
from .channel import (ThreadChannel, ProcessChannel,ChannelData,
from .channel import (ThreadChannel, ProcessChannel, ChannelData,
ChannelDataType, ChannelStopError, ChannelTimeoutError)
from .error_catch import ProductErrCode
from .error_catch import CustomExceptionCode as ChannelDataErrcode
......@@ -142,7 +142,6 @@ class Op(object):
fetch_names = client.fetch_names_
return feed_names, fetch_names
def init_from_dict(self, conf):
"""
Initializing one Op from config.yaml. If server_endpoints exist,
......@@ -164,9 +163,10 @@ class Op(object):
if self._client_config is None:
self._client_config = conf.get("client_config")
if self._use_encryption_model is None:
print ("config use_encryption model here", conf.get("use_encryption_model"))
print("config use_encryption model here",
conf.get("use_encryption_model"))
self._use_encryption_model = conf.get("use_encryption_model")
if self._encryption_key is None or self._encryption_key=="":
if self._encryption_key is None or self._encryption_key == "":
self._encryption_key = conf.get("encryption_key")
if self._timeout is None:
self._timeout = conf["timeout"]
......@@ -401,14 +401,16 @@ class Op(object):
if self.client_type == 'brpc':
client = Client()
client.load_client_config(client_config)
self.right_feed_names, self.right_fetch_names = self.get_feed_fetch_list(client)
self.right_feed_names, self.right_fetch_names = self.get_feed_fetch_list(
client)
elif self.client_type == 'pipeline_grpc':
client = PPClient()
elif self.client_type == 'local_predictor':
if self.local_predictor is None:
raise ValueError("local predictor not yet created")
client = self.local_predictor
self.right_feed_names, self.right_fetch_names = self.get_feed_fetch_list(client)
self.right_feed_names, self.right_fetch_names = self.get_feed_fetch_list(
client)
else:
raise ValueError("Failed to init client: unknow client "
"type {}".format(self.client_type))
......@@ -422,7 +424,8 @@ class Op(object):
print("connect to encryption rpc client")
client.use_key(self._encryption_key)
client.connect(server_endpoints, encryption=True)
_LOGGER.info("init_client, feed_list:{}, fetch_list: {}".format(self.right_feed_names, self.right_fetch_names))
_LOGGER.info("init_client, feed_list:{}, fetch_list: {}".format(
self.right_feed_names, self.right_fetch_names))
return client
def get_input_ops(self):
......@@ -616,13 +619,17 @@ class Op(object):
call_result = None
err_code = ChannelDataErrcode.OK.value
err_info = ""
@ErrorCatch
@ParamChecker
def feed_fetch_list_check_helper(feed_batch : lambda feed_batch: check_feed_dict(feed_batch[0], self.right_feed_names),
fetch_list : lambda fetch_list: check_fetch_list(fetch_list, self.right_fetch_names),
def feed_fetch_list_check_helper(
feed_batch: lambda feed_batch: check_feed_dict(feed_batch[0], self.right_feed_names),
fetch_list: lambda fetch_list: check_fetch_list(fetch_list, self.right_fetch_names),
log_id):
return None
_, resp = feed_fetch_list_check_helper(feed_batch, self._fetch_names, log_id=typical_logid)
_, resp = feed_fetch_list_check_helper(
feed_batch, self._fetch_names, log_id=typical_logid)
if resp.err_no != CustomExceptionCode.OK.value:
err_code = resp.err_no
err_info = resp.err_msg
......@@ -873,6 +880,7 @@ class Op(object):
preped_data_dict = collections.OrderedDict()
err_channeldata_dict = collections.OrderedDict()
skip_process_dict = {}
@ErrorCatch
def preprocess_help(self, parsed_data, data_id, logid_dict):
preped_data, is_skip_process, prod_errcode, prod_errinfo = self.preprocess(
......@@ -884,14 +892,16 @@ class Op(object):
is_skip_process = False
prod_errcode, prod_errinfo = None, None
log_id = logid_dict.get(data_id)
process_res, resp = preprocess_help(self, parsed_data, data_id = data_id,
logid_dict = logid_dict)
process_res, resp = preprocess_help(
self, parsed_data, data_id=data_id, logid_dict=logid_dict)
if resp.err_no == CustomExceptionCode.OK.value:
preped_data, is_skip_process, prod_errcode, prod_errinfo = process_res
if is_skip_process is True:
skip_process_dict[data_id] = True
if prod_errcode is not None:
_LOGGER.error("data_id: {} return product error. Product ErrNo:{}, Product ErrMsg: {}".format(data_id, prod_errcode, prod_errinfo))
_LOGGER.error(
"data_id: {} return product error. Product ErrNo:{}, Product ErrMsg: {}".
format(data_id, prod_errcode, prod_errinfo))
error_channeldata = ChannelData(
error_code=ChannelDataErrcode.PRODUCT_ERROR.value,
error_info="",
......@@ -1162,12 +1172,16 @@ class Op(object):
_LOGGER.debug("{} Running postprocess".format(op_info_prefix))
postped_data_dict = collections.OrderedDict()
err_channeldata_dict = collections.OrderedDict()
@ErrorCatch
def postprocess_help(self, parsed_data_dict, midped_data, data_id, logid_dict):
postped_data, prod_errcode, prod_errinfo = self.postprocess(parsed_data_dict[data_id],
midped_data, data_id, logid_dict.get(data_id))
def postprocess_help(self, parsed_data_dict, midped_data, data_id,
logid_dict):
postped_data, prod_errcode, prod_errinfo = self.postprocess(
parsed_data_dict[data_id], midped_data, data_id,
logid_dict.get(data_id))
if not isinstance(postped_data, dict):
raise CustomException(CustomExceptionCode.TYPE_ERROR, "postprocess should return dict", True)
raise CustomException(CustomExceptionCode.TYPE_ERROR,
"postprocess should return dict", True)
return postped_data, prod_errcode, prod_errinfo
for data_id, midped_data in midped_data_dict.items():
......@@ -1175,8 +1189,12 @@ class Op(object):
postped_data, err_channeldata = None, None
prod_errcode, prod_errinfo = None, None
post_res, resp = postprocess_help(self, parsed_data_dict, midped_data, data_id
= data_id, logid_dict = logid_dict)
post_res, resp = postprocess_help(
self,
parsed_data_dict,
midped_data,
data_id=data_id,
logid_dict=logid_dict)
if resp.err_no == CustomExceptionCode.OK.value:
postped_data, prod_errcode, prod_errinfo = post_res
if prod_errcode is not None:
......@@ -1303,10 +1321,10 @@ class Op(object):
def _run(self, concurrency_idx, input_channel, output_channels,
is_thread_op, trace_buffer, model_config, workdir, thread_num,
device_type, devices, mem_optim, ir_optim, precision,
use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list,
mkldnn_bf16_op_list, is_jump_op, output_channels_of_jump_ops,
min_subgraph_size, dynamic_shape_info, use_calib):
device_type, devices, mem_optim, ir_optim, precision, use_mkldnn,
mkldnn_cache_capacity, mkldnn_op_list, mkldnn_bf16_op_list,
is_jump_op, output_channels_of_jump_ops, min_subgraph_size,
dynamic_shape_info, use_calib):
"""
_run() is the entry function of OP process / thread model.When client
type is local_predictor in process mode, the CUDA environment needs to
......@@ -1344,11 +1362,13 @@ class Op(object):
# init ops
profiler = None
@ErrorCatch
def check_helper(self, is_thread_op, model_config, workdir,
thread_num, device_type, devices, mem_optim, ir_optim,
precision, use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list,
mkldnn_bf16_op_list, min_subgraph_size, dynamic_shape_info):
def check_helper(self, is_thread_op, model_config, workdir, thread_num,
device_type, devices, mem_optim, ir_optim, precision,
use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list,
mkldnn_bf16_op_list, min_subgraph_size,
dynamic_shape_info):
if is_thread_op == False and self.client_type == "local_predictor":
self.service_handler = local_service_handler.LocalServiceHandler(
......@@ -1377,17 +1397,19 @@ class Op(object):
profiler = self._initialize(is_thread_op, concurrency_idx)
return profiler
profiler, resp = check_helper(self, is_thread_op, model_config, workdir,
thread_num, device_type, devices, mem_optim, ir_optim,
precision, use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list,
mkldnn_bf16_op_list, min_subgraph_size, dynamic_shape_info)
profiler, resp = check_helper(
self, is_thread_op, model_config, workdir, thread_num, device_type,
devices, mem_optim, ir_optim, precision, use_mkldnn,
mkldnn_cache_capacity, mkldnn_op_list, mkldnn_bf16_op_list,
min_subgraph_size, dynamic_shape_info)
if resp.err_no != CustomExceptionCode.OK.value:
_LOGGER.critical(
"{} failed to init op: {}".format(op_info_prefix, resp.err_msg),
exc_info=False)
print("{} failed to init op: {}".format(op_info_prefix, resp.err_msg))
print("{} failed to init op: {}".format(op_info_prefix,
resp.err_msg))
kill_stop_process_by_pid("kill", os.getpgid(os.getpid()))
_LOGGER.info("{} Succ init".format(op_info_prefix))
......@@ -1583,6 +1605,7 @@ class Op(object):
Returns:
TimeProfiler
"""
@ErrorCatch
def init_helper(self, is_thread_op, concurrency_idx):
if is_thread_op:
......@@ -1910,8 +1933,8 @@ class VirtualOp(Op):
\-> E ----------/
DAG view: [[A], [B, E], [C], [D], [F]]
BUILD DAG: [A -> B -> C -> D -> E -> F]
\-> E -> V1-> V2-> V3/
BUILD DAG: [A -> B -> C -> D -> F]
\-> E -> V1-> V2->/
"""
def __init__(self, name, concurrency=1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册