提交 19dea60d 编写于 作者: T TeslaZhao

fix one bad explaination

上级 dc512ce7
......@@ -35,12 +35,12 @@ else:
raise Exception("Error Python version")
from .error_catch import ErrorCatch, CustomException, CustomExceptionCode, ParamChecker, ParamVerify
check_feed_dict=ParamVerify.check_feed_dict
check_fetch_list=ParamVerify.check_fetch_list
check_feed_dict = ParamVerify.check_feed_dict
check_fetch_list = ParamVerify.check_fetch_list
from .proto import pipeline_service_pb2
from .channel import (ThreadChannel, ProcessChannel,ChannelData,
from .channel import (ThreadChannel, ProcessChannel, ChannelData,
ChannelDataType, ChannelStopError, ChannelTimeoutError)
from .error_catch import ProductErrCode
from .error_catch import ProductErrCode
from .error_catch import CustomExceptionCode as ChannelDataErrcode
from .util import NameGenerator
from .profiler import UnsafeTimeProfiler as TimeProfiler
......@@ -119,9 +119,9 @@ class Op(object):
self._for_close_op_lock = threading.Lock()
self._succ_init_op = False
self._succ_close_op = False
self.dynamic_shape_info = {}
self.dynamic_shape_info = {}
self.set_dynamic_shape_info()
def set_dynamic_shape_info(self):
"""
when opening tensorrt(configure in config.yml) and each time the input shape
......@@ -141,7 +141,6 @@ class Op(object):
feed_names = client.feed_names_
fetch_names = client.fetch_names_
return feed_names, fetch_names
def init_from_dict(self, conf):
"""
......@@ -164,9 +163,10 @@ class Op(object):
if self._client_config is None:
self._client_config = conf.get("client_config")
if self._use_encryption_model is None:
print ("config use_encryption model here", conf.get("use_encryption_model"))
print("config use_encryption model here",
conf.get("use_encryption_model"))
self._use_encryption_model = conf.get("use_encryption_model")
if self._encryption_key is None or self._encryption_key=="":
if self._encryption_key is None or self._encryption_key == "":
self._encryption_key = conf.get("encryption_key")
if self._timeout is None:
self._timeout = conf["timeout"]
......@@ -401,14 +401,16 @@ class Op(object):
if self.client_type == 'brpc':
client = Client()
client.load_client_config(client_config)
self.right_feed_names, self.right_fetch_names = self.get_feed_fetch_list(client)
self.right_feed_names, self.right_fetch_names = self.get_feed_fetch_list(
client)
elif self.client_type == 'pipeline_grpc':
client = PPClient()
elif self.client_type == 'local_predictor':
if self.local_predictor is None:
raise ValueError("local predictor not yet created")
client = self.local_predictor
self.right_feed_names, self.right_fetch_names = self.get_feed_fetch_list(client)
self.right_feed_names, self.right_fetch_names = self.get_feed_fetch_list(
client)
else:
raise ValueError("Failed to init client: unknow client "
"type {}".format(self.client_type))
......@@ -417,12 +419,13 @@ class Op(object):
_LOGGER.info("Op({}) has no fetch name set. So fetch all vars")
if self.client_type != "local_predictor":
if self._use_encryption_model is None or self._use_encryption_model is False:
client.connect(server_endpoints)
client.connect(server_endpoints)
else:
print("connect to encryption rpc client")
client.use_key(self._encryption_key)
client.connect(server_endpoints, encryption=True)
_LOGGER.info("init_client, feed_list:{}, fetch_list: {}".format(self.right_feed_names, self.right_fetch_names))
print("connect to encryption rpc client")
client.use_key(self._encryption_key)
client.connect(server_endpoints, encryption=True)
_LOGGER.info("init_client, feed_list:{}, fetch_list: {}".format(
self.right_feed_names, self.right_fetch_names))
return client
def get_input_ops(self):
......@@ -599,7 +602,7 @@ class Op(object):
(_, input_dict), = input_dicts.items()
return input_dict, False, None, ""
def process(self, feed_batch, typical_logid=0):
"""
In process stage, send requests to the inference server or predict locally.
......@@ -616,19 +619,23 @@ class Op(object):
call_result = None
err_code = ChannelDataErrcode.OK.value
err_info = ""
@ErrorCatch
@ErrorCatch
@ParamChecker
def feed_fetch_list_check_helper(feed_batch : lambda feed_batch: check_feed_dict(feed_batch[0], self.right_feed_names),
fetch_list : lambda fetch_list: check_fetch_list(fetch_list, self.right_fetch_names),
log_id):
def feed_fetch_list_check_helper(
feed_batch: lambda feed_batch: check_feed_dict(feed_batch[0], self.right_feed_names),
fetch_list: lambda fetch_list: check_fetch_list(fetch_list, self.right_fetch_names),
log_id):
return None
_, resp = feed_fetch_list_check_helper(feed_batch, self._fetch_names, log_id=typical_logid)
_, resp = feed_fetch_list_check_helper(
feed_batch, self._fetch_names, log_id=typical_logid)
if resp.err_no != CustomExceptionCode.OK.value:
err_code = resp.err_no
err_info = resp.err_msg
call_result = None
return call_result, err_code, err_info
if self.client_type == "local_predictor":
err, err_info = ChannelData.check_batch_npdata(feed_batch)
if err != 0:
......@@ -804,7 +811,7 @@ class Op(object):
self.mkldnn_cache_capacity, self.mkldnn_op_list,
self.mkldnn_bf16_op_list, self.is_jump_op(),
self.get_output_channels_of_jump_ops(),
self.min_subgraph_size, self.dynamic_shape_info,
self.min_subgraph_size, self.dynamic_shape_info,
self.use_calib))
p.daemon = True
p.start()
......@@ -839,9 +846,9 @@ class Op(object):
self._get_output_channels(), True, trace_buffer,
self.model_config, self.workdir, self.thread_num,
self.device_type, self.devices, self.mem_optim,
self.ir_optim, self.precision, self.use_mkldnn,
self.mkldnn_cache_capacity, self.mkldnn_op_list,
self.mkldnn_bf16_op_list, self.is_jump_op(),
self.ir_optim, self.precision, self.use_mkldnn,
self.mkldnn_cache_capacity, self.mkldnn_op_list,
self.mkldnn_bf16_op_list, self.is_jump_op(),
self.get_output_channels_of_jump_ops(),
self.min_subgraph_size, self.dynamic_shape_info,
self.use_calib))
......@@ -873,40 +880,43 @@ class Op(object):
preped_data_dict = collections.OrderedDict()
err_channeldata_dict = collections.OrderedDict()
skip_process_dict = {}
@ErrorCatch
def preprocess_help(self, parsed_data, data_id, logid_dict):
preped_data, is_skip_process, prod_errcode, prod_errinfo = self.preprocess(
parsed_data, data_id, logid_dict.get(data_id))
return preped_data, is_skip_process, prod_errcode, prod_errinfo
for data_id, parsed_data in parsed_data_dict.items():
preped_data, error_channeldata = None, None
is_skip_process = False
prod_errcode, prod_errinfo = None, None
log_id = logid_dict.get(data_id)
process_res, resp = preprocess_help(self, parsed_data, data_id = data_id,
logid_dict = logid_dict)
process_res, resp = preprocess_help(
self, parsed_data, data_id=data_id, logid_dict=logid_dict)
if resp.err_no == CustomExceptionCode.OK.value:
preped_data, is_skip_process, prod_errcode, prod_errinfo = process_res
if is_skip_process is True:
skip_process_dict[data_id] = True
if prod_errcode is not None:
_LOGGER.error("data_id: {} return product error. Product ErrNo:{}, Product ErrMsg: {}".format(data_id, prod_errcode, prod_errinfo))
_LOGGER.error(
"data_id: {} return product error. Product ErrNo:{}, Product ErrMsg: {}".
format(data_id, prod_errcode, prod_errinfo))
error_channeldata = ChannelData(
error_code=ChannelDataErrcode.PRODUCT_ERROR.value,
error_info="",
prod_error_code=prod_errcode,
prod_error_info=prod_errinfo,
data_id=data_id,
log_id=log_id)
error_code=ChannelDataErrcode.PRODUCT_ERROR.value,
error_info="",
prod_error_code=prod_errcode,
prod_error_info=prod_errinfo,
data_id=data_id,
log_id=log_id)
else:
error_channeldata = ChannelData(
error_code=resp.err_no,
error_info=resp.err_msg,
data_id=data_id,
log_id=log_id)
skip_process_dict[data_id] = True
error_code=resp.err_no,
error_info=resp.err_msg,
data_id=data_id,
log_id=log_id)
skip_process_dict[data_id] = True
if error_channeldata is not None:
err_channeldata_dict[data_id] = error_channeldata
......@@ -1086,8 +1096,8 @@ class Op(object):
# 2 kinds of errors
if error_code != ChannelDataErrcode.OK.value or midped_batch is None:
error_info = "[{}] failed to predict. {}. Please check the input dict and checkout PipelineServingLogs/pipeline.log for more details.".format(
self.name, error_info)
self.name, error_info)
_LOGGER.error(error_info)
for data_id in data_ids:
err_channeldata_dict[data_id] = ChannelData(
......@@ -1162,12 +1172,16 @@ class Op(object):
_LOGGER.debug("{} Running postprocess".format(op_info_prefix))
postped_data_dict = collections.OrderedDict()
err_channeldata_dict = collections.OrderedDict()
@ErrorCatch
def postprocess_help(self, parsed_data_dict, midped_data, data_id, logid_dict):
postped_data, prod_errcode, prod_errinfo = self.postprocess(parsed_data_dict[data_id],
midped_data, data_id, logid_dict.get(data_id))
def postprocess_help(self, parsed_data_dict, midped_data, data_id,
logid_dict):
postped_data, prod_errcode, prod_errinfo = self.postprocess(
parsed_data_dict[data_id], midped_data, data_id,
logid_dict.get(data_id))
if not isinstance(postped_data, dict):
raise CustomException(CustomExceptionCode.TYPE_ERROR, "postprocess should return dict", True)
raise CustomException(CustomExceptionCode.TYPE_ERROR,
"postprocess should return dict", True)
return postped_data, prod_errcode, prod_errinfo
for data_id, midped_data in midped_data_dict.items():
......@@ -1175,19 +1189,23 @@ class Op(object):
postped_data, err_channeldata = None, None
prod_errcode, prod_errinfo = None, None
post_res, resp = postprocess_help(self, parsed_data_dict, midped_data, data_id
= data_id, logid_dict = logid_dict)
post_res, resp = postprocess_help(
self,
parsed_data_dict,
midped_data,
data_id=data_id,
logid_dict=logid_dict)
if resp.err_no == CustomExceptionCode.OK.value:
postped_data, prod_errcode, prod_errinfo = post_res
if prod_errcode is not None:
# product errors occured
# product errors occured
err_channeldata = ChannelData(
error_code=ChannelDataErrcode.PRODUCT_ERROR.value,
error_info="",
prod_error_code=prod_errcode,
prod_error_info=prod_errinfo,
data_id=data_id,
log_id=log_id)
error_code=ChannelDataErrcode.PRODUCT_ERROR.value,
error_info="",
prod_error_code=prod_errcode,
prod_error_info=prod_errinfo,
data_id=data_id,
log_id=log_id)
else:
err_channeldata = ChannelData(
error_code=resp.err_no,
......@@ -1203,16 +1221,16 @@ class Op(object):
err, _ = ChannelData.check_npdata(postped_data)
if err == 0:
output_data = ChannelData(
ChannelDataType.CHANNEL_NPDATA.value,
npdata=postped_data,
data_id=data_id,
log_id=log_id)
ChannelDataType.CHANNEL_NPDATA.value,
npdata=postped_data,
data_id=data_id,
log_id=log_id)
else:
output_data = ChannelData(
ChannelDataType.DICT.value,
dictdata=postped_data,
data_id=data_id,
log_id=log_id)
ChannelDataType.DICT.value,
dictdata=postped_data,
data_id=data_id,
log_id=log_id)
postped_data_dict[data_id] = output_data
_LOGGER.debug("{} Succ postprocess".format(op_info_prefix))
return postped_data_dict, err_channeldata_dict
......@@ -1303,10 +1321,10 @@ class Op(object):
def _run(self, concurrency_idx, input_channel, output_channels,
is_thread_op, trace_buffer, model_config, workdir, thread_num,
device_type, devices, mem_optim, ir_optim, precision,
use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list,
mkldnn_bf16_op_list, is_jump_op, output_channels_of_jump_ops,
min_subgraph_size, dynamic_shape_info, use_calib):
device_type, devices, mem_optim, ir_optim, precision, use_mkldnn,
mkldnn_cache_capacity, mkldnn_op_list, mkldnn_bf16_op_list,
is_jump_op, output_channels_of_jump_ops, min_subgraph_size,
dynamic_shape_info, use_calib):
"""
_run() is the entry function of OP process / thread model.When client
type is local_predictor in process mode, the CUDA environment needs to
......@@ -1344,12 +1362,14 @@ class Op(object):
# init ops
profiler = None
@ErrorCatch
def check_helper(self, is_thread_op, model_config, workdir,
thread_num, device_type, devices, mem_optim, ir_optim,
precision, use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list,
mkldnn_bf16_op_list, min_subgraph_size, dynamic_shape_info):
def check_helper(self, is_thread_op, model_config, workdir, thread_num,
device_type, devices, mem_optim, ir_optim, precision,
use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list,
mkldnn_bf16_op_list, min_subgraph_size,
dynamic_shape_info):
if is_thread_op == False and self.client_type == "local_predictor":
self.service_handler = local_service_handler.LocalServiceHandler(
model_config=model_config,
......@@ -1377,17 +1397,19 @@ class Op(object):
profiler = self._initialize(is_thread_op, concurrency_idx)
return profiler
profiler, resp = check_helper(self, is_thread_op, model_config, workdir,
thread_num, device_type, devices, mem_optim, ir_optim,
precision, use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list,
mkldnn_bf16_op_list, min_subgraph_size, dynamic_shape_info)
profiler, resp = check_helper(
self, is_thread_op, model_config, workdir, thread_num, device_type,
devices, mem_optim, ir_optim, precision, use_mkldnn,
mkldnn_cache_capacity, mkldnn_op_list, mkldnn_bf16_op_list,
min_subgraph_size, dynamic_shape_info)
if resp.err_no != CustomExceptionCode.OK.value:
_LOGGER.critical(
"{} failed to init op: {}".format(op_info_prefix, resp.err_msg),
exc_info=False)
print("{} failed to init op: {}".format(op_info_prefix, resp.err_msg))
print("{} failed to init op: {}".format(op_info_prefix,
resp.err_msg))
kill_stop_process_by_pid("kill", os.getpgid(os.getpid()))
_LOGGER.info("{} Succ init".format(op_info_prefix))
......@@ -1583,6 +1605,7 @@ class Op(object):
Returns:
TimeProfiler
"""
@ErrorCatch
def init_helper(self, is_thread_op, concurrency_idx):
if is_thread_op:
......@@ -1592,7 +1615,7 @@ class Op(object):
self.concurrency_idx = None
# init client
self.client = self.init_client(self._client_config,
self._server_endpoints)
self._server_endpoints)
# user defined
self.init_op()
self._succ_init_op = True
......@@ -1601,10 +1624,10 @@ class Op(object):
self.concurrency_idx = concurrency_idx
# init client
self.client = self.init_client(self._client_config,
self._server_endpoints)
self._server_endpoints)
# user defined
self.init_op()
self.init_op()
init_helper(self, is_thread_op, concurrency_idx)
print("[OP Object] init success")
# use a separate TimeProfiler per thread or process
......@@ -1910,8 +1933,8 @@ class VirtualOp(Op):
\-> E ----------/
DAG view: [[A], [B, E], [C], [D], [F]]
BUILD DAG: [A -> B -> C -> D -> E -> F]
\-> E -> V1-> V2-> V3/
BUILD DAG: [A -> B -> C -> D -> F]
\-> E -> V1-> V2->/
"""
def __init__(self, name, concurrency=1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册