提交 ccaccc39 编写于 作者: B barrierye

update profiler for process

上级 c64f6ddd
......@@ -40,18 +40,26 @@ class DAGExecutor(object):
use_multithread = yml_config.get('use_multithread', True)
use_profile = yml_config.get('profile', False)
channel_size = yml_config.get('channel_size', 0)
self._asyn_profile = yml_config.get('asgn_profile', False)
if not use_multithread:
if use_profile:
raise Exception(
"profile cannot be used in multiprocess version temporarily")
if use_profile:
_LOGGER.info("====> profiler <====")
if use_multithread:
_LOGGER.info("op: thread")
else:
_LOGGER.info("op: process")
if self._asyn_profile:
_LOGGER.info("profile mode: asyn")
else:
_LOGGER.info("profile mode: sync")
_LOGGER.info("====================")
self.name = "@G"
self._profiler = TimeProfiler()
self._profiler.enable(use_profile)
self._dag = DAG(response_op, self._profiler, use_multithread,
client_type, channel_size)
self._dag = DAG(response_op, self._profiler, use_profile,
use_multithread, client_type, channel_size)
(in_channel, out_channel, pack_rpc_func,
unpack_rpc_func) = self._dag.build()
self._dag.start()
......@@ -131,9 +139,8 @@ class DAGExecutor(object):
self._cv.notify_all()
return resp
def _pack_channeldata(self, rpc_request):
def _pack_channeldata(self, rpc_request, data_id):
_LOGGER.debug(self._log('start inferce'))
data_id = self._get_next_data_id()
dictdata = None
try:
dictdata = self._unpack_rpc_func(rpc_request)
......@@ -141,31 +148,35 @@ class DAGExecutor(object):
return ChannelData(
ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
error_info="rpc package error: {}".format(e),
data_id=data_id), data_id
data_id=data_id)
else:
return ChannelData(
datatype=ChannelDataType.DICT.value,
dictdata=dictdata,
data_id=data_id), data_id
data_id=data_id)
def call(self, rpc_request):
self._profiler.record("call#DAG_0")
data_id = self._get_next_data_id()
if self._asyn_profile:
self._profiler.record("call_{}#DAG-{}_0".format(data_id, data_id))
else:
self._profiler.record("call_{}#DAG_0".format(data_id))
self._profiler.record("prepack#{}_0".format(self.name))
req_channeldata, data_id = self._pack_channeldata(rpc_request)
self._profiler.record("prepack#{}_1".format(self.name))
self._profiler.record("prepack_{}#{}_0".format(data_id, self.name))
req_channeldata = self._pack_channeldata(rpc_request, data_id)
self._profiler.record("prepack_{}#{}_1".format(data_id, self.name))
resp_channeldata = None
for i in range(self._retry):
_LOGGER.debug(self._log('push data'))
#self._profiler.record("push#{}_0".format(self.name))
#self._profiler.record("push_{}#{}_0".format(data_id, self.name))
self._in_channel.push(req_channeldata, self.name)
#self._profiler.record("push#{}_1".format(self.name))
#self._profiler.record("push_{}#{}_1".format(data_id, self.name))
_LOGGER.debug(self._log('wait for infer'))
#self._profiler.record("fetch#{}_0".format(self.name))
#self._profiler.record("fetch_{}#{}_0".format(data_id, self.name))
resp_channeldata = self._get_channeldata_from_fetch_buffer(data_id)
#self._profiler.record("fetch#{}_1".format(self.name))
#self._profiler.record("fetch_{}#{}_1".format(data_id, self.name))
if resp_channeldata.ecode == ChannelDataEcode.OK.value:
break
......@@ -173,11 +184,14 @@ class DAGExecutor(object):
_LOGGER.warn("retry({}): {}".format(
i + 1, resp_channeldata.error_info))
self._profiler.record("postpack#{}_0".format(self.name))
self._profiler.record("postpack_{}#{}_0".format(data_id, self.name))
rpc_resp = self._pack_for_rpc_resp(resp_channeldata)
self._profiler.record("postpack#{}_1".format(self.name))
self._profiler.record("postpack_{}#{}_1".format(data_id, self.name))
self._profiler.record("call#DAG_1")
if self._asyn_profile:
self._profiler.record("call_{}#DAG-{}_1".format(data_id, data_id))
else:
self._profiler.record("call_{}#DAG_1".format(data_id))
self._profiler.print_profile()
return rpc_resp
......@@ -190,9 +204,10 @@ class DAGExecutor(object):
class DAG(object):
def __init__(self, response_op, profiler, use_multithread, client_type,
channel_size):
def __init__(self, response_op, profiler, use_profile, use_multithread,
client_type, channel_size):
self._response_op = response_op
self._use_profile = use_profile
self._use_multithread = use_multithread
self._channel_size = channel_size
self._client_type = client_type
......@@ -398,7 +413,7 @@ class DAG(object):
def start(self):
self._threads_or_proces = []
for op in self._actual_ops:
op.init_profiler(self._profiler)
op.init_profiler(self._profiler, self._use_profile)
if self._use_multithread:
self._threads_or_proces.extend(
op.start_with_thread(self._client_type))
......
......@@ -65,8 +65,9 @@ class Op(object):
self._for_init_op_lock = threading.Lock()
self._succ_init_op = False
def init_profiler(self, profiler):
def init_profiler(self, profiler, use_profile):
self._profiler = profiler
self._use_profile = use_profile
def _profiler_record(self, string):
if self._profiler is None:
......@@ -349,6 +350,11 @@ class Op(object):
_LOGGER.error(log(e))
os._exit(-1)
# init profiler
if not use_multithread:
self._profiler = TimeProfiler()
self._profiler.enable(self._use_profile)
self._is_run = True
while self._is_run:
#self._profiler_record("get#{}_0".format(op_info_prefix))
......@@ -398,6 +404,7 @@ class Op(object):
#self._profiler_record("push#{}_0".format(op_info_prefix))
self._push_to_output_channels(output_data, output_channels)
#self._profiler_record("push#{}_1".format(op_info_prefix))
self._profiler.print_profile()
def _log(self, info):
return "{} {}".format(self.name, info)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册