提交 a1521423 编写于 作者: B barrierye

support get profile form client

上级 72be624d
......@@ -11,20 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client.pipeline import PipelineClient
from paddle_serving_server.pipeline import PipelineClient
import numpy as np
from line_profiler import LineProfiler
client = PipelineClient()
client.connect(['127.0.0.1:8080'])
lp = LineProfiler()
lp_wrapper = lp(client.predict)
client.connect(['127.0.0.1:18080'])
words = 'i am very sad | 0'
for i in range(10):
fetch_map = lp_wrapper(feed_dict={"words": words}, fetch=["prediction"])
print(fetch_map)
futures = []
for i in range(1):
futures.append(
client.predict(
feed_dict={"words": words},
fetch=["prediction"],
asyn=True,
profile=True))
for f in futures:
res = f.result()
if res["ecode"] != 0:
print(res)
exit(1)
#lp.print_stats()
print(res)
......@@ -21,12 +21,9 @@ import numpy as np
import logging
from paddle_serving_app.reader import IMDBDataset
_LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M',
level=logging.DEBUG)
_LOGGER = logging.getLogger()
class ImdbRequestOp(RequestOp):
......@@ -91,7 +88,7 @@ cnn_op = Op(name="cnn",
combine_op = CombineOp(
name="combine",
input_ops=[bow_op, cnn_op],
concurrency=1,
concurrency=5,
timeout=-1,
retry=1)
......
......@@ -54,7 +54,8 @@ class ChannelData(object):
dictdata=None,
data_id=None,
ecode=None,
error_info=None):
error_info=None,
client_need_profile=False):
'''
There are several ways to use it:
......@@ -88,6 +89,13 @@ class ChannelData(object):
self.id = data_id
self.ecode = ecode
self.error_info = error_info
self.client_need_profile = client_need_profile
self.profile_data_list = []
def add_profile(self, profile_list):
if self.client_need_profile is False:
self.client_need_profile = True
self.profile_data_list.extend(profile_list)
@staticmethod
def check_dictdata(dictdata):
......@@ -434,7 +442,7 @@ class ProcessChannel(object):
return resp # reference, read only
def stop(self):
_LOGGER.info(self._log("stop."))
_LOGGER.debug(self._log("stop."))
self._stop.value = 1
with self._cv:
self._cv.notify_all()
......@@ -674,7 +682,7 @@ class ThreadChannel(Queue.Queue):
return resp
def stop(self):
_LOGGER.info(self._log("stop."))
_LOGGER.debug(self._log("stop."))
self._stop = True
with self._cv:
self._cv.notify_all()
......
......@@ -39,11 +39,11 @@ class DAGExecutor(object):
self._retry = dag_config.get('retry', 1)
client_type = dag_config.get('client_type', 'brpc')
use_profile = dag_config.get('use_profile', False)
self._server_use_profile = dag_config.get('use_profile', False)
channel_size = dag_config.get('channel_size', 0)
self._is_thread_op = dag_config.get('is_thread_op', True)
if show_info and use_profile:
if show_info and self._server_use_profile:
_LOGGER.info("================= PROFILER ================")
if self._is_thread_op:
_LOGGER.info("op: thread")
......@@ -55,10 +55,11 @@ class DAGExecutor(object):
self.name = "@G"
self._profiler = TimeProfiler()
self._profiler.enable(use_profile)
self._profiler.enable(True)
self._dag = DAG(self.name, response_op, use_profile, self._is_thread_op,
client_type, channel_size, show_info)
self._dag = DAG(self.name, response_op, self._server_use_profile,
self._is_thread_op, client_type, channel_size,
show_info)
(in_channel, out_channel, pack_rpc_func,
unpack_rpc_func) = self._dag.build()
self._dag.start()
......@@ -79,6 +80,9 @@ class DAGExecutor(object):
self._fetch_buffer = None
self._recive_func = None
self._client_profile_key = "pipeline.profile"
self._client_profile_value = "1"
def start(self):
self._recive_func = threading.Thread(
target=DAGExecutor._recive_out_channel_func, args=(self, ))
......@@ -117,7 +121,7 @@ class DAGExecutor(object):
try:
channeldata_dict = self._out_channel.front(self.name)
except ChannelStopError:
_LOGGER.info(self._log("stop."))
_LOGGER.debug(self._log("stop."))
with self._cv_for_cv_pool:
for data_id, cv in self._cv_pool.items():
closed_errror_data = ChannelData(
......@@ -170,10 +174,19 @@ class DAGExecutor(object):
error_info="rpc package error: {}".format(e),
data_id=data_id)
else:
# because unpack_rpc_func is rewritten by user, we need
# to look for client_profile_key field in rpc_request
profile_value = None
for idx, key in enumerate(rpc_request.key):
if key == self._client_profile_key:
profile_value = rpc_request.value[idx]
break
return ChannelData(
datatype=ChannelDataType.DICT.value,
dictdata=dictdata,
data_id=data_id)
data_id=data_id,
client_need_profile=(
profile_value == self._client_profile_value))
def call(self, rpc_request):
data_id = self._get_next_data_id()
......@@ -193,7 +206,7 @@ class DAGExecutor(object):
try:
self._in_channel.push(req_channeldata, self.name)
except ChannelStopError:
_LOGGER.info(self._log("stop."))
_LOGGER.debug(self._log("stop."))
return self._pack_for_rpc_resp(
ChannelData(
ecode=ChannelDataEcode.CLOSED_ERROR.value,
......@@ -215,12 +228,25 @@ class DAGExecutor(object):
self._profiler.record("postpack_{}#{}_0".format(data_id, self.name))
rpc_resp = self._pack_for_rpc_resp(resp_channeldata)
self._profiler.record("postpack_{}#{}_1".format(data_id, self.name))
if not self._is_thread_op:
self._profiler.record("call_{}#DAG-{}_1".format(data_id, data_id))
else:
self._profiler.record("call_{}#DAG_1".format(data_id))
self._profiler.print_profile()
#self._profiler.print_profile()
profile_str = self._profiler.gen_profile_str()
if self._server_use_profile:
sys.stderr.write(profile_str)
# add profile info into rpc_resp
profile_value = ""
if resp_channeldata.client_need_profile:
profile_list = resp_channeldata.profile_data_list
profile_list.append(profile_str)
profile_value = "".join(profile_list)
rpc_resp.key.append(self._client_profile_key)
rpc_resp.value.append(profile_value)
return rpc_resp
def _pack_for_rpc_resp(self, channeldata):
......
......@@ -61,7 +61,7 @@ class Op(object):
self._input = None
self._outputs = []
self._use_profile = False
self._server_use_profile = False
# only for multithread
self._for_init_op_lock = threading.Lock()
......@@ -70,7 +70,7 @@ class Op(object):
self._succ_close_op = False
def use_profiler(self, use_profile):
self._use_profile = use_profile
self._server_use_profile = use_profile
def _profiler_record(self, string):
if self._profiler is None:
......@@ -162,24 +162,46 @@ class Op(object):
def _parse_channeldata(self, channeldata_dict):
data_id, error_channeldata = None, None
client_need_profile, profile_list = False, []
parsed_data = {}
key = list(channeldata_dict.keys())[0]
data_id = channeldata_dict[key].id
client_need_profile = channeldata_dict[key].client_need_profile
for name, data in channeldata_dict.items():
if data.ecode != ChannelDataEcode.OK.value:
error_channeldata = data
break
parsed_data[name] = data.parse()
return data_id, error_channeldata, parsed_data
def _push_to_output_channels(self, data, channels, name=None):
if client_need_profile:
profile_list.extend(data.profile_data_list)
return (data_id, error_channeldata, parsed_data, client_need_profile,
profile_list)
def _push_to_output_channels(self,
data,
channels,
name=None,
client_need_profile=False,
profile_list=None):
if name is None:
name = self.name
self._add_profile_into_channeldata(data, client_need_profile,
profile_list)
for channel in channels:
channel.push(data, name)
def _add_profile_into_channeldata(self, data, client_need_profile,
profile_list):
profile_str = self._profiler.gen_profile_str()
if self._server_use_profile:
sys.stderr.write(profile_str)
if client_need_profile and profile_list is not None:
profile_list.append(profile_str)
data.add_profile(profile_list)
def start_with_process(self, client_type):
proces = []
for concurrency_idx in range(self.concurrency):
......@@ -335,7 +357,7 @@ class Op(object):
if not self._succ_init_op:
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(self._use_profile)
self._profiler.enable(True)
# init client
self.client = self.init_client(
client_type, self._client_config,
......@@ -347,7 +369,7 @@ class Op(object):
else:
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(self._use_profile)
self._profiler.enable(True)
# init client
self.client = self.init_client(client_type, self._client_config,
self._server_endpoints,
......@@ -363,7 +385,7 @@ class Op(object):
try:
channeldata_dict = input_channel.front(self.name)
except ChannelStopError:
_LOGGER.info(log("stop."))
_LOGGER.debug(log("stop."))
if is_thread_op:
with self._for_close_op_lock:
if not self._succ_close_op:
......@@ -375,15 +397,17 @@ class Op(object):
#self._profiler_record("get#{}_1".format(op_info_prefix))
_LOGGER.debug(log("input_data: {}".format(channeldata_dict)))
data_id, error_channeldata, parsed_data = self._parse_channeldata(
channeldata_dict)
(data_id, error_channeldata, parsed_data, client_need_profile,
profile_list) = self._parse_channeldata(channeldata_dict)
# error data in predecessor Op
if error_channeldata is not None:
try:
# error_channeldata with profile info
self._push_to_output_channels(error_channeldata,
output_channels)
except ChannelStopError:
_LOGGER.info(log("stop."))
_LOGGER.debug(log("stop."))
break
continue
# preprecess
......@@ -393,10 +417,14 @@ class Op(object):
self._profiler_record("prep#{}_1".format(op_info_prefix))
if error_channeldata is not None:
try:
self._push_to_output_channels(error_channeldata,
output_channels)
self._push_to_output_channels(
error_channeldata,
output_channels,
client_need_profile=client_need_profile,
profile_list=profile_list)
except ChannelStopError:
_LOGGER.info(log("stop."))
_LOGGER.debug(log("stop."))
break
continue
# process
......@@ -406,10 +434,14 @@ class Op(object):
self._profiler_record("midp#{}_1".format(op_info_prefix))
if error_channeldata is not None:
try:
self._push_to_output_channels(error_channeldata,
output_channels)
self._push_to_output_channels(
error_channeldata,
output_channels,
client_need_profile=client_need_profile,
profile_list=profile_list)
except ChannelStopError:
_LOGGER.info(log("stop."))
_LOGGER.debug(log("stop."))
break
continue
# postprocess
......@@ -419,27 +451,28 @@ class Op(object):
self._profiler_record("postp#{}_1".format(op_info_prefix))
if error_channeldata is not None:
try:
self._push_to_output_channels(error_channeldata,
output_channels)
self._push_to_output_channels(
error_channeldata,
output_channels,
client_need_profile=client_need_profile,
profile_list=profile_list)
except ChannelStopError:
_LOGGER.info(log("stop."))
_LOGGER.debug(log("stop."))
break
continue
if self._use_profile:
profile_str = self._profiler.gen_profile_str()
sys.stderr.write(profile_str)
#TODO
#output_data.add_profile(profile_str)
# push data to channel (if run succ)
#self._profiler_record("push#{}_0".format(op_info_prefix))
try:
self._push_to_output_channels(output_data, output_channels)
self._push_to_output_channels(
output_data,
output_channels,
client_need_profile=client_need_profile,
profile_list=profile_list)
except ChannelStopError:
_LOGGER.info(log("stop."))
_LOGGER.debug(log("stop."))
break
#self._profiler_record("push#{}_1".format(op_info_prefix))
#self._profiler.print_profile()
def _log(self, info):
return "{} {}".format(self.name, info)
......@@ -561,7 +594,7 @@ class VirtualOp(Op):
try:
channeldata_dict = input_channel.front(self.name)
except ChannelStopError:
_LOGGER.info(log("stop."))
_LOGGER.debug(log("stop."))
break
try:
......@@ -569,5 +602,5 @@ class VirtualOp(Op):
self._push_to_output_channels(
data, channels=output_channels, name=name)
except ChannelStopError:
_LOGGER.info(log("stop."))
_LOGGER.debug(log("stop."))
break
......@@ -13,6 +13,7 @@
# limitations under the License.
# pylint: disable=doc-string-missing
import grpc
import sys
import numpy as np
from numpy import *
import logging
......@@ -26,6 +27,8 @@ _LOGGER = logging.getLogger()
class PipelineClient(object):
def __init__(self):
self._channel = None
self._profile_key = "pipeline.profile"
self._profile_value = "1"
def connect(self, endpoints):
options = [('grpc.max_receive_message_length', 512 * 1024 * 1024),
......@@ -36,7 +39,7 @@ class PipelineClient(object):
self._stub = pipeline_service_pb2_grpc.PipelineServiceStub(
self._channel)
def _pack_request_package(self, feed_dict):
def _pack_request_package(self, feed_dict, profile):
req = pipeline_service_pb2.Request()
for key, value in feed_dict.items():
req.key.append(key)
......@@ -49,6 +52,9 @@ class PipelineClient(object):
else:
raise TypeError("only str and np.ndarray type is supported: {}".
format(type(value)))
if profile:
req.key.append(self._profile_key)
req.value.append(self._profile_value)
return req
def _unpack_response_package(self, resp, fetch):
......@@ -56,6 +62,10 @@ class PipelineClient(object):
return {"ecode": resp.ecode, "error_info": resp.error_info}
fetch_map = {"ecode": resp.ecode}
for idx, key in enumerate(resp.key):
if key == self._profile_key:
if resp.value[idx] != "":
sys.stderr.write(resp.value[idx])
continue
if fetch is not None and key not in fetch:
continue
data = resp.value[idx]
......@@ -66,13 +76,13 @@ class PipelineClient(object):
fetch_map[key] = data
return fetch_map
def predict(self, feed_dict, fetch=None, asyn=False):
def predict(self, feed_dict, fetch=None, asyn=False, profile=False):
if not isinstance(feed_dict, dict):
raise TypeError(
"feed must be dict type with format: {name: value}.")
if fetch is not None and not isinstance(fetch, list):
raise TypeError("fetch must be list type with format: [name].")
req = self._pack_request_package(feed_dict)
req = self._pack_request_package(feed_dict, profile)
if not asyn:
resp = self._stub.inference(req)
return self._unpack_response_package(resp, fetch)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册