未验证 提交 f3eb9d4a 编写于 作者: B barriery 提交者: GitHub

Merge pull request #710 from barrierye/pipeline-update

update pipeline
...@@ -27,7 +27,7 @@ import logging ...@@ -27,7 +27,7 @@ import logging
import enum import enum
import copy import copy
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger()
class ChannelDataEcode(enum.Enum): class ChannelDataEcode(enum.Enum):
...@@ -92,7 +92,16 @@ class ChannelData(object): ...@@ -92,7 +92,16 @@ class ChannelData(object):
def check_dictdata(dictdata): def check_dictdata(dictdata):
ecode = ChannelDataEcode.OK.value ecode = ChannelDataEcode.OK.value
error_info = None error_info = None
if not isinstance(dictdata, dict): if isinstance(dictdata, list):
# batch data
for sample in dictdata:
if not isinstance(sample, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be dict, but get {}.".format(type(sample))
break
elif not isinstance(dictdata, dict):
# batch size = 1
ecode = ChannelDataEcode.TYPE_ERROR.value ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \ error_info = "the value of data must " \
"be dict, but get {}.".format(type(dictdata)) "be dict, but get {}.".format(type(dictdata))
...@@ -102,12 +111,32 @@ class ChannelData(object): ...@@ -102,12 +111,32 @@ class ChannelData(object):
def check_npdata(npdata): def check_npdata(npdata):
ecode = ChannelDataEcode.OK.value ecode = ChannelDataEcode.OK.value
error_info = None error_info = None
if isinstance(npdata, list):
# batch data
for sample in npdata:
if not isinstance(sample, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be dict, but get {}.".format(type(sample))
break
for _, value in sample.items():
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be np.ndarray, but get {}.".format(type(value))
return ecode, error_info
elif isinstance(npdata, dict):
# batch_size = 1
for _, value in npdata.items(): for _, value in npdata.items():
if not isinstance(value, np.ndarray): if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \ error_info = "the value of data must " \
"be np.ndarray, but get {}.".format(type(value)) "be np.ndarray, but get {}.".format(type(value))
break break
else:
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be dict, but get {}.".format(type(npdata))
return ecode, error_info return ecode, error_info
def parse(self): def parse(self):
......
...@@ -19,13 +19,14 @@ from paddle_serving_client import MultiLangClient, Client ...@@ -19,13 +19,14 @@ from paddle_serving_client import MultiLangClient, Client
from concurrent import futures from concurrent import futures
import logging import logging
import func_timeout import func_timeout
import os
from numpy import * from numpy import *
from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2
from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType
from .util import NameGenerator from .util import NameGenerator
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger()
_op_name_gen = NameGenerator("Op") _op_name_gen = NameGenerator("Op")
...@@ -59,6 +60,10 @@ class Op(object): ...@@ -59,6 +60,10 @@ class Op(object):
self._outputs = [] self._outputs = []
self._profiler = None self._profiler = None
# only for multithread
self._for_init_op_lock = threading.Lock()
self._succ_init_op = False
def init_profiler(self, profiler): def init_profiler(self, profiler):
self._profiler = profiler self._profiler = profiler
...@@ -71,18 +76,19 @@ class Op(object): ...@@ -71,18 +76,19 @@ class Op(object):
fetch_names): fetch_names):
if self.with_serving == False: if self.with_serving == False:
_LOGGER.debug("{} no client".format(self.name)) _LOGGER.debug("{} no client".format(self.name))
return return None
_LOGGER.debug("{} client_config: {}".format(self.name, client_config)) _LOGGER.debug("{} client_config: {}".format(self.name, client_config))
_LOGGER.debug("{} fetch_names: {}".format(self.name, fetch_names)) _LOGGER.debug("{} fetch_names: {}".format(self.name, fetch_names))
if client_type == 'brpc': if client_type == 'brpc':
self._client = Client() client = Client()
self._client.load_client_config(client_config) client.load_client_config(client_config)
elif client_type == 'grpc': elif client_type == 'grpc':
self._client = MultiLangClient() client = MultiLangClient()
else: else:
raise ValueError("unknow client type: {}".format(client_type)) raise ValueError("unknow client type: {}".format(client_type))
self._client.connect(server_endpoints) client.connect(server_endpoints)
self._fetch_names = fetch_names self._fetch_names = fetch_names
return client
def _get_input_channel(self): def _get_input_channel(self):
return self._input return self._input
...@@ -130,19 +136,17 @@ class Op(object): ...@@ -130,19 +136,17 @@ class Op(object):
(_, input_dict), = input_dicts.items() (_, input_dict), = input_dicts.items()
return input_dict return input_dict
def process(self, feed_dict): def process(self, client_predict_handler, feed_dict):
err, err_info = ChannelData.check_npdata(feed_dict) err, err_info = ChannelData.check_npdata(feed_dict)
if err != 0: if err != 0:
raise NotImplementedError( raise NotImplementedError(
"{} Please override preprocess func.".format(err_info)) "{} Please override preprocess func.".format(err_info))
_LOGGER.debug(self._log('feed_dict: {}'.format(feed_dict))) call_result = client_predict_handler(
_LOGGER.debug(self._log('fetch: {}'.format(self._fetch_names)))
call_result = self._client.predict(
feed=feed_dict, fetch=self._fetch_names) feed=feed_dict, fetch=self._fetch_names)
_LOGGER.debug(self._log("get call_result")) _LOGGER.debug(self._log("get call_result"))
return call_result return call_result
def postprocess(self, fetch_dict): def postprocess(self, input_dict, fetch_dict):
return fetch_dict return fetch_dict
def stop(self): def stop(self):
...@@ -174,7 +178,7 @@ class Op(object): ...@@ -174,7 +178,7 @@ class Op(object):
p = multiprocessing.Process( p = multiprocessing.Process(
target=self._run, target=self._run,
args=(concurrency_idx, self._get_input_channel(), args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type)) self._get_output_channels(), client_type, False))
p.start() p.start()
proces.append(p) proces.append(p)
return proces return proces
...@@ -185,12 +189,12 @@ class Op(object): ...@@ -185,12 +189,12 @@ class Op(object):
t = threading.Thread( t = threading.Thread(
target=self._run, target=self._run,
args=(concurrency_idx, self._get_input_channel(), args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type)) self._get_output_channels(), client_type, True))
t.start() t.start()
threads.append(t) threads.append(t)
return threads return threads
def load_user_resources(self): def init_op(self):
pass pass
def _run_preprocess(self, parsed_data, data_id, log_func): def _run_preprocess(self, parsed_data, data_id, log_func):
...@@ -222,13 +226,15 @@ class Op(object): ...@@ -222,13 +226,15 @@ class Op(object):
data_id=data_id) data_id=data_id)
return preped_data, error_channeldata return preped_data, error_channeldata
def _run_process(self, preped_data, data_id, log_func): def _run_process(self, client_predict_handler, preped_data, data_id,
log_func):
midped_data, error_channeldata = None, None midped_data, error_channeldata = None, None
if self.with_serving: if self.with_serving:
ecode = ChannelDataEcode.OK.value ecode = ChannelDataEcode.OK.value
if self._timeout <= 0: if self._timeout <= 0:
try: try:
midped_data = self.process(preped_data) midped_data = self.process(client_predict_handler,
preped_data)
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e) error_info = log_func(e)
...@@ -237,7 +243,11 @@ class Op(object): ...@@ -237,7 +243,11 @@ class Op(object):
for i in range(self._retry): for i in range(self._retry):
try: try:
midped_data = func_timeout.func_timeout( midped_data = func_timeout.func_timeout(
self._timeout, self.process, args=(preped_data, )) self._timeout,
self.process,
args=(
client_predict_handler,
preped_data, ))
except func_timeout.FunctionTimedOut as e: except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry: if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value ecode = ChannelDataEcode.TIMEOUT.value
...@@ -267,10 +277,10 @@ class Op(object): ...@@ -267,10 +277,10 @@ class Op(object):
midped_data = preped_data midped_data = preped_data
return midped_data, error_channeldata return midped_data, error_channeldata
def _run_postprocess(self, midped_data, data_id, log_func): def _run_postprocess(self, input_dict, midped_data, data_id, log_func):
output_data, error_channeldata = None, None output_data, error_channeldata = None, None
try: try:
postped_data = self.postprocess(midped_data) postped_data = self.postprocess(input_dict, midped_data)
except Exception as e: except Exception as e:
error_info = log_func(e) error_info = log_func(e)
_LOGGER.error(error_info) _LOGGER.error(error_info)
...@@ -303,8 +313,8 @@ class Op(object): ...@@ -303,8 +313,8 @@ class Op(object):
data_id=data_id) data_id=data_id)
return output_data, error_channeldata return output_data, error_channeldata
def _run(self, concurrency_idx, input_channel, output_channels, def _run(self, concurrency_idx, input_channel, output_channels, client_type,
client_type): use_multithread):
def get_log_func(op_info_prefix): def get_log_func(op_info_prefix):
def log_func(info_str): def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str) return "{} {}".format(op_info_prefix, info_str)
...@@ -315,12 +325,30 @@ class Op(object): ...@@ -315,12 +325,30 @@ class Op(object):
log = get_log_func(op_info_prefix) log = get_log_func(op_info_prefix)
tid = threading.current_thread().ident tid = threading.current_thread().ident
client = None
client_predict_handler = None
# create client based on client_type # create client based on client_type
self.init_client(client_type, self._client_config, try:
client = self.init_client(client_type, self._client_config,
self._server_endpoints, self._fetch_names) self._server_endpoints, self._fetch_names)
if client is not None:
client_predict_handler = client.predict
except Exception as e:
_LOGGER.error(log(e))
os._exit(-1)
# load user resources # load user resources
self.load_user_resources() try:
if use_multithread:
with self._for_init_op_lock:
if not self._succ_init_op:
self.init_op()
self._succ_init_op = True
else:
self.init_op()
except Exception as e:
_LOGGER.error(log(e))
os._exit(-1)
self._is_run = True self._is_run = True
while self._is_run: while self._is_run:
...@@ -349,8 +377,8 @@ class Op(object): ...@@ -349,8 +377,8 @@ class Op(object):
# process # process
self._profiler_record("{}-midp#{}_0".format(op_info_prefix, tid)) self._profiler_record("{}-midp#{}_0".format(op_info_prefix, tid))
midped_data, error_channeldata = self._run_process(preped_data, midped_data, error_channeldata = self._run_process(
data_id, log) client_predict_handler, preped_data, data_id, log)
self._profiler_record("{}-midp#{}_1".format(op_info_prefix, tid)) self._profiler_record("{}-midp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None: if error_channeldata is not None:
self._push_to_output_channels(error_channeldata, self._push_to_output_channels(error_channeldata,
...@@ -359,8 +387,8 @@ class Op(object): ...@@ -359,8 +387,8 @@ class Op(object):
# postprocess # postprocess
self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid)) self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid))
output_data, error_channeldata = self._run_postprocess(midped_data, output_data, error_channeldata = self._run_postprocess(
data_id, log) parsed_data, midped_data, data_id, log)
self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid)) self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None: if error_channeldata is not None:
self._push_to_output_channels(error_channeldata, self._push_to_output_channels(error_channeldata,
...@@ -384,7 +412,11 @@ class RequestOp(Op): ...@@ -384,7 +412,11 @@ class RequestOp(Op):
super(RequestOp, self).__init__( super(RequestOp, self).__init__(
name="#G", input_ops=[], concurrency=concurrency) name="#G", input_ops=[], concurrency=concurrency)
# load user resources # load user resources
self.load_user_resources() try:
self.init_op()
except Exception as e:
_LOGGER.error(e)
os._exit(-1)
def unpack_request_package(self, request): def unpack_request_package(self, request):
dictdata = {} dictdata = {}
...@@ -405,7 +437,11 @@ class ResponseOp(Op): ...@@ -405,7 +437,11 @@ class ResponseOp(Op):
super(ResponseOp, self).__init__( super(ResponseOp, self).__init__(
name="#R", input_ops=input_ops, concurrency=concurrency) name="#R", input_ops=input_ops, concurrency=concurrency)
# load user resources # load user resources
self.load_user_resources() try:
self.init_op()
except Exception as e:
_LOGGER.error(e)
os._exit(-1)
def pack_response_package(self, channeldata): def pack_response_package(self, channeldata):
resp = pipeline_service_pb2.Response() resp = pipeline_service_pb2.Response()
...@@ -450,17 +486,26 @@ class VirtualOp(Op): ...@@ -450,17 +486,26 @@ class VirtualOp(Op):
def add_virtual_pred_op(self, op): def add_virtual_pred_op(self, op):
self._virtual_pred_ops.append(op) self._virtual_pred_ops.append(op)
def _actual_pred_op_names(self, op):
if not isinstance(op, VirtualOp):
return [op.name]
names = []
for x in op._virtual_pred_ops:
names.extend(self._actual_pred_op_names(x))
return names
def add_output_channel(self, channel): def add_output_channel(self, channel):
if not isinstance(channel, (ThreadChannel, ProcessChannel)): if not isinstance(channel, (ThreadChannel, ProcessChannel)):
raise TypeError( raise TypeError(
self._log('output channel must be Channel type, not {}'.format( self._log('output channel must be Channel type, not {}'.format(
type(channel)))) type(channel))))
for op in self._virtual_pred_ops: for op in self._virtual_pred_ops:
channel.add_producer(op.name) for op_name in self._actual_pred_op_names(op):
channel.add_producer(op_name)
self._outputs.append(channel) self._outputs.append(channel)
def _run(self, concurrency_idx, input_channel, output_channels, def _run(self, concurrency_idx, input_channel, output_channels, client_type,
client_type): use_multithread):
def get_log_func(op_info_prefix): def get_log_func(op_info_prefix):
def log_func(info_str): def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str) return "{} {}".format(op_info_prefix, info_str)
......
...@@ -20,7 +20,7 @@ import functools ...@@ -20,7 +20,7 @@ import functools
from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc from .proto import pipeline_service_pb2_grpc
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger()
class PipelineClient(object): class PipelineClient(object):
...@@ -52,7 +52,7 @@ class PipelineClient(object): ...@@ -52,7 +52,7 @@ class PipelineClient(object):
return {"ecode": resp.ecode, "error_info": resp.error_info} return {"ecode": resp.ecode, "error_info": resp.error_info}
fetch_map = {"ecode": resp.ecode} fetch_map = {"ecode": resp.ecode}
for idx, key in enumerate(resp.key): for idx, key in enumerate(resp.key):
if key not in fetch: if fetch is not None and key not in fetch:
continue continue
data = resp.value[idx] data = resp.value[idx]
try: try:
...@@ -62,16 +62,16 @@ class PipelineClient(object): ...@@ -62,16 +62,16 @@ class PipelineClient(object):
fetch_map[key] = data fetch_map[key] = data
return fetch_map return fetch_map
def predict(self, feed_dict, fetch, asyn=False): def predict(self, feed_dict, fetch=None, asyn=False):
if not isinstance(feed_dict, dict): if not isinstance(feed_dict, dict):
raise TypeError( raise TypeError(
"feed must be dict type with format: {name: value}.") "feed must be dict type with format: {name: value}.")
if not isinstance(fetch, list): if fetch is not None and not isinstance(fetch, list):
raise TypeError("fetch must be list type with format: [name].") raise TypeError("fetch must be list type with format: [name].")
req = self._pack_request_package(feed_dict) req = self._pack_request_package(feed_dict)
if not asyn: if not asyn:
resp = self._stub.inference(req) resp = self._stub.inference(req)
return self._unpack_response_package(resp) return self._unpack_response_package(resp, fetch)
else: else:
call_future = self._stub.inference.future(req) call_future = self._stub.inference.future(req)
return PipelinePredictFuture( return PipelinePredictFuture(
......
...@@ -45,7 +45,7 @@ from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcod ...@@ -45,7 +45,7 @@ from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcod
from .profiler import TimeProfiler from .profiler import TimeProfiler
from .util import NameGenerator from .util import NameGenerator
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger()
_profiler = TimeProfiler() _profiler = TimeProfiler()
...@@ -235,6 +235,10 @@ class PipelineServer(object): ...@@ -235,6 +235,10 @@ class PipelineServer(object):
return use_ops, succ_ops_of_use_op return use_ops, succ_ops_of_use_op
use_ops, out_degree_ops = get_use_ops(response_op) use_ops, out_degree_ops = get_use_ops(response_op)
_LOGGER.info("================= use op ==================")
for op in use_ops:
_LOGGER.info(op.name)
_LOGGER.info("===========================================")
if len(use_ops) <= 1: if len(use_ops) <= 1:
raise Exception( raise Exception(
"Besides RequestOp and ResponseOp, there should be at least one Op in DAG." "Besides RequestOp and ResponseOp, there should be at least one Op in DAG."
......
...@@ -24,7 +24,7 @@ else: ...@@ -24,7 +24,7 @@ else:
raise Exception("Error Python version") raise Exception("Error Python version")
import time import time
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger()
class TimeProfiler(object): class TimeProfiler(object):
...@@ -58,7 +58,7 @@ class TimeProfiler(object): ...@@ -58,7 +58,7 @@ class TimeProfiler(object):
print_str += "{}_{}:{} ".format(name, tag, timestamp) print_str += "{}_{}:{} ".format(name, tag, timestamp)
else: else:
tmp[name] = (tag, timestamp) tmp[name] = (tag, timestamp)
print_str += "\n" print_str = "\n{}\n".format(print_str)
sys.stderr.write(print_str) sys.stderr.write(print_str)
for name, item in tmp.items(): for name, item in tmp.items():
tag, timestamp = item tag, timestamp = item
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册