提交 85343b20 编写于 作者: B barriery

add Tracer

上级 bd8bc552
...@@ -239,6 +239,9 @@ class ProcessChannel(object): ...@@ -239,6 +239,9 @@ class ProcessChannel(object):
self._base_cursor = manager.Value('i', 0) self._base_cursor = manager.Value('i', 0)
self._output_buf = manager.list() self._output_buf = manager.list()
def size(self):
return self._que.qsize()
def get_producers(self): def get_producers(self):
return self._producers return self._producers
...@@ -530,6 +533,9 @@ class ThreadChannel(Queue.Queue): ...@@ -530,6 +533,9 @@ class ThreadChannel(Queue.Queue):
self._base_cursor = 0 self._base_cursor = 0
self._output_buf = [] self._output_buf = []
def size(self):
return self.qsize()
def get_producers(self): def get_producers(self):
return self._producers return self._producers
......
...@@ -28,7 +28,7 @@ import logging ...@@ -28,7 +28,7 @@ import logging
from .operator import Op, RequestOp, ResponseOp, VirtualOp from .operator import Op, RequestOp, ResponseOp, VirtualOp
from .channel import (ThreadChannel, ProcessChannel, ChannelData, from .channel import (ThreadChannel, ProcessChannel, ChannelData,
ChannelDataEcode, ChannelDataType, ChannelStopError) ChannelDataEcode, ChannelDataType, ChannelStopError)
from .profiler import TimeProfiler from .profiler import TimeProfiler, PerformanceTracer
from .util import NameGenerator from .util import NameGenerator
from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2
...@@ -48,13 +48,18 @@ class DAGExecutor(object): ...@@ -48,13 +48,18 @@ class DAGExecutor(object):
self._profiler = TimeProfiler() self._profiler = TimeProfiler()
self._profiler.enable(True) self._profiler.enable(True)
self._tracer = PerformanceTracer()
self._dag = DAG(self.name, response_op, self._server_use_profile, self._dag = DAG(self.name, response_op, self._server_use_profile,
self._is_thread_op, client_type, channel_size, self._is_thread_op, client_type, channel_size,
build_dag_each_worker) build_dag_each_worker, self._tracer)
(in_channel, out_channel, pack_rpc_func, (in_channel, out_channel, pack_rpc_func,
unpack_rpc_func) = self._dag.build() unpack_rpc_func) = self._dag.build()
self._dag.start() self._dag.start()
self._tracer.set_channels(self._dag.get_channels())
self._tracer.start()
self._set_in_channel(in_channel) self._set_in_channel(in_channel)
self._set_out_channel(out_channel) self._set_out_channel(out_channel)
self._pack_rpc_func = pack_rpc_func self._pack_rpc_func = pack_rpc_func
...@@ -74,6 +79,7 @@ class DAGExecutor(object): ...@@ -74,6 +79,7 @@ class DAGExecutor(object):
def start(self): def start(self):
self._recive_func = threading.Thread( self._recive_func = threading.Thread(
target=DAGExecutor._recive_out_channel_func, args=(self, )) target=DAGExecutor._recive_out_channel_func, args=(self, ))
self._recive_func.daemon = True
self._recive_func.start() self._recive_func.start()
_LOGGER.debug("[DAG Executor] Start recive thread") _LOGGER.debug("[DAG Executor] Start recive thread")
...@@ -205,6 +211,8 @@ class DAGExecutor(object): ...@@ -205,6 +211,8 @@ class DAGExecutor(object):
client_need_profile=client_need_profile) client_need_profile=client_need_profile)
def call(self, rpc_request): def call(self, rpc_request):
data_buffer = self._tracer.data_buffer()
data_id, cond_v = self._get_next_data_id() data_id, cond_v = self._get_next_data_id()
_LOGGER.info("(logid={}) Succ generate id".format(data_id)) _LOGGER.info("(logid={}) Succ generate id".format(data_id))
...@@ -214,6 +222,7 @@ class DAGExecutor(object): ...@@ -214,6 +222,7 @@ class DAGExecutor(object):
data_id, data_id)) data_id, data_id))
else: else:
start_call = self._profiler.record("call_{}#DAG_0".format(data_id)) start_call = self._profiler.record("call_{}#DAG_0".format(data_id))
data_buffer.put(("DAG", "call_{}".format(data_id), 0, start_call))
_LOGGER.debug("(logid={}) Parsing RPC request package".format(data_id)) _LOGGER.debug("(logid={}) Parsing RPC request package".format(data_id))
self._profiler.record("prepack_{}#{}_0".format(data_id, self.name)) self._profiler.record("prepack_{}#{}_0".format(data_id, self.name))
...@@ -262,9 +271,7 @@ class DAGExecutor(object): ...@@ -262,9 +271,7 @@ class DAGExecutor(object):
data_id)) data_id))
else: else:
end_call = self._profiler.record("call_{}#DAG_1".format(data_id)) end_call = self._profiler.record("call_{}#DAG_1".format(data_id))
_LOGGER.log(level=1, data_buffer.put(("DAG", "call_{}".format(data_id), 1, end_call))
msg="(logid={}) call[{} ms]".format(
data_id, (end_call - start_call) / 1e3))
profile_str = self._profiler.gen_profile_str() profile_str = self._profiler.gen_profile_str()
if self._server_use_profile: if self._server_use_profile:
...@@ -297,7 +304,7 @@ class DAGExecutor(object): ...@@ -297,7 +304,7 @@ class DAGExecutor(object):
class DAG(object): class DAG(object):
def __init__(self, request_name, response_op, use_profile, is_thread_op, def __init__(self, request_name, response_op, use_profile, is_thread_op,
client_type, channel_size, build_dag_each_worker): client_type, channel_size, build_dag_each_worker, tracer):
self._request_name = request_name self._request_name = request_name
self._response_op = response_op self._response_op = response_op
self._use_profile = use_profile self._use_profile = use_profile
...@@ -305,6 +312,7 @@ class DAG(object): ...@@ -305,6 +312,7 @@ class DAG(object):
self._channel_size = channel_size self._channel_size = channel_size
self._client_type = client_type self._client_type = client_type
self._build_dag_each_worker = build_dag_each_worker self._build_dag_each_worker = build_dag_each_worker
self._tracer = tracer
if not self._is_thread_op: if not self._is_thread_op:
self._manager = multiprocessing.Manager() self._manager = multiprocessing.Manager()
_LOGGER.info("[DAG] Succ init") _LOGGER.info("[DAG] Succ init")
...@@ -515,6 +523,9 @@ class DAG(object): ...@@ -515,6 +523,9 @@ class DAG(object):
return (actual_ops, channels, input_channel, output_channel, pack_func, return (actual_ops, channels, input_channel, output_channel, pack_func,
unpack_func) unpack_func)
def get_channels(self):
return self._channels
def build(self): def build(self):
(actual_ops, channels, input_channel, output_channel, pack_func, (actual_ops, channels, input_channel, output_channel, pack_func,
unpack_func) = self._build_dag(self._response_op) unpack_func) = self._build_dag(self._response_op)
...@@ -533,6 +544,7 @@ class DAG(object): ...@@ -533,6 +544,7 @@ class DAG(object):
self._threads_or_proces = [] self._threads_or_proces = []
for op in self._actual_ops: for op in self._actual_ops:
op.use_profiler(self._use_profile) op.use_profiler(self._use_profile)
op.set_tracer(self._tracer)
if self._is_thread_op: if self._is_thread_op:
self._threads_or_proces.extend( self._threads_or_proces.extend(
op.start_with_thread(self._client_type)) op.start_with_thread(self._client_type))
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from time import time as _time from time import time as _time
import time
import threading import threading
import multiprocessing import multiprocessing
from paddle_serving_client import MultiLangClient, Client from paddle_serving_client import MultiLangClient, Client
...@@ -97,6 +98,7 @@ class Op(object): ...@@ -97,6 +98,7 @@ class Op(object):
self._batch_size, self._auto_batching_timeout))) self._batch_size, self._auto_batching_timeout)))
self._server_use_profile = False self._server_use_profile = False
self._tracer = None
# only for thread op # only for thread op
self._for_init_op_lock = threading.Lock() self._for_init_op_lock = threading.Lock()
...@@ -118,6 +120,9 @@ class Op(object): ...@@ -118,6 +120,9 @@ class Op(object):
def use_profiler(self, use_profile): def use_profiler(self, use_profile):
self._server_use_profile = use_profile self._server_use_profile = use_profile
def set_tracer(self, tracer):
self._tracer = tracer
def init_client(self, client_type, client_config, server_endpoints, def init_client(self, client_type, client_config, server_endpoints,
fetch_names): fetch_names):
if self.with_serving == False: if self.with_serving == False:
...@@ -256,7 +261,9 @@ class Op(object): ...@@ -256,7 +261,9 @@ class Op(object):
p = multiprocessing.Process( p = multiprocessing.Process(
target=self._run, target=self._run,
args=(concurrency_idx, self._get_input_channel(), args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type, False)) self._get_output_channels(), client_type, False,
self._tracer.data_buffer()))
p.daemon = True
p.start() p.start()
proces.append(p) proces.append(p)
return proces return proces
...@@ -267,7 +274,8 @@ class Op(object): ...@@ -267,7 +274,8 @@ class Op(object):
t = threading.Thread( t = threading.Thread(
target=self._run, target=self._run,
args=(concurrency_idx, self._get_input_channel(), args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type, True)) self._get_output_channels(), client_type, True,
self._tracer.data_buffer()))
# When a process exits, it attempts to terminate # When a process exits, it attempts to terminate
# all of its daemonic child processes. # all of its daemonic child processes.
t.daemon = True t.daemon = True
...@@ -482,7 +490,7 @@ class Op(object): ...@@ -482,7 +490,7 @@ class Op(object):
return parsed_data_dict, need_profile_dict, profile_dict return parsed_data_dict, need_profile_dict, profile_dict
def _run(self, concurrency_idx, input_channel, output_channels, client_type, def _run(self, concurrency_idx, input_channel, output_channels, client_type,
is_thread_op): is_thread_op, trace_buffer):
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
tid = threading.current_thread().ident tid = threading.current_thread().ident
...@@ -505,16 +513,18 @@ class Op(object): ...@@ -505,16 +513,18 @@ class Op(object):
timeout=self._auto_batching_timeout, timeout=self._auto_batching_timeout,
op_info_prefix=op_info_prefix) op_info_prefix=op_info_prefix)
start_prep, end_prep = None, None start, end = None, None
start_midp, end_midp = None, None
start_postp, end_postp = None, None
while True: while True:
start = int(round(_time() * 1000000))
trace_buffer.put((self.name, "in", 0, start))
try: try:
channeldata_dict_batch = next(batch_generator) channeldata_dict_batch = next(batch_generator)
except ChannelStopError: except ChannelStopError:
_LOGGER.debug("{} Stop.".format(op_info_prefix)) _LOGGER.debug("{} Stop.".format(op_info_prefix))
self._finalize(is_thread_op) self._finalize(is_thread_op)
break break
end = int(round(_time() * 1000000))
trace_buffer.put((self.name, "in", 1, end))
# parse channeldata batch # parse channeldata batch
try: try:
...@@ -530,14 +540,12 @@ class Op(object): ...@@ -530,14 +540,12 @@ class Op(object):
continue continue
# preprecess # preprecess
start_prep = profiler.record("prep#{}_0".format(op_info_prefix)) start = profiler.record("prep#{}_0".format(op_info_prefix))
trace_buffer.put((self.name, "prep", 0, start))
preped_data_dict, err_channeldata_dict \ preped_data_dict, err_channeldata_dict \
= self._run_preprocess(parsed_data_dict, op_info_prefix) = self._run_preprocess(parsed_data_dict, op_info_prefix)
end_prep = profiler.record("prep#{}_1".format(op_info_prefix)) end = profiler.record("prep#{}_1".format(op_info_prefix))
_LOGGER.log(level=1, trace_buffer.put((self.name, "prep", 1, end))
msg="(logid={}) {} prep[{} ms]".format(
parsed_data_dict.keys(), op_info_prefix,
(end_prep - start_prep) / 1e3))
try: try:
for data_id, err_channeldata in err_channeldata_dict.items(): for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels( self._push_to_output_channels(
...@@ -553,14 +561,12 @@ class Op(object): ...@@ -553,14 +561,12 @@ class Op(object):
continue continue
# process # process
start_midp = profiler.record("midp#{}_0".format(op_info_prefix)) start = profiler.record("midp#{}_0".format(op_info_prefix))
trace_buffer.put((self.name, "midp", 0, start))
midped_data_dict, err_channeldata_dict \ midped_data_dict, err_channeldata_dict \
= self._run_process(preped_data_dict, op_info_prefix) = self._run_process(preped_data_dict, op_info_prefix)
end_midp = profiler.record("midp#{}_1".format(op_info_prefix)) end = profiler.record("midp#{}_1".format(op_info_prefix))
_LOGGER.log(level=1, trace_buffer.put((self.name, "midp", 1, end))
msg="(logid={}) {} midp[{} ms]".format(
preped_data_dict.keys(), op_info_prefix,
(end_midp - start_midp) / 1e3))
try: try:
for data_id, err_channeldata in err_channeldata_dict.items(): for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels( self._push_to_output_channels(
...@@ -576,15 +582,13 @@ class Op(object): ...@@ -576,15 +582,13 @@ class Op(object):
continue continue
# postprocess # postprocess
start_postp = profiler.record("postp#{}_0".format(op_info_prefix)) start = profiler.record("postp#{}_0".format(op_info_prefix))
trace_buffer.put((self.name, "postp", 0, start))
postped_data_dict, err_channeldata_dict \ postped_data_dict, err_channeldata_dict \
= self._run_postprocess( = self._run_postprocess(
parsed_data_dict, midped_data_dict, op_info_prefix) parsed_data_dict, midped_data_dict, op_info_prefix)
end_postp = profiler.record("postp#{}_1".format(op_info_prefix)) end = profiler.record("postp#{}_1".format(op_info_prefix))
_LOGGER.log(level=1, trace_buffer.put((self.name, "postp", 1, end))
msg="(logid={}) {} postp[{} ms]".format(
midped_data_dict.keys(), op_info_prefix,
(end_midp - start_midp) / 1e3))
try: try:
for data_id, err_channeldata in err_channeldata_dict.items(): for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels( self._push_to_output_channels(
...@@ -600,6 +604,8 @@ class Op(object): ...@@ -600,6 +604,8 @@ class Op(object):
continue continue
# push data to channel (if run succ) # push data to channel (if run succ)
start = int(round(_time() * 1000000))
trace_buffer.put((self.name, "out", 0, start))
try: try:
profile_str = profiler.gen_profile_str() profile_str = profiler.gen_profile_str()
for data_id, postped_data in postped_data_dict.items(): for data_id, postped_data in postped_data_dict.items():
...@@ -615,6 +621,8 @@ class Op(object): ...@@ -615,6 +621,8 @@ class Op(object):
_LOGGER.debug("{} Stop.".format(op_info_prefix)) _LOGGER.debug("{} Stop.".format(op_info_prefix))
self._finalize(is_thread_op) self._finalize(is_thread_op)
break break
end = int(round(_time() * 1000000))
trace_buffer.put((self.name, "out", 1, end))
def _initialize(self, is_thread_op, client_type, concurrency_idx): def _initialize(self, is_thread_op, client_type, concurrency_idx):
if is_thread_op: if is_thread_op:
......
...@@ -41,9 +41,6 @@ class PipelineServicer(pipeline_service_pb2_grpc.PipelineServiceServicer): ...@@ -41,9 +41,6 @@ class PipelineServicer(pipeline_service_pb2_grpc.PipelineServiceServicer):
resp = self._dag_executor.call(request) resp = self._dag_executor.call(request)
return resp return resp
def __del__(self):
self._dag_executor.stop()
@contextlib.contextmanager @contextlib.contextmanager
def _reserve_port(port): def _reserve_port(port):
......
...@@ -23,11 +23,85 @@ elif sys.version_info.major == 3: ...@@ -23,11 +23,85 @@ elif sys.version_info.major == 3:
else: else:
raise Exception("Error Python version") raise Exception("Error Python version")
from time import time as _time from time import time as _time
import time
import threading import threading
import multiprocessing
_LOGGER = logging.getLogger() _LOGGER = logging.getLogger()
class PerformanceTracer(object):
def __init__(self, interval_s=1):
self._data_buffer = multiprocessing.Manager().Queue()
self._interval_s = interval_s
self._proc = None
self._channels = []
self._trace_filename = os.path.join("PipelineServingLogs", "INDEX.log")
def data_buffer(self):
return self._data_buffer
def start(self):
self._proc = multiprocessing.Process(
target=self._trace_func, args=(self._channels, ))
self._proc.daemon = True
self._proc.start()
def set_channels(self, channels):
self._channels = channels
def _trace_func(self, channels):
trace_file = open(self._trace_filename, "a")
actions = ["prep", "midp", "postp"]
tag_dict = {}
while True:
op_cost = {}
trace_file.write("==========================")
# op
while not self._data_buffer.empty():
name, action, stage, timestamp = self._data_buffer.get()
tag = "{}_{}".format(name, action)
if tag in tag_dict:
assert stage == 1
start_timestamp = tag_dict.pop(tag)
cost = timestamp - start_timestamp
if name not in op_cost:
op_cost[name] = {}
if action not in op_cost[name]:
op_cost[name][action] = []
op_cost[name][action].append(cost)
else:
assert stage == 0
tag_dict[tag] = timestamp
for name in op_cost:
tot_cost, cal_cost = 0.0, 0.0
for action, costs in op_cost[name].items():
op_cost[name][action] = sum(costs) / (1e3 * len(costs))
tot_cost += op_cost[name][action]
msg = ", ".join([
"{}[{} ms]".format(action, cost)
for action, cost in op_cost[name].items()
])
for action in actions:
if action in op_cost[name]:
cal_cost += op_cost[name][action]
trace_file.write("Op({}) {}".format(name, msg))
if name != "DAG":
trace_file.write("Op({}) idle[{}]".format(
name, 1 - 1.0 * cal_cost / tot_cost))
# channel
for channel in channels:
trace_file.write("Channel({}) size[{}]".format(channel.name,
channel.size()))
time.sleep(self._interval_s)
class UnsafeTimeProfiler(object): class UnsafeTimeProfiler(object):
""" thread unsafe profiler """ """ thread unsafe profiler """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册