提交 e5e0bab6 编写于 作者: B barrierye

add profile option

上级 ff8f45f3
......@@ -29,7 +29,7 @@ x = np.array(
lp = LineProfiler()
lp_wrapper = lp(client.predict)
for i in range(5):
for i in range(3):
fetch_map = lp_wrapper(
feed={"x": x}, fetch_with_type={"combine_op_output": "float"})
# fetch_map = client.predict(
......
......@@ -61,7 +61,7 @@ combine_channel = Channel(name="combine_channel")
out_channel = Channel(name="out_channel")
cnn_op = UciOp(
name="cnn_op",
name="cnn",
input=read_channel,
in_dtype='float',
outputs=[combine_channel],
......@@ -75,7 +75,7 @@ cnn_op = UciOp(
concurrency=1)
bow_op = UciOp(
name="bow_op",
name="bow",
input=read_channel,
in_dtype='float',
outputs=[combine_channel],
......@@ -89,7 +89,7 @@ bow_op = UciOp(
concurrency=1)
combine_op = CombineOp(
name="combine_op",
name="combine",
input=combine_channel,
in_dtype='float',
outputs=[out_channel],
......@@ -99,7 +99,7 @@ combine_op = CombineOp(
logging.info(read_channel.debug())
logging.info(combine_channel.debug())
logging.info(out_channel.debug())
pyserver = PyServer()
pyserver = PyServer(profile=False)
pyserver.add_channel(read_channel)
pyserver.add_channel(combine_channel)
pyserver.add_channel(out_channel)
......
......@@ -16,6 +16,7 @@ import threading
import multiprocessing
import Queue
import os
import sys
import paddle_serving_server
from paddle_serving_client import Client
from concurrent import futures
......@@ -29,6 +30,42 @@ import random
import time
class _TimeProfiler(object):
def __init__(self):
self._pid = os.getpid()
self._print_head = 'PROFILE\tpid:{}\t'.format(self._pid)
self._time_record = Queue.Queue()
self._enable = False
def enable(self, enable):
self._enable = enable
def record(self, name_with_tag):
name_with_tag = name_with_tag.split("_")
tag = name_with_tag[-1]
name = '_'.join(name_with_tag[:-1])
self._time_record.put((name, tag, int(round(time.time() * 1000000))))
def print_profile(self):
sys.stderr.write(self._print_head)
tmp = {}
while not self._time_record.empty():
name, tag, timestamp = self._time_record.get()
if name in tmp:
ptag, ptimestamp = tmp.pop(name)
sys.stderr.write("{}_{}:{} ".format(name, ptag, ptimestamp))
sys.stderr.write("{}_{}:{} ".format(name, tag, timestamp))
else:
tmp[name] = (tag, timestamp)
sys.stderr.write('\n')
for name, item in tmp.items():
tag, timestamp = item
self._time_record.put((name, tag, timestamp))
_profiler = _TimeProfiler()
class Channel(Queue.Queue):
"""
The channel used for communication between Ops.
......@@ -313,7 +350,9 @@ class Op(object):
def start(self):
self._run = True
while self._run:
_profiler.record("{}-get_0".format(self._name))
input_data = self._input.front(self._name)
_profiler.record("{}-get_1".format(self._name))
data_id = None
logging.debug(self._log("input_data: {}".format(input_data)))
if isinstance(input_data, dict):
......@@ -322,10 +361,18 @@ class Op(object):
else:
data_id = input_data.id
_profiler.record("{}-prep_0".format(self._name))
data = self.preprocess(input_data)
_profiler.record("{}-prep_1".format(self._name))
if self.with_serving():
_profiler.record("{}-midp_0".format(self._name))
data = self.midprocess(data)
_profiler.record("{}-midp_1".format(self._name))
_profiler.record("{}-postp_0".format(self._name))
output_data = self.postprocess(data)
_profiler.record("{}-postp_1".format(self._name))
if not isinstance(output_data,
python_service_channel_pb2.ChannelData):
......@@ -335,8 +382,10 @@ class Op(object):
format(type(output_data))))
output_data.id = data_id
_profiler.record("{}-push_0".format(self._name))
for channel in self._outputs:
channel.push(output_data, self._name)
_profiler.record("{}-push_1".format(self._name))
def _log(self, info_str):
return "[{}] {}".format(self._name, info_str)
......@@ -349,7 +398,7 @@ class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, in_channel, out_channel):
super(GeneralPythonService, self).__init__()
self._name = "__GeneralPythonService__"
self._name = "#G"
self.set_in_channel(in_channel)
self.set_out_channel(out_channel)
logging.debug(self._log(in_channel.debug()))
......@@ -437,18 +486,30 @@ class GeneralPythonService(
return resp
def inference(self, request, context):
_profiler.record("{}-prepack_0".format(self._name))
data, data_id = self._pack_data_for_infer(request)
_profiler.record("{}-prepack_1".format(self._name))
logging.debug(self._log('push data'))
_profiler.record("{}-push_0".format(self._name))
self._in_channel.push(data, self._name)
_profiler.record("{}-push_1".format(self._name))
logging.debug(self._log('wait for infer'))
resp_data = None
_profiler.record("{}-fetch_0".format(self._name))
resp_data = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self._name))
_profiler.record("{}-postpack_0".format(self._name))
resp = self._pack_data_for_resp(resp_data)
_profiler.record("{}-postpack_1".format(self._name))
_profiler.print_profile()
return resp
class PyServer(object):
def __init__(self):
def __init__(self, profile=False):
self._channels = []
self._ops = []
self._op_threads = []
......@@ -456,6 +517,7 @@ class PyServer(object):
self._worker_num = None
self._in_channel = None
self._out_channel = None
_profiler.enable(profile)
def add_channel(self, channel):
self._channels.append(channel)
......@@ -496,7 +558,7 @@ class PyServer(object):
logging.debug("run op: {}, op_concurrency: {}".format(
op._name, op_concurrency))
for c in range(op_concurrency):
# th = multiprocessing.Process(target=self._op_start_wrapper, args=(op, ))
# th = multiprocessing.Process(
th = threading.Thread(
target=self._op_start_wrapper, args=(op, ))
th.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册