提交 f1516e90 编写于 作者: B barrierye

support get profile form client

上级 d8fd9e4b
...@@ -11,20 +11,27 @@ ...@@ -11,20 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle_serving_client.pipeline import PipelineClient from paddle_serving_server.pipeline import PipelineClient
import numpy as np import numpy as np
from line_profiler import LineProfiler
client = PipelineClient() client = PipelineClient()
client.connect(['127.0.0.1:8080']) client.connect(['127.0.0.1:18080'])
lp = LineProfiler()
lp_wrapper = lp(client.predict)
words = 'i am very sad | 0' words = 'i am very sad | 0'
for i in range(10): futures = []
fetch_map = lp_wrapper(feed_dict={"words": words}, fetch=["prediction"]) for i in range(1):
print(fetch_map) futures.append(
client.predict(
feed_dict={"words": words},
fetch=["prediction"],
asyn=True,
profile=True))
for f in futures:
res = f.result()
if res["ecode"] != 0:
print(res)
exit(1)
#lp.print_stats() print(res)
...@@ -21,12 +21,9 @@ import numpy as np ...@@ -21,12 +21,9 @@ import numpy as np
import logging import logging
from paddle_serving_app.reader import IMDBDataset from paddle_serving_app.reader import IMDBDataset
_LOGGER = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG)
logging.basicConfig( _LOGGER = logging.getLogger()
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M',
level=logging.DEBUG)
class ImdbRequestOp(RequestOp): class ImdbRequestOp(RequestOp):
...@@ -91,7 +88,7 @@ cnn_op = Op(name="cnn", ...@@ -91,7 +88,7 @@ cnn_op = Op(name="cnn",
combine_op = CombineOp( combine_op = CombineOp(
name="combine", name="combine",
input_ops=[bow_op, cnn_op], input_ops=[bow_op, cnn_op],
concurrency=1, concurrency=5,
timeout=-1, timeout=-1,
retry=1) retry=1)
......
...@@ -54,7 +54,8 @@ class ChannelData(object): ...@@ -54,7 +54,8 @@ class ChannelData(object):
dictdata=None, dictdata=None,
data_id=None, data_id=None,
ecode=None, ecode=None,
error_info=None): error_info=None,
client_need_profile=False):
''' '''
There are several ways to use it: There are several ways to use it:
...@@ -88,6 +89,13 @@ class ChannelData(object): ...@@ -88,6 +89,13 @@ class ChannelData(object):
self.id = data_id self.id = data_id
self.ecode = ecode self.ecode = ecode
self.error_info = error_info self.error_info = error_info
self.client_need_profile = client_need_profile
self.profile_data_list = []
def add_profile(self, profile_list):
if self.client_need_profile is False:
self.client_need_profile = True
self.profile_data_list.extend(profile_list)
@staticmethod @staticmethod
def check_dictdata(dictdata): def check_dictdata(dictdata):
...@@ -434,7 +442,7 @@ class ProcessChannel(object): ...@@ -434,7 +442,7 @@ class ProcessChannel(object):
return resp # reference, read only return resp # reference, read only
def stop(self): def stop(self):
_LOGGER.info(self._log("stop.")) _LOGGER.debug(self._log("stop."))
self._stop.value = 1 self._stop.value = 1
with self._cv: with self._cv:
self._cv.notify_all() self._cv.notify_all()
...@@ -674,7 +682,7 @@ class ThreadChannel(Queue.Queue): ...@@ -674,7 +682,7 @@ class ThreadChannel(Queue.Queue):
return resp return resp
def stop(self): def stop(self):
_LOGGER.info(self._log("stop.")) _LOGGER.debug(self._log("stop."))
self._stop = True self._stop = True
with self._cv: with self._cv:
self._cv.notify_all() self._cv.notify_all()
......
...@@ -39,11 +39,11 @@ class DAGExecutor(object): ...@@ -39,11 +39,11 @@ class DAGExecutor(object):
self._retry = dag_config.get('retry', 1) self._retry = dag_config.get('retry', 1)
client_type = dag_config.get('client_type', 'brpc') client_type = dag_config.get('client_type', 'brpc')
use_profile = dag_config.get('use_profile', False) self._server_use_profile = dag_config.get('use_profile', False)
channel_size = dag_config.get('channel_size', 0) channel_size = dag_config.get('channel_size', 0)
self._is_thread_op = dag_config.get('is_thread_op', True) self._is_thread_op = dag_config.get('is_thread_op', True)
if show_info and use_profile: if show_info and self._server_use_profile:
_LOGGER.info("================= PROFILER ================") _LOGGER.info("================= PROFILER ================")
if self._is_thread_op: if self._is_thread_op:
_LOGGER.info("op: thread") _LOGGER.info("op: thread")
...@@ -55,10 +55,11 @@ class DAGExecutor(object): ...@@ -55,10 +55,11 @@ class DAGExecutor(object):
self.name = "@G" self.name = "@G"
self._profiler = TimeProfiler() self._profiler = TimeProfiler()
self._profiler.enable(use_profile) self._profiler.enable(True)
self._dag = DAG(self.name, response_op, use_profile, self._is_thread_op, self._dag = DAG(self.name, response_op, self._server_use_profile,
client_type, channel_size, show_info) self._is_thread_op, client_type, channel_size,
show_info)
(in_channel, out_channel, pack_rpc_func, (in_channel, out_channel, pack_rpc_func,
unpack_rpc_func) = self._dag.build() unpack_rpc_func) = self._dag.build()
self._dag.start() self._dag.start()
...@@ -79,6 +80,9 @@ class DAGExecutor(object): ...@@ -79,6 +80,9 @@ class DAGExecutor(object):
self._fetch_buffer = None self._fetch_buffer = None
self._recive_func = None self._recive_func = None
self._client_profile_key = "pipeline.profile"
self._client_profile_value = "1"
def start(self): def start(self):
self._recive_func = threading.Thread( self._recive_func = threading.Thread(
target=DAGExecutor._recive_out_channel_func, args=(self, )) target=DAGExecutor._recive_out_channel_func, args=(self, ))
...@@ -117,7 +121,7 @@ class DAGExecutor(object): ...@@ -117,7 +121,7 @@ class DAGExecutor(object):
try: try:
channeldata_dict = self._out_channel.front(self.name) channeldata_dict = self._out_channel.front(self.name)
except ChannelStopError: except ChannelStopError:
_LOGGER.info(self._log("stop.")) _LOGGER.debug(self._log("stop."))
with self._cv_for_cv_pool: with self._cv_for_cv_pool:
for data_id, cv in self._cv_pool.items(): for data_id, cv in self._cv_pool.items():
closed_errror_data = ChannelData( closed_errror_data = ChannelData(
...@@ -170,10 +174,19 @@ class DAGExecutor(object): ...@@ -170,10 +174,19 @@ class DAGExecutor(object):
error_info="rpc package error: {}".format(e), error_info="rpc package error: {}".format(e),
data_id=data_id) data_id=data_id)
else: else:
# because unpack_rpc_func is rewritten by user, we need
# to look for client_profile_key field in rpc_request
profile_value = None
for idx, key in enumerate(rpc_request.key):
if key == self._client_profile_key:
profile_value = rpc_request.value[idx]
break
return ChannelData( return ChannelData(
datatype=ChannelDataType.DICT.value, datatype=ChannelDataType.DICT.value,
dictdata=dictdata, dictdata=dictdata,
data_id=data_id) data_id=data_id,
client_need_profile=(
profile_value == self._client_profile_value))
def call(self, rpc_request): def call(self, rpc_request):
data_id = self._get_next_data_id() data_id = self._get_next_data_id()
...@@ -193,7 +206,7 @@ class DAGExecutor(object): ...@@ -193,7 +206,7 @@ class DAGExecutor(object):
try: try:
self._in_channel.push(req_channeldata, self.name) self._in_channel.push(req_channeldata, self.name)
except ChannelStopError: except ChannelStopError:
_LOGGER.info(self._log("stop.")) _LOGGER.debug(self._log("stop."))
return self._pack_for_rpc_resp( return self._pack_for_rpc_resp(
ChannelData( ChannelData(
ecode=ChannelDataEcode.CLOSED_ERROR.value, ecode=ChannelDataEcode.CLOSED_ERROR.value,
...@@ -215,12 +228,25 @@ class DAGExecutor(object): ...@@ -215,12 +228,25 @@ class DAGExecutor(object):
self._profiler.record("postpack_{}#{}_0".format(data_id, self.name)) self._profiler.record("postpack_{}#{}_0".format(data_id, self.name))
rpc_resp = self._pack_for_rpc_resp(resp_channeldata) rpc_resp = self._pack_for_rpc_resp(resp_channeldata)
self._profiler.record("postpack_{}#{}_1".format(data_id, self.name)) self._profiler.record("postpack_{}#{}_1".format(data_id, self.name))
if not self._is_thread_op: if not self._is_thread_op:
self._profiler.record("call_{}#DAG-{}_1".format(data_id, data_id)) self._profiler.record("call_{}#DAG-{}_1".format(data_id, data_id))
else: else:
self._profiler.record("call_{}#DAG_1".format(data_id)) self._profiler.record("call_{}#DAG_1".format(data_id))
self._profiler.print_profile() #self._profiler.print_profile()
profile_str = self._profiler.gen_profile_str()
if self._server_use_profile:
sys.stderr.write(profile_str)
# add profile info into rpc_resp
profile_value = ""
if resp_channeldata.client_need_profile:
profile_list = resp_channeldata.profile_data_list
profile_list.append(profile_str)
profile_value = "".join(profile_list)
rpc_resp.key.append(self._client_profile_key)
rpc_resp.value.append(profile_value)
return rpc_resp return rpc_resp
def _pack_for_rpc_resp(self, channeldata): def _pack_for_rpc_resp(self, channeldata):
......
...@@ -61,7 +61,7 @@ class Op(object): ...@@ -61,7 +61,7 @@ class Op(object):
self._input = None self._input = None
self._outputs = [] self._outputs = []
self._use_profile = False self._server_use_profile = False
# only for multithread # only for multithread
self._for_init_op_lock = threading.Lock() self._for_init_op_lock = threading.Lock()
...@@ -70,7 +70,7 @@ class Op(object): ...@@ -70,7 +70,7 @@ class Op(object):
self._succ_close_op = False self._succ_close_op = False
def use_profiler(self, use_profile): def use_profiler(self, use_profile):
self._use_profile = use_profile self._server_use_profile = use_profile
def _profiler_record(self, string): def _profiler_record(self, string):
if self._profiler is None: if self._profiler is None:
...@@ -162,24 +162,46 @@ class Op(object): ...@@ -162,24 +162,46 @@ class Op(object):
def _parse_channeldata(self, channeldata_dict): def _parse_channeldata(self, channeldata_dict):
data_id, error_channeldata = None, None data_id, error_channeldata = None, None
client_need_profile, profile_list = False, []
parsed_data = {} parsed_data = {}
key = list(channeldata_dict.keys())[0] key = list(channeldata_dict.keys())[0]
data_id = channeldata_dict[key].id data_id = channeldata_dict[key].id
client_need_profile = channeldata_dict[key].client_need_profile
for name, data in channeldata_dict.items(): for name, data in channeldata_dict.items():
if data.ecode != ChannelDataEcode.OK.value: if data.ecode != ChannelDataEcode.OK.value:
error_channeldata = data error_channeldata = data
break break
parsed_data[name] = data.parse() parsed_data[name] = data.parse()
return data_id, error_channeldata, parsed_data if client_need_profile:
profile_list.extend(data.profile_data_list)
def _push_to_output_channels(self, data, channels, name=None): return (data_id, error_channeldata, parsed_data, client_need_profile,
profile_list)
def _push_to_output_channels(self,
data,
channels,
name=None,
client_need_profile=False,
profile_list=None):
if name is None: if name is None:
name = self.name name = self.name
self._add_profile_into_channeldata(data, client_need_profile,
profile_list)
for channel in channels: for channel in channels:
channel.push(data, name) channel.push(data, name)
def _add_profile_into_channeldata(self, data, client_need_profile,
profile_list):
profile_str = self._profiler.gen_profile_str()
if self._server_use_profile:
sys.stderr.write(profile_str)
if client_need_profile and profile_list is not None:
profile_list.append(profile_str)
data.add_profile(profile_list)
def start_with_process(self, client_type): def start_with_process(self, client_type):
proces = [] proces = []
for concurrency_idx in range(self.concurrency): for concurrency_idx in range(self.concurrency):
...@@ -335,7 +357,7 @@ class Op(object): ...@@ -335,7 +357,7 @@ class Op(object):
if not self._succ_init_op: if not self._succ_init_op:
# init profiler # init profiler
self._profiler = TimeProfiler() self._profiler = TimeProfiler()
self._profiler.enable(self._use_profile) self._profiler.enable(True)
# init client # init client
self.client = self.init_client( self.client = self.init_client(
client_type, self._client_config, client_type, self._client_config,
...@@ -347,7 +369,7 @@ class Op(object): ...@@ -347,7 +369,7 @@ class Op(object):
else: else:
# init profiler # init profiler
self._profiler = TimeProfiler() self._profiler = TimeProfiler()
self._profiler.enable(self._use_profile) self._profiler.enable(True)
# init client # init client
self.client = self.init_client(client_type, self._client_config, self.client = self.init_client(client_type, self._client_config,
self._server_endpoints, self._server_endpoints,
...@@ -363,7 +385,7 @@ class Op(object): ...@@ -363,7 +385,7 @@ class Op(object):
try: try:
channeldata_dict = input_channel.front(self.name) channeldata_dict = input_channel.front(self.name)
except ChannelStopError: except ChannelStopError:
_LOGGER.info(log("stop.")) _LOGGER.debug(log("stop."))
if is_thread_op: if is_thread_op:
with self._for_close_op_lock: with self._for_close_op_lock:
if not self._succ_close_op: if not self._succ_close_op:
...@@ -375,15 +397,17 @@ class Op(object): ...@@ -375,15 +397,17 @@ class Op(object):
#self._profiler_record("get#{}_1".format(op_info_prefix)) #self._profiler_record("get#{}_1".format(op_info_prefix))
_LOGGER.debug(log("input_data: {}".format(channeldata_dict))) _LOGGER.debug(log("input_data: {}".format(channeldata_dict)))
data_id, error_channeldata, parsed_data = self._parse_channeldata( (data_id, error_channeldata, parsed_data, client_need_profile,
channeldata_dict) profile_list) = self._parse_channeldata(channeldata_dict)
# error data in predecessor Op # error data in predecessor Op
if error_channeldata is not None: if error_channeldata is not None:
try: try:
# error_channeldata with profile info
self._push_to_output_channels(error_channeldata, self._push_to_output_channels(error_channeldata,
output_channels) output_channels)
except ChannelStopError: except ChannelStopError:
_LOGGER.info(log("stop.")) _LOGGER.debug(log("stop."))
break
continue continue
# preprecess # preprecess
...@@ -393,10 +417,14 @@ class Op(object): ...@@ -393,10 +417,14 @@ class Op(object):
self._profiler_record("prep#{}_1".format(op_info_prefix)) self._profiler_record("prep#{}_1".format(op_info_prefix))
if error_channeldata is not None: if error_channeldata is not None:
try: try:
self._push_to_output_channels(error_channeldata, self._push_to_output_channels(
output_channels) error_channeldata,
output_channels,
client_need_profile=client_need_profile,
profile_list=profile_list)
except ChannelStopError: except ChannelStopError:
_LOGGER.info(log("stop.")) _LOGGER.debug(log("stop."))
break
continue continue
# process # process
...@@ -406,10 +434,14 @@ class Op(object): ...@@ -406,10 +434,14 @@ class Op(object):
self._profiler_record("midp#{}_1".format(op_info_prefix)) self._profiler_record("midp#{}_1".format(op_info_prefix))
if error_channeldata is not None: if error_channeldata is not None:
try: try:
self._push_to_output_channels(error_channeldata, self._push_to_output_channels(
output_channels) error_channeldata,
output_channels,
client_need_profile=client_need_profile,
profile_list=profile_list)
except ChannelStopError: except ChannelStopError:
_LOGGER.info(log("stop.")) _LOGGER.debug(log("stop."))
break
continue continue
# postprocess # postprocess
...@@ -419,27 +451,28 @@ class Op(object): ...@@ -419,27 +451,28 @@ class Op(object):
self._profiler_record("postp#{}_1".format(op_info_prefix)) self._profiler_record("postp#{}_1".format(op_info_prefix))
if error_channeldata is not None: if error_channeldata is not None:
try: try:
self._push_to_output_channels(error_channeldata, self._push_to_output_channels(
output_channels) error_channeldata,
output_channels,
client_need_profile=client_need_profile,
profile_list=profile_list)
except ChannelStopError: except ChannelStopError:
_LOGGER.info(log("stop.")) _LOGGER.debug(log("stop."))
break
continue continue
if self._use_profile:
profile_str = self._profiler.gen_profile_str()
sys.stderr.write(profile_str)
#TODO
#output_data.add_profile(profile_str)
# push data to channel (if run succ) # push data to channel (if run succ)
#self._profiler_record("push#{}_0".format(op_info_prefix)) #self._profiler_record("push#{}_0".format(op_info_prefix))
try: try:
self._push_to_output_channels(output_data, output_channels) self._push_to_output_channels(
output_data,
output_channels,
client_need_profile=client_need_profile,
profile_list=profile_list)
except ChannelStopError: except ChannelStopError:
_LOGGER.info(log("stop.")) _LOGGER.debug(log("stop."))
break break
#self._profiler_record("push#{}_1".format(op_info_prefix)) #self._profiler_record("push#{}_1".format(op_info_prefix))
#self._profiler.print_profile()
def _log(self, info): def _log(self, info):
return "{} {}".format(self.name, info) return "{} {}".format(self.name, info)
...@@ -561,7 +594,7 @@ class VirtualOp(Op): ...@@ -561,7 +594,7 @@ class VirtualOp(Op):
try: try:
channeldata_dict = input_channel.front(self.name) channeldata_dict = input_channel.front(self.name)
except ChannelStopError: except ChannelStopError:
_LOGGER.info(log("stop.")) _LOGGER.debug(log("stop."))
break break
try: try:
...@@ -569,5 +602,5 @@ class VirtualOp(Op): ...@@ -569,5 +602,5 @@ class VirtualOp(Op):
self._push_to_output_channels( self._push_to_output_channels(
data, channels=output_channels, name=name) data, channels=output_channels, name=name)
except ChannelStopError: except ChannelStopError:
_LOGGER.info(log("stop.")) _LOGGER.debug(log("stop."))
break break
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
import grpc import grpc
import sys
import numpy as np import numpy as np
from numpy import * from numpy import *
import logging import logging
...@@ -26,6 +27,8 @@ _LOGGER = logging.getLogger() ...@@ -26,6 +27,8 @@ _LOGGER = logging.getLogger()
class PipelineClient(object): class PipelineClient(object):
def __init__(self): def __init__(self):
self._channel = None self._channel = None
self._profile_key = "pipeline.profile"
self._profile_value = "1"
def connect(self, endpoints): def connect(self, endpoints):
options = [('grpc.max_receive_message_length', 512 * 1024 * 1024), options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
...@@ -36,7 +39,7 @@ class PipelineClient(object): ...@@ -36,7 +39,7 @@ class PipelineClient(object):
self._stub = pipeline_service_pb2_grpc.PipelineServiceStub( self._stub = pipeline_service_pb2_grpc.PipelineServiceStub(
self._channel) self._channel)
def _pack_request_package(self, feed_dict): def _pack_request_package(self, feed_dict, profile):
req = pipeline_service_pb2.Request() req = pipeline_service_pb2.Request()
for key, value in feed_dict.items(): for key, value in feed_dict.items():
req.key.append(key) req.key.append(key)
...@@ -49,6 +52,9 @@ class PipelineClient(object): ...@@ -49,6 +52,9 @@ class PipelineClient(object):
else: else:
raise TypeError("only str and np.ndarray type is supported: {}". raise TypeError("only str and np.ndarray type is supported: {}".
format(type(value))) format(type(value)))
if profile:
req.key.append(self._profile_key)
req.value.append(self._profile_value)
return req return req
def _unpack_response_package(self, resp, fetch): def _unpack_response_package(self, resp, fetch):
...@@ -56,6 +62,10 @@ class PipelineClient(object): ...@@ -56,6 +62,10 @@ class PipelineClient(object):
return {"ecode": resp.ecode, "error_info": resp.error_info} return {"ecode": resp.ecode, "error_info": resp.error_info}
fetch_map = {"ecode": resp.ecode} fetch_map = {"ecode": resp.ecode}
for idx, key in enumerate(resp.key): for idx, key in enumerate(resp.key):
if key == self._profile_key:
if resp.value[idx] != "":
sys.stderr.write(resp.value[idx])
continue
if fetch is not None and key not in fetch: if fetch is not None and key not in fetch:
continue continue
data = resp.value[idx] data = resp.value[idx]
...@@ -66,13 +76,13 @@ class PipelineClient(object): ...@@ -66,13 +76,13 @@ class PipelineClient(object):
fetch_map[key] = data fetch_map[key] = data
return fetch_map return fetch_map
def predict(self, feed_dict, fetch=None, asyn=False): def predict(self, feed_dict, fetch=None, asyn=False, profile=False):
if not isinstance(feed_dict, dict): if not isinstance(feed_dict, dict):
raise TypeError( raise TypeError(
"feed must be dict type with format: {name: value}.") "feed must be dict type with format: {name: value}.")
if fetch is not None and not isinstance(fetch, list): if fetch is not None and not isinstance(fetch, list):
raise TypeError("fetch must be list type with format: [name].") raise TypeError("fetch must be list type with format: [name].")
req = self._pack_request_package(feed_dict) req = self._pack_request_package(feed_dict, profile)
if not asyn: if not asyn:
resp = self._stub.inference(req) resp = self._stub.inference(req)
return self._unpack_response_package(resp, fetch) return self._unpack_response_package(resp, fetch)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册