提交 47d71a5c 编写于 作者: B barriery

update tracer

上级 85343b20
...@@ -239,6 +239,9 @@ class ProcessChannel(object): ...@@ -239,6 +239,9 @@ class ProcessChannel(object):
self._base_cursor = manager.Value('i', 0) self._base_cursor = manager.Value('i', 0)
self._output_buf = manager.list() self._output_buf = manager.list()
def get_maxsize(self):
return self._maxsize
def size(self): def size(self):
return self._que.qsize() return self._que.qsize()
...@@ -533,6 +536,9 @@ class ThreadChannel(Queue.Queue): ...@@ -533,6 +536,9 @@ class ThreadChannel(Queue.Queue):
self._base_cursor = 0 self._base_cursor = 0
self._output_buf = [] self._output_buf = []
def get_maxsize(self):
return self._maxsize
def size(self): def size(self):
return self.qsize() return self.qsize()
......
...@@ -36,19 +36,28 @@ _LOGGER = logging.getLogger() ...@@ -36,19 +36,28 @@ _LOGGER = logging.getLogger()
class DAGExecutor(object): class DAGExecutor(object):
def __init__(self, response_op, dag_conf): def __init__(self, response_op, server_conf):
build_dag_each_worker = server_conf["build_dag_each_worker"]
server_worker_num = server_conf["worker_num"]
dag_conf = server_conf["dag"]
self._retry = dag_conf["retry"] self._retry = dag_conf["retry"]
client_type = dag_conf["client_type"] client_type = dag_conf["client_type"]
self._server_use_profile = dag_conf["use_profile"] self._server_use_profile = dag_conf["use_profile"]
channel_size = dag_conf["channel_size"] channel_size = dag_conf["channel_size"]
self._is_thread_op = dag_conf["is_thread_op"] self._is_thread_op = dag_conf["is_thread_op"]
build_dag_each_worker = dag_conf["build_dag_each_worker"]
self.name = "@G" tracer_conf = dag_conf["tracer"]
tracer_interval_s = tracer_conf["interval_s"]
self.name = "@DAGExecutor"
self._profiler = TimeProfiler() self._profiler = TimeProfiler()
self._profiler.enable(True) self._profiler.enable(True)
self._tracer = PerformanceTracer() self._tracer = None
if tracer_interval_s >= 1:
self._tracer = PerformanceTracer(
self._is_thread_op, tracer_interval_s, server_worker_num)
self._dag = DAG(self.name, response_op, self._server_use_profile, self._dag = DAG(self.name, response_op, self._server_use_profile,
self._is_thread_op, client_type, channel_size, self._is_thread_op, client_type, channel_size,
...@@ -57,14 +66,14 @@ class DAGExecutor(object): ...@@ -57,14 +66,14 @@ class DAGExecutor(object):
unpack_rpc_func) = self._dag.build() unpack_rpc_func) = self._dag.build()
self._dag.start() self._dag.start()
self._tracer.set_channels(self._dag.get_channels())
self._tracer.start()
self._set_in_channel(in_channel) self._set_in_channel(in_channel)
self._set_out_channel(out_channel) self._set_out_channel(out_channel)
self._pack_rpc_func = pack_rpc_func self._pack_rpc_func = pack_rpc_func
self._unpack_rpc_func = unpack_rpc_func self._unpack_rpc_func = unpack_rpc_func
if self._tracer is not None:
self._tracer.start()
self._id_lock = threading.Lock() self._id_lock = threading.Lock()
self._id_counter = 0 self._id_counter = 0
self._reset_max_id = 1000000000000000000 self._reset_max_id = 1000000000000000000
...@@ -211,7 +220,8 @@ class DAGExecutor(object): ...@@ -211,7 +220,8 @@ class DAGExecutor(object):
client_need_profile=client_need_profile) client_need_profile=client_need_profile)
def call(self, rpc_request): def call(self, rpc_request):
data_buffer = self._tracer.data_buffer() if self._tracer is not None:
trace_buffer = self._tracer.data_buffer()
data_id, cond_v = self._get_next_data_id() data_id, cond_v = self._get_next_data_id()
_LOGGER.info("(logid={}) Succ generate id".format(data_id)) _LOGGER.info("(logid={}) Succ generate id".format(data_id))
...@@ -222,7 +232,6 @@ class DAGExecutor(object): ...@@ -222,7 +232,6 @@ class DAGExecutor(object):
data_id, data_id)) data_id, data_id))
else: else:
start_call = self._profiler.record("call_{}#DAG_0".format(data_id)) start_call = self._profiler.record("call_{}#DAG_0".format(data_id))
data_buffer.put(("DAG", "call_{}".format(data_id), 0, start_call))
_LOGGER.debug("(logid={}) Parsing RPC request package".format(data_id)) _LOGGER.debug("(logid={}) Parsing RPC request package".format(data_id))
self._profiler.record("prepack_{}#{}_0".format(data_id, self.name)) self._profiler.record("prepack_{}#{}_0".format(data_id, self.name))
...@@ -250,7 +259,7 @@ class DAGExecutor(object): ...@@ -250,7 +259,7 @@ class DAGExecutor(object):
cond_v) cond_v)
if resp_channeldata.ecode == ChannelDataEcode.OK.value: if resp_channeldata.ecode == ChannelDataEcode.OK.value:
_LOGGER.debug("(logid={}) Succ predict".format(data_id)) _LOGGER.info("(logid={}) Succ predict".format(data_id))
break break
else: else:
_LOGGER.error("(logid={}) Failed to predict: {}" _LOGGER.error("(logid={}) Failed to predict: {}"
...@@ -271,7 +280,14 @@ class DAGExecutor(object): ...@@ -271,7 +280,14 @@ class DAGExecutor(object):
data_id)) data_id))
else: else:
end_call = self._profiler.record("call_{}#DAG_1".format(data_id)) end_call = self._profiler.record("call_{}#DAG_1".format(data_id))
data_buffer.put(("DAG", "call_{}".format(data_id), 1, end_call))
if self._tracer is not None:
if resp_channeldata.ecode == ChannelDataEcode.OK.value:
trace_buffer.put(("DAG", "call_{}".format(data_id), True,
end_call - start_call))
else:
trace_buffer.put(("DAG", "call_{}".format(data_id), False,
end_call - start_call))
profile_str = self._profiler.gen_profile_str() profile_str = self._profiler.gen_profile_str()
if self._server_use_profile: if self._server_use_profile:
...@@ -538,6 +554,8 @@ class DAG(object): ...@@ -538,6 +554,8 @@ class DAG(object):
self._pack_func = pack_func self._pack_func = pack_func
self._unpack_func = unpack_func self._unpack_func = unpack_func
self._tracer.set_channels(self._channels)
return self._input_channel, self._output_channel, self._pack_func, self._unpack_func return self._input_channel, self._output_channel, self._pack_func, self._unpack_func
def start(self): def start(self):
......
...@@ -256,26 +256,32 @@ class Op(object): ...@@ -256,26 +256,32 @@ class Op(object):
channel.push(data, name) channel.push(data, name)
def start_with_process(self, client_type): def start_with_process(self, client_type):
trace_buffer = None
if self._tracer is not None:
trace_buffer = self._tracer.data_buffer()
proces = [] proces = []
for concurrency_idx in range(self.concurrency): for concurrency_idx in range(self.concurrency):
p = multiprocessing.Process( p = multiprocessing.Process(
target=self._run, target=self._run,
args=(concurrency_idx, self._get_input_channel(), args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type, False, self._get_output_channels(), client_type, False,
self._tracer.data_buffer())) trace_buffer))
p.daemon = True p.daemon = True
p.start() p.start()
proces.append(p) proces.append(p)
return proces return proces
def start_with_thread(self, client_type): def start_with_thread(self, client_type):
trace_buffer = None
if self._tracer is not None:
trace_buffer = self._tracer.data_buffer()
threads = [] threads = []
for concurrency_idx in range(self.concurrency): for concurrency_idx in range(self.concurrency):
t = threading.Thread( t = threading.Thread(
target=self._run, target=self._run,
args=(concurrency_idx, self._get_input_channel(), args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type, True, self._get_output_channels(), client_type, True,
self._tracer.data_buffer())) trace_buffer))
# When a process exits, it attempts to terminate # When a process exits, it attempts to terminate
# all of its daemonic child processes. # all of its daemonic child processes.
t.daemon = True t.daemon = True
...@@ -516,7 +522,6 @@ class Op(object): ...@@ -516,7 +522,6 @@ class Op(object):
start, end = None, None start, end = None, None
while True: while True:
start = int(round(_time() * 1000000)) start = int(round(_time() * 1000000))
trace_buffer.put((self.name, "in", 0, start))
try: try:
channeldata_dict_batch = next(batch_generator) channeldata_dict_batch = next(batch_generator)
except ChannelStopError: except ChannelStopError:
...@@ -524,7 +529,8 @@ class Op(object): ...@@ -524,7 +529,8 @@ class Op(object):
self._finalize(is_thread_op) self._finalize(is_thread_op)
break break
end = int(round(_time() * 1000000)) end = int(round(_time() * 1000000))
trace_buffer.put((self.name, "in", 1, end)) if trace_buffer is not None:
trace_buffer.put((self.name, "in", True, end - start))
# parse channeldata batch # parse channeldata batch
try: try:
...@@ -541,11 +547,11 @@ class Op(object): ...@@ -541,11 +547,11 @@ class Op(object):
# preprecess # preprecess
start = profiler.record("prep#{}_0".format(op_info_prefix)) start = profiler.record("prep#{}_0".format(op_info_prefix))
trace_buffer.put((self.name, "prep", 0, start))
preped_data_dict, err_channeldata_dict \ preped_data_dict, err_channeldata_dict \
= self._run_preprocess(parsed_data_dict, op_info_prefix) = self._run_preprocess(parsed_data_dict, op_info_prefix)
end = profiler.record("prep#{}_1".format(op_info_prefix)) end = profiler.record("prep#{}_1".format(op_info_prefix))
trace_buffer.put((self.name, "prep", 1, end)) if trace_buffer is not None:
trace_buffer.put((self.name, "prep", True, end - start))
try: try:
for data_id, err_channeldata in err_channeldata_dict.items(): for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels( self._push_to_output_channels(
...@@ -562,11 +568,11 @@ class Op(object): ...@@ -562,11 +568,11 @@ class Op(object):
# process # process
start = profiler.record("midp#{}_0".format(op_info_prefix)) start = profiler.record("midp#{}_0".format(op_info_prefix))
trace_buffer.put((self.name, "midp", 0, start))
midped_data_dict, err_channeldata_dict \ midped_data_dict, err_channeldata_dict \
= self._run_process(preped_data_dict, op_info_prefix) = self._run_process(preped_data_dict, op_info_prefix)
end = profiler.record("midp#{}_1".format(op_info_prefix)) end = profiler.record("midp#{}_1".format(op_info_prefix))
trace_buffer.put((self.name, "midp", 1, end)) if trace_buffer is not None:
trace_buffer.put((self.name, "midp", True, end - start))
try: try:
for data_id, err_channeldata in err_channeldata_dict.items(): for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels( self._push_to_output_channels(
...@@ -583,12 +589,12 @@ class Op(object): ...@@ -583,12 +589,12 @@ class Op(object):
# postprocess # postprocess
start = profiler.record("postp#{}_0".format(op_info_prefix)) start = profiler.record("postp#{}_0".format(op_info_prefix))
trace_buffer.put((self.name, "postp", 0, start))
postped_data_dict, err_channeldata_dict \ postped_data_dict, err_channeldata_dict \
= self._run_postprocess( = self._run_postprocess(
parsed_data_dict, midped_data_dict, op_info_prefix) parsed_data_dict, midped_data_dict, op_info_prefix)
end = profiler.record("postp#{}_1".format(op_info_prefix)) end = profiler.record("postp#{}_1".format(op_info_prefix))
trace_buffer.put((self.name, "postp", 1, end)) if trace_buffer is not None:
trace_buffer.put((self.name, "postp", True, end - start))
try: try:
for data_id, err_channeldata in err_channeldata_dict.items(): for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels( self._push_to_output_channels(
...@@ -605,7 +611,6 @@ class Op(object): ...@@ -605,7 +611,6 @@ class Op(object):
# push data to channel (if run succ) # push data to channel (if run succ)
start = int(round(_time() * 1000000)) start = int(round(_time() * 1000000))
trace_buffer.put((self.name, "out", 0, start))
try: try:
profile_str = profiler.gen_profile_str() profile_str = profiler.gen_profile_str()
for data_id, postped_data in postped_data_dict.items(): for data_id, postped_data in postped_data_dict.items():
...@@ -622,7 +627,8 @@ class Op(object): ...@@ -622,7 +627,8 @@ class Op(object):
self._finalize(is_thread_op) self._finalize(is_thread_op)
break break
end = int(round(_time() * 1000000)) end = int(round(_time() * 1000000))
trace_buffer.put((self.name, "out", 1, end)) if trace_buffer is not None:
trace_buffer.put((self.name, "out", True, end - start))
def _initialize(self, is_thread_op, client_type, concurrency_idx): def _initialize(self, is_thread_op, client_type, concurrency_idx):
if is_thread_op: if is_thread_op:
...@@ -669,8 +675,8 @@ class RequestOp(Op): ...@@ -669,8 +675,8 @@ class RequestOp(Op):
""" RequestOp do not run preprocess, process, postprocess. """ """ RequestOp do not run preprocess, process, postprocess. """
def __init__(self): def __init__(self):
# PipelineService.name = "@G" # PipelineService.name = "@DAGExecutor"
super(RequestOp, self).__init__(name="@G", input_ops=[]) super(RequestOp, self).__init__(name="@DAGExecutor", input_ops=[])
# init op # init op
try: try:
self.init_op() self.init_op()
...@@ -694,7 +700,8 @@ class ResponseOp(Op): ...@@ -694,7 +700,8 @@ class ResponseOp(Op):
""" ResponseOp do not run preprocess, process, postprocess. """ """ ResponseOp do not run preprocess, process, postprocess. """
def __init__(self, input_ops): def __init__(self, input_ops):
super(ResponseOp, self).__init__(name="@R", input_ops=input_ops) super(ResponseOp, self).__init__(
name="@DAGExecutor", input_ops=input_ops)
# init op # init op
try: try:
self.init_op() self.init_op()
......
...@@ -96,8 +96,7 @@ class PipelineServer(object): ...@@ -96,8 +96,7 @@ class PipelineServer(object):
"(Make sure that install grpcio whl with --no-binary flag)") "(Make sure that install grpcio whl with --no-binary flag)")
_LOGGER.info("-------------------------------------------") _LOGGER.info("-------------------------------------------")
self._dag_conf = conf["dag"] self._conf = conf
self._dag_conf["build_dag_each_worker"] = self._build_dag_each_worker
def run_server(self): def run_server(self):
if self._build_dag_each_worker: if self._build_dag_each_worker:
...@@ -108,7 +107,7 @@ class PipelineServer(object): ...@@ -108,7 +107,7 @@ class PipelineServer(object):
show_info = (i == 0) show_info = (i == 0)
worker = multiprocessing.Process( worker = multiprocessing.Process(
target=self._run_server_func, target=self._run_server_func,
args=(bind_address, self._response_op, self._dag_conf)) args=(bind_address, self._response_op, self._conf))
worker.start() worker.start()
workers.append(worker) workers.append(worker)
for worker in workers: for worker in workers:
...@@ -117,7 +116,7 @@ class PipelineServer(object): ...@@ -117,7 +116,7 @@ class PipelineServer(object):
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num)) futures.ThreadPoolExecutor(max_workers=self._worker_num))
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineServicer(self._response_op, self._dag_conf), server) PipelineServicer(self._response_op, self._conf), server)
server.add_insecure_port('[::]:{}'.format(self._port)) server.add_insecure_port('[::]:{}'.format(self._port))
server.start() server.start()
server.wait_for_termination() server.wait_for_termination()
...@@ -144,8 +143,15 @@ class ServerYamlConfChecker(object): ...@@ -144,8 +143,15 @@ class ServerYamlConfChecker(object):
conf = yaml.load(f.read()) conf = yaml.load(f.read())
ServerYamlConfChecker.check_server_conf(conf) ServerYamlConfChecker.check_server_conf(conf)
ServerYamlConfChecker.check_dag_conf(conf["dag"]) ServerYamlConfChecker.check_dag_conf(conf["dag"])
ServerYamlConfChecker.check_tracer_conf(conf["dag"]["tracer"])
return conf return conf
@staticmethod
def check_conf(conf, default_conf, conf_type, conf_qualification):
ServerYamlConfChecker.fill_with_default_conf(conf, default_conf)
ServerYamlConfChecker.check_conf_type(conf, conf_type)
ServerYamlConfChecker.check_conf_qualification(conf, conf_qualification)
@staticmethod @staticmethod
def check_server_conf(conf): def check_server_conf(conf):
default_conf = { default_conf = {
...@@ -155,22 +161,30 @@ class ServerYamlConfChecker(object): ...@@ -155,22 +161,30 @@ class ServerYamlConfChecker(object):
"dag": {}, "dag": {},
} }
ServerYamlConfChecker.fill_with_default_conf(conf, default_conf)
conf_type = { conf_type = {
"port": int, "port": int,
"worker_num": int, "worker_num": int,
"build_dag_each_worker": bool, "build_dag_each_worker": bool,
} }
ServerYamlConfChecker.check_conf_type(conf, conf_type)
conf_qualification = { conf_qualification = {
"port": [(">=", 1024), ("<=", 65535)], "port": [(">=", 1024), ("<=", 65535)],
"worker_num": (">=", 1), "worker_num": (">=", 1),
} }
ServerYamlConfChecker.check_conf_qualification(conf, conf_qualification) ServerYamlConfChecker.check_conf(conf, default_conf, conf_type,
conf_qualification)
@staticmethod
def check_tracer_conf(conf):
default_conf = {"interval_s": 600, }
conf_type = {"interval_s": int, }
conf_qualification = {}
ServerYamlConfChecker.check_conf(conf, default_conf, conf_type,
conf_qualification)
@staticmethod @staticmethod
def check_dag_conf(conf): def check_dag_conf(conf):
...@@ -179,11 +193,10 @@ class ServerYamlConfChecker(object): ...@@ -179,11 +193,10 @@ class ServerYamlConfChecker(object):
"client_type": "brpc", "client_type": "brpc",
"use_profile": False, "use_profile": False,
"channel_size": 0, "channel_size": 0,
"is_thread_op": True "is_thread_op": True,
"tracer": {},
} }
ServerYamlConfChecker.fill_with_default_conf(conf, default_conf)
conf_type = { conf_type = {
"retry": int, "retry": int,
"client_type": str, "client_type": str,
...@@ -192,15 +205,14 @@ class ServerYamlConfChecker(object): ...@@ -192,15 +205,14 @@ class ServerYamlConfChecker(object):
"is_thread_op": bool, "is_thread_op": bool,
} }
ServerYamlConfChecker.check_conf_type(conf, conf_type)
conf_qualification = { conf_qualification = {
"retry": (">=", 1), "retry": (">=", 1),
"client_type": ("in", ["brpc", "grpc"]), "client_type": ("in", ["brpc", "grpc"]),
"channel_size": (">=", 0), "channel_size": (">=", 0),
} }
ServerYamlConfChecker.check_conf_qualification(conf, conf_qualification) ServerYamlConfChecker.check_conf(conf, default_conf, conf_type,
conf_qualification)
@staticmethod @staticmethod
def fill_with_default_conf(conf, default_conf): def fill_with_default_conf(conf, default_conf):
......
...@@ -27,78 +27,114 @@ import time ...@@ -27,78 +27,114 @@ import time
import threading import threading
import multiprocessing import multiprocessing
_LOGGER = logging.getLogger() _TRACER = logging.getLogger("tracer")
class PerformanceTracer(object): class PerformanceTracer(object):
def __init__(self, interval_s=1): def __init__(self, is_thread_mode, interval_s, server_worker_num):
self._data_buffer = multiprocessing.Manager().Queue() self._is_thread_mode = is_thread_mode
if is_thread_mode:
# Because the Channel in the thread mode cannot be
# accessed across processes, when using thread mode,
# the PerformanceTracer is also the thread mode.
# However, performance may be affected by GIL.
self._data_buffer = Queue.Queue()
else:
self._data_buffer = multiprocessing.Manager().Queue()
self._interval_s = interval_s self._interval_s = interval_s
self._thrd = None
self._proc = None self._proc = None
self._channels = [] self._channels = []
self._trace_filename = os.path.join("PipelineServingLogs", "INDEX.log") # The size of data in Channel will not exceed server_worker_num
self._server_worker_num = server_worker_num
def data_buffer(self): def data_buffer(self):
return self._data_buffer return self._data_buffer
def start(self): def start(self):
self._proc = multiprocessing.Process( if self._is_thread_mode:
target=self._trace_func, args=(self._channels, )) self._thrd = threading.Thread(
self._proc.daemon = True target=self._trace_func, args=(self._channels, ))
self._proc.start() self._thrd.daemon = True
self._thrd.start()
else:
self._proc = multiprocessing.Process(
target=self._trace_func, args=(self._channels, ))
self._proc.daemon = True
self._proc.start()
def set_channels(self, channels): def set_channels(self, channels):
self._channels = channels self._channels = channels
def _trace_func(self, channels): def _trace_func(self, channels):
trace_file = open(self._trace_filename, "a") actions = ["in", "prep", "midp", "postp", "out"]
actions = ["prep", "midp", "postp"] calcu_actions = ["prep", "midp", "postp"]
tag_dict = {}
while True: while True:
op_cost = {} op_cost = {}
trace_file.write("==========================") err_count = 0
_TRACER.info("==================== TRACER ======================")
# op # op
while not self._data_buffer.empty(): while True:
name, action, stage, timestamp = self._data_buffer.get() try:
tag = "{}_{}".format(name, action) name, action, stage, cost = self._data_buffer.get_nowait()
if tag in tag_dict: if stage == False:
assert stage == 1 # only for name == DAG
start_timestamp = tag_dict.pop(tag) assert name == "DAG"
cost = timestamp - start_timestamp err_count += 1
if name not in op_cost: if name not in op_cost:
op_cost[name] = {} op_cost[name] = {}
if action not in op_cost[name]: if action not in op_cost[name]:
op_cost[name][action] = [] op_cost[name][action] = []
op_cost[name][action].append(cost) op_cost[name][action].append(cost)
else: except Queue.Empty:
assert stage == 0 break
tag_dict[tag] = timestamp
if len(op_cost) != 0:
for name in op_cost: for name in op_cost:
tot_cost, cal_cost = 0.0, 0.0 tot_cost, calcu_cost = 0.0, 0.0
for action, costs in op_cost[name].items(): for action, costs in op_cost[name].items():
op_cost[name][action] = sum(costs) / (1e3 * len(costs)) op_cost[name][action] = sum(costs) / (1e3 * len(costs))
tot_cost += op_cost[name][action] tot_cost += op_cost[name][action]
msg = ", ".join([ if name != "DAG":
"{}[{} ms]".format(action, cost) _TRACER.info("Op({}):".format(name))
for action, cost in op_cost[name].items() for action in actions:
]) if action in op_cost[name]:
_TRACER.info("\t{}[{} ms]".format(
for action in actions: action, op_cost[name][action]))
if action in op_cost[name]: for action in calcu_actions:
cal_cost += op_cost[name][action] if action in op_cost[name]:
calcu_cost += op_cost[name][action]
trace_file.write("Op({}) {}".format(name, msg)) _TRACER.info("\tidle[{}]".format(1 - 1.0 * calcu_cost /
if name != "DAG": tot_cost))
trace_file.write("Op({}) idle[{}]".format(
name, 1 - 1.0 * cal_cost / tot_cost)) if "DAG" in op_cost:
calls = op_cost["DAG"].values()
calls.sort()
tot = len(calls)
qps = 1.0 * tot / self._interval_s
ave_cost = sum(calls) / tot
latencys = [50, 60, 70, 80, 90, 95, 99]
_TRACER.info("DAGExecutor:")
_TRACER.info("\tquery count[{}]".format(tot))
_TRACER.info("\tqps[{} q/s]".format(qps))
_TRACER.info("\tsucc[{}]".format(1 - 1.0 * err_count / tot))
_TRACER.info("\tlatency:")
_TRACER.info("\t\tave[{} ms]".format(ave_cost))
for latency in latencys:
_TRACER.info("\t\t.{}[{} ms]".format(latency, calls[int(
tot * latency / 100.0)]))
# channel # channel
_TRACER.info("Channel (server worker num[{}]):".format(
self._server_worker_num))
for channel in channels: for channel in channels:
trace_file.write("Channel({}) size[{}]".format(channel.name, _TRACER.info("\t{}(In: {}, Out: {}) size[{}/{}]".format(
channel.size())) channel.name,
channel.get_producers(),
channel.get_consumers(),
channel.size(), channel.get_maxsize()))
time.sleep(self._interval_s) time.sleep(self._interval_s)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册