提交 cf1f7cc4 编写于 作者: B barriery 提交者: GitHub

Merge pull request #710 from barrierye/pipeline-update

update pipeline
......@@ -27,7 +27,7 @@ import logging
import enum
import copy
_LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger()
class ChannelDataEcode(enum.Enum):
......@@ -92,7 +92,16 @@ class ChannelData(object):
def check_dictdata(dictdata):
ecode = ChannelDataEcode.OK.value
error_info = None
if not isinstance(dictdata, dict):
if isinstance(dictdata, list):
# batch data
for sample in dictdata:
if not isinstance(sample, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be dict, but get {}.".format(type(sample))
break
elif not isinstance(dictdata, dict):
# batch size = 1
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be dict, but get {}.".format(type(dictdata))
......@@ -102,12 +111,32 @@ class ChannelData(object):
def check_npdata(npdata):
ecode = ChannelDataEcode.OK.value
error_info = None
for _, value in npdata.items():
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be np.ndarray, but get {}.".format(type(value))
break
if isinstance(npdata, list):
# batch data
for sample in npdata:
if not isinstance(sample, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be dict, but get {}.".format(type(sample))
break
for _, value in sample.items():
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be np.ndarray, but get {}.".format(type(value))
return ecode, error_info
elif isinstance(npdata, dict):
# batch_size = 1
for _, value in npdata.items():
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be np.ndarray, but get {}.".format(type(value))
break
else:
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be dict, but get {}.".format(type(npdata))
return ecode, error_info
def parse(self):
......
......@@ -19,13 +19,14 @@ from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import logging
import func_timeout
import os
from numpy import *
from .proto import pipeline_service_pb2
from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType
from .util import NameGenerator
_LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger()
_op_name_gen = NameGenerator("Op")
......@@ -59,6 +60,10 @@ class Op(object):
self._outputs = []
self._profiler = None
# only for multithread
self._for_init_op_lock = threading.Lock()
self._succ_init_op = False
def init_profiler(self, profiler):
self._profiler = profiler
......@@ -71,18 +76,19 @@ class Op(object):
fetch_names):
if self.with_serving == False:
_LOGGER.debug("{} no client".format(self.name))
return
return None
_LOGGER.debug("{} client_config: {}".format(self.name, client_config))
_LOGGER.debug("{} fetch_names: {}".format(self.name, fetch_names))
if client_type == 'brpc':
self._client = Client()
self._client.load_client_config(client_config)
client = Client()
client.load_client_config(client_config)
elif client_type == 'grpc':
self._client = MultiLangClient()
client = MultiLangClient()
else:
raise ValueError("unknow client type: {}".format(client_type))
self._client.connect(server_endpoints)
client.connect(server_endpoints)
self._fetch_names = fetch_names
return client
def _get_input_channel(self):
return self._input
......@@ -130,19 +136,17 @@ class Op(object):
(_, input_dict), = input_dicts.items()
return input_dict
def process(self, feed_dict):
def process(self, client_predict_handler, feed_dict):
err, err_info = ChannelData.check_npdata(feed_dict)
if err != 0:
raise NotImplementedError(
"{} Please override preprocess func.".format(err_info))
_LOGGER.debug(self._log('feed_dict: {}'.format(feed_dict)))
_LOGGER.debug(self._log('fetch: {}'.format(self._fetch_names)))
call_result = self._client.predict(
call_result = client_predict_handler(
feed=feed_dict, fetch=self._fetch_names)
_LOGGER.debug(self._log("get call_result"))
return call_result
def postprocess(self, fetch_dict):
def postprocess(self, input_dict, fetch_dict):
return fetch_dict
def stop(self):
......@@ -174,7 +178,7 @@ class Op(object):
p = multiprocessing.Process(
target=self._run,
args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type))
self._get_output_channels(), client_type, False))
p.start()
proces.append(p)
return proces
......@@ -185,12 +189,12 @@ class Op(object):
t = threading.Thread(
target=self._run,
args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type))
self._get_output_channels(), client_type, True))
t.start()
threads.append(t)
return threads
def load_user_resources(self):
def init_op(self):
pass
def _run_preprocess(self, parsed_data, data_id, log_func):
......@@ -222,13 +226,15 @@ class Op(object):
data_id=data_id)
return preped_data, error_channeldata
def _run_process(self, preped_data, data_id, log_func):
def _run_process(self, client_predict_handler, preped_data, data_id,
log_func):
midped_data, error_channeldata = None, None
if self.with_serving:
ecode = ChannelDataEcode.OK.value
if self._timeout <= 0:
try:
midped_data = self.process(preped_data)
midped_data = self.process(client_predict_handler,
preped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e)
......@@ -237,7 +243,11 @@ class Op(object):
for i in range(self._retry):
try:
midped_data = func_timeout.func_timeout(
self._timeout, self.process, args=(preped_data, ))
self._timeout,
self.process,
args=(
client_predict_handler,
preped_data, ))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
......@@ -267,10 +277,10 @@ class Op(object):
midped_data = preped_data
return midped_data, error_channeldata
def _run_postprocess(self, midped_data, data_id, log_func):
def _run_postprocess(self, input_dict, midped_data, data_id, log_func):
output_data, error_channeldata = None, None
try:
postped_data = self.postprocess(midped_data)
postped_data = self.postprocess(input_dict, midped_data)
except Exception as e:
error_info = log_func(e)
_LOGGER.error(error_info)
......@@ -303,8 +313,8 @@ class Op(object):
data_id=data_id)
return output_data, error_channeldata
def _run(self, concurrency_idx, input_channel, output_channels,
client_type):
def _run(self, concurrency_idx, input_channel, output_channels, client_type,
use_multithread):
def get_log_func(op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
......@@ -315,12 +325,30 @@ class Op(object):
log = get_log_func(op_info_prefix)
tid = threading.current_thread().ident
client = None
client_predict_handler = None
# create client based on client_type
self.init_client(client_type, self._client_config,
self._server_endpoints, self._fetch_names)
try:
client = self.init_client(client_type, self._client_config,
self._server_endpoints, self._fetch_names)
if client is not None:
client_predict_handler = client.predict
except Exception as e:
_LOGGER.error(log(e))
os._exit(-1)
# load user resources
self.load_user_resources()
try:
if use_multithread:
with self._for_init_op_lock:
if not self._succ_init_op:
self.init_op()
self._succ_init_op = True
else:
self.init_op()
except Exception as e:
_LOGGER.error(log(e))
os._exit(-1)
self._is_run = True
while self._is_run:
......@@ -349,8 +377,8 @@ class Op(object):
# process
self._profiler_record("{}-midp#{}_0".format(op_info_prefix, tid))
midped_data, error_channeldata = self._run_process(preped_data,
data_id, log)
midped_data, error_channeldata = self._run_process(
client_predict_handler, preped_data, data_id, log)
self._profiler_record("{}-midp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
......@@ -359,8 +387,8 @@ class Op(object):
# postprocess
self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid))
output_data, error_channeldata = self._run_postprocess(midped_data,
data_id, log)
output_data, error_channeldata = self._run_postprocess(
parsed_data, midped_data, data_id, log)
self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
......@@ -384,7 +412,11 @@ class RequestOp(Op):
super(RequestOp, self).__init__(
name="#G", input_ops=[], concurrency=concurrency)
# load user resources
self.load_user_resources()
try:
self.init_op()
except Exception as e:
_LOGGER.error(e)
os._exit(-1)
def unpack_request_package(self, request):
dictdata = {}
......@@ -405,7 +437,11 @@ class ResponseOp(Op):
super(ResponseOp, self).__init__(
name="#R", input_ops=input_ops, concurrency=concurrency)
# load user resources
self.load_user_resources()
try:
self.init_op()
except Exception as e:
_LOGGER.error(e)
os._exit(-1)
def pack_response_package(self, channeldata):
resp = pipeline_service_pb2.Response()
......@@ -450,17 +486,26 @@ class VirtualOp(Op):
def add_virtual_pred_op(self, op):
self._virtual_pred_ops.append(op)
def _actual_pred_op_names(self, op):
if not isinstance(op, VirtualOp):
return [op.name]
names = []
for x in op._virtual_pred_ops:
names.extend(self._actual_pred_op_names(x))
return names
def add_output_channel(self, channel):
if not isinstance(channel, (ThreadChannel, ProcessChannel)):
raise TypeError(
self._log('output channel must be Channel type, not {}'.format(
type(channel))))
for op in self._virtual_pred_ops:
channel.add_producer(op.name)
for op_name in self._actual_pred_op_names(op):
channel.add_producer(op_name)
self._outputs.append(channel)
def _run(self, concurrency_idx, input_channel, output_channels,
client_type):
def _run(self, concurrency_idx, input_channel, output_channels, client_type,
use_multithread):
def get_log_func(op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
......
......@@ -20,7 +20,7 @@ import functools
from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc
_LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger()
class PipelineClient(object):
......@@ -52,7 +52,7 @@ class PipelineClient(object):
return {"ecode": resp.ecode, "error_info": resp.error_info}
fetch_map = {"ecode": resp.ecode}
for idx, key in enumerate(resp.key):
if key not in fetch:
if fetch is not None and key not in fetch:
continue
data = resp.value[idx]
try:
......@@ -62,16 +62,16 @@ class PipelineClient(object):
fetch_map[key] = data
return fetch_map
def predict(self, feed_dict, fetch, asyn=False):
def predict(self, feed_dict, fetch=None, asyn=False):
if not isinstance(feed_dict, dict):
raise TypeError(
"feed must be dict type with format: {name: value}.")
if not isinstance(fetch, list):
if fetch is not None and not isinstance(fetch, list):
raise TypeError("fetch must be list type with format: [name].")
req = self._pack_request_package(feed_dict)
if not asyn:
resp = self._stub.inference(req)
return self._unpack_response_package(resp)
return self._unpack_response_package(resp, fetch)
else:
call_future = self._stub.inference.future(req)
return PipelinePredictFuture(
......
......@@ -45,7 +45,7 @@ from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcod
from .profiler import TimeProfiler
from .util import NameGenerator
_LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger()
_profiler = TimeProfiler()
......@@ -235,6 +235,10 @@ class PipelineServer(object):
return use_ops, succ_ops_of_use_op
use_ops, out_degree_ops = get_use_ops(response_op)
_LOGGER.info("================= use op ==================")
for op in use_ops:
_LOGGER.info(op.name)
_LOGGER.info("===========================================")
if len(use_ops) <= 1:
raise Exception(
"Besides RequestOp and ResponseOp, there should be at least one Op in DAG."
......
......@@ -24,7 +24,7 @@ else:
raise Exception("Error Python version")
import time
_LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger()
class TimeProfiler(object):
......@@ -58,7 +58,7 @@ class TimeProfiler(object):
print_str += "{}_{}:{} ".format(name, tag, timestamp)
else:
tmp[name] = (tag, timestamp)
print_str += "\n"
print_str = "\n{}\n".format(print_str)
sys.stderr.write(print_str)
for name, item in tmp.items():
tag, timestamp = item
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册