提交 5936dfc6 编写于 作者: B barrierye

add ResposeOp and ResquestOp

上级 5ad0861c
use_multithread: true
client_type: brpc
retry: 2
retry: 1
profile: false
prot: 8080
worker_num: 2
......@@ -23,7 +23,7 @@ lp_wrapper = lp(client.predict)
words = 'i am very sad | 0'
for i in range(10):
for i in range(1):
fetch_map = lp_wrapper(feed_dict={"words": words}, fetch=["prediction"])
print(fetch_map)
......
......@@ -13,43 +13,66 @@
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_server.pipeline import Op, ReadOp
from paddle_serving_server.pipeline import Op, RequestOp, ResponseOp
from paddle_serving_server.pipeline import PipelineServer
from paddle_serving_server.pipeline.proto import pipeline_service_pb2
from paddle_serving_server.pipeline.channel import ChannelDataEcode
import numpy as np
import logging
from paddle_serving_app.reader import IMDBDataset
_LOGGER = logging.getLogger(__name__)
logging.basicConfig(
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M',
level=logging.INFO)
level=logging.DEBUG)
class ImdbOp(Op):
class ImdbRequestOp(RequestOp):
def load_user_resources(self):
self.imdb_dataset = IMDBDataset()
self.imdb_dataset.load_resource('imdb.vocab')
def preprocess(self, input_data):
data = input_data.parse()
word_ids, _ = self.imdb_dataset.get_words_and_label(data['words'])
return {"words": word_ids}
def unpack_request_package(self, request):
dictdata = {}
for idx, key in enumerate(request.key):
if key != "words":
continue
words = request.value[idx]
word_ids, _ = self.imdb_dataset.get_words_and_label(words)
dictdata[key] = np.array(word_ids)
return dictdata
class CombineOp(Op):
def preprocess(self, input_data):
combined_prediction = 0
for op_name, channeldata in input_data.items():
data = channeldata.parse()
logging.info("{}: {}".format(op_name, data["prediction"]))
for op_name, data in input_data.items():
_LOGGER.info("{}: {}".format(op_name, data["prediction"]))
combined_prediction += data["prediction"]
data = {"prediction": combined_prediction / 2}
return data
read_op = ReadOp()
bow_op = ImdbOp(
name="bow",
class ImdbResponseOp(ResponseOp):
# Here ImdbResponseOp is consistent with the default ResponseOp implementation
def pack_response_package(self, channeldata):
resp = pipeline_service_pb2.Response()
resp.ecode = channeldata.ecode
if resp.ecode == ChannelDataEcode.OK.value:
feed = channeldata.parse()
# ndarray to string
for name, var in feed.items():
resp.value.append(var.__repr__())
resp.key.append(name)
else:
resp.error_info = channeldata.error_info
return resp
read_op = ImdbRequestOp()
bow_op = Op(name="bow",
input_ops=[read_op],
server_endpoints=["127.0.0.1:9393"],
fetch_list=["prediction"],
......@@ -57,8 +80,7 @@ bow_op = ImdbOp(
concurrency=1,
timeout=-1,
retry=1)
cnn_op = ImdbOp(
name="cnn",
cnn_op = Op(name="cnn",
input_ops=[read_op],
server_endpoints=["127.0.0.1:9292"],
fetch_list=["prediction"],
......@@ -73,9 +95,16 @@ combine_op = CombineOp(
timeout=-1,
retry=1)
# fetch output of bow_op
# response_op = ImdbResponseOp(input_ops=[bow_op])
# fetch output of combine_op
response_op = ImdbResponseOp(input_ops=[combine_op])
# use default ResponseOp implementation
# response_op = ResponseOp(input_ops=[combine_op])
server = PipelineServer()
server.add_ops([read_op, bow_op, cnn_op, combine_op])
#server.set_response_op(bow_op)
server.set_response_op(combine_op)
server.set_response_op(response_op)
server.prepare_server('config.yml')
server.run_server()
......@@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from operator import Op, ReadOp
from operator import Op, RequestOp, ResponseOp
from pipeline_server import PipelineServer
from pipeline_client import PipelineClient
......@@ -27,6 +27,8 @@ import logging
import enum
import copy
_LOGGER = logging.getLogger(__name__)
class ChannelDataEcode(enum.Enum):
OK = 0
......@@ -71,12 +73,12 @@ class ChannelData(object):
ecode, error_info = ChannelData.check_npdata(npdata)
if ecode != ChannelDataEcode.OK.value:
datatype = ChannelDataType.ERROR.value
logging.error(error_info)
_LOGGER.error(error_info)
elif datatype == ChannelDataType.DICT.value:
ecode, error_info = ChannelData.check_dictdata(dictdata)
if ecode != ChannelDataEcode.OK.value:
datatype = ChannelDataType.ERROR.value
logging.error(error_info)
_LOGGER.error(error_info)
else:
raise ValueError("datatype not match")
self.datatype = datatype
......@@ -92,8 +94,8 @@ class ChannelData(object):
error_info = None
if not isinstance(dictdata, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of postped_data must " \
"be dict, but get {}".format(type(dictdata))
error_info = "the value of data must " \
"be dict, but get {}.".format(type(dictdata))
return ecode, error_info
@staticmethod
......@@ -103,8 +105,8 @@ class ChannelData(object):
for _, value in npdata.items():
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of postped_data must " \
"be np.ndarray, but get {}".format(type(value))
error_info = "the value of data must " \
"be np.ndarray, but get {}.".format(type(value))
break
return ecode, error_info
......@@ -200,7 +202,7 @@ class ProcessChannel(multiprocessing.queues.Queue):
self._idx_consumer_num[0] += 1
def push(self, channeldata, op_name=None):
logging.debug(
_LOGGER.debug(
self._log("{} try to push data: {}".format(op_name,
channeldata.__str__())))
if len(self._producers) == 0:
......@@ -212,16 +214,16 @@ class ProcessChannel(multiprocessing.queues.Queue):
with self._cv:
while self._stop is False:
try:
self.put(channeldata, timeout=0)
self.put({op_name: channeldata}, timeout=0)
break
except Queue.Full:
self._cv.wait()
logging.debug(
_LOGGER.debug(
self._log("{} channel size: {}".format(op_name,
self.qsize())))
self._cv.notify_all()
logging.debug(self._log("{} notify all".format(op_name)))
logging.debug(self._log("{} push data succ!".format(op_name)))
_LOGGER.debug(self._log("{} notify all".format(op_name)))
_LOGGER.debug(self._log("{} push data succ!".format(op_name)))
return True
elif op_name is None:
raise Exception(
......@@ -232,7 +234,7 @@ class ProcessChannel(multiprocessing.queues.Queue):
data_id = channeldata.id
put_data = None
with self._cv:
logging.debug(self._log("{} get lock".format(op_name)))
_LOGGER.debug(self._log("{} get lock".format(op_name)))
if data_id not in self._push_res:
self._push_res[data_id] = {
name: None
......@@ -253,13 +255,13 @@ class ProcessChannel(multiprocessing.queues.Queue):
self._producer_res_count[data_id] += 1
if put_data is None:
logging.debug(
_LOGGER.debug(
self._log("{} push data succ, but not push to queue.".
format(op_name)))
else:
while self._stop is False:
try:
logging.debug(
_LOGGER.debug(
self._log("{} push data succ: {}".format(
op_name, put_data.__str__())))
self.put(put_data, timeout=0)
......@@ -267,13 +269,13 @@ class ProcessChannel(multiprocessing.queues.Queue):
except Queue.Empty:
self._cv.wait()
logging.debug(
_LOGGER.debug(
self._log("multi | {} push data succ!".format(op_name)))
self._cv.notify_all()
return True
def front(self, op_name=None):
logging.debug(self._log("{} try to get data...".format(op_name)))
_LOGGER.debug(self._log("{} try to get data...".format(op_name)))
if len(self._consumers) == 0:
raise Exception(
self._log(
......@@ -284,7 +286,7 @@ class ProcessChannel(multiprocessing.queues.Queue):
with self._cv:
while self._stop is False and resp is None:
try:
logging.debug(
_LOGGER.debug(
self._log("{} try to get(with channel empty: {})".
format(op_name, self.empty())))
# For queue multiprocess: after putting an object on
......@@ -296,12 +298,12 @@ class ProcessChannel(multiprocessing.queues.Queue):
resp = self.get(timeout=1e-3)
break
except Queue.Empty:
logging.debug(
_LOGGER.debug(
self._log(
"{} wait for empty queue(with channel empty: {})".
format(op_name, self.empty())))
self._cv.wait()
logging.debug(
_LOGGER.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__(
))))
return resp
......@@ -315,13 +317,13 @@ class ProcessChannel(multiprocessing.queues.Queue):
while self._stop is False and self._consumers[
op_name] - self._consumer_base_idx.value >= len(
self._front_res):
logging.debug(
_LOGGER.debug(
self._log(
"({}) B self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}".
format(op_name, self._consumers, self.
_consumer_base_idx.value, len(self._front_res))))
try:
logging.debug(
_LOGGER.debug(
self._log("{} try to get(with channel size: {})".format(
op_name, self.qsize())))
# For queue multiprocess: after putting an object on
......@@ -334,7 +336,7 @@ class ProcessChannel(multiprocessing.queues.Queue):
self._front_res.append(channeldata)
break
except Queue.Empty:
logging.debug(
_LOGGER.debug(
self._log(
"{} wait for empty queue(with channel size: {})".
format(op_name, self.qsize())))
......@@ -344,7 +346,7 @@ class ProcessChannel(multiprocessing.queues.Queue):
base_idx = self._consumer_base_idx.value
data_idx = consumer_idx - base_idx
resp = self._front_res[data_idx]
logging.debug(self._log("{} get data: {}".format(op_name, resp)))
_LOGGER.debug(self._log("{} get data: {}".format(op_name, resp)))
self._idx_consumer_num[consumer_idx] -= 1
if consumer_idx == base_idx and self._idx_consumer_num[
......@@ -358,15 +360,15 @@ class ProcessChannel(multiprocessing.queues.Queue):
if self._idx_consumer_num.get(new_consumer_idx) is None:
self._idx_consumer_num[new_consumer_idx] = 0
self._idx_consumer_num[new_consumer_idx] += 1
logging.debug(
_LOGGER.debug(
self._log(
"({}) A self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}".
format(op_name, self._consumers, self._consumer_base_idx.
value, len(self._front_res))))
logging.debug(self._log("{} notify all".format(op_name)))
_LOGGER.debug(self._log("{} notify all".format(op_name)))
self._cv.notify_all()
logging.debug(self._log("multi | {} get data succ!".format(op_name)))
_LOGGER.debug(self._log("multi | {} get data succ!".format(op_name)))
return resp # reference, read only
def stop(self):
......@@ -444,7 +446,7 @@ class ThreadChannel(Queue.Queue):
self._idx_consumer_num[0] += 1
def push(self, channeldata, op_name=None):
logging.debug(
_LOGGER.debug(
self._log("{} try to push data: {}".format(op_name,
channeldata.__str__())))
if len(self._producers) == 0:
......@@ -456,12 +458,12 @@ class ThreadChannel(Queue.Queue):
with self._cv:
while self._stop is False:
try:
self.put(channeldata, timeout=0)
self.put({op_name: channeldata}, timeout=0)
break
except Queue.Full:
self._cv.wait()
self._cv.notify_all()
logging.debug(self._log("{} push data succ!".format(op_name)))
_LOGGER.debug(self._log("{} push data succ!".format(op_name)))
return True
elif op_name is None:
raise Exception(
......@@ -472,7 +474,7 @@ class ThreadChannel(Queue.Queue):
data_id = channeldata.id
put_data = None
with self._cv:
logging.debug(self._log("{} get lock".format(op_name)))
_LOGGER.debug(self._log("{} get lock".format(op_name)))
if data_id not in self._push_res:
self._push_res[data_id] = {
name: None
......@@ -488,7 +490,7 @@ class ThreadChannel(Queue.Queue):
self._producer_res_count[data_id] += 1
if put_data is None:
logging.debug(
_LOGGER.debug(
self._log("{} push data succ, but not push to queue.".
format(op_name)))
else:
......@@ -499,13 +501,13 @@ class ThreadChannel(Queue.Queue):
except Queue.Empty:
self._cv.wait()
logging.debug(
_LOGGER.debug(
self._log("multi | {} push data succ!".format(op_name)))
self._cv.notify_all()
return True
def front(self, op_name=None):
logging.debug(self._log("{} try to get data".format(op_name)))
_LOGGER.debug(self._log("{} try to get data".format(op_name)))
if len(self._consumers) == 0:
raise Exception(
self._log(
......@@ -520,7 +522,7 @@ class ThreadChannel(Queue.Queue):
break
except Queue.Empty:
self._cv.wait()
logging.debug(
_LOGGER.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__(
))))
return resp
......@@ -544,7 +546,7 @@ class ThreadChannel(Queue.Queue):
base_idx = self._consumer_base_idx
data_idx = consumer_idx - base_idx
resp = self._front_res[data_idx]
logging.debug(self._log("{} get data: {}".format(op_name, resp)))
_LOGGER.debug(self._log("{} get data: {}".format(op_name, resp)))
self._idx_consumer_num[consumer_idx] -= 1
if consumer_idx == base_idx and self._idx_consumer_num[
......@@ -561,7 +563,7 @@ class ThreadChannel(Queue.Queue):
self._cv.notify_all()
logging.debug(self._log("multi | {} get data succ!".format(op_name)))
_LOGGER.debug(self._log("multi | {} get data succ!".format(op_name)))
# return resp # reference, read only
return copy.deepcopy(resp)
......
......@@ -20,9 +20,11 @@ from concurrent import futures
import logging
import func_timeout
from .proto import pipeline_service_pb2
from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType
from .util import NameGenerator
_LOGGER = logging.getLogger(__name__)
_op_name_gen = NameGenerator("Op")
......@@ -67,17 +69,17 @@ class Op(object):
def init_client(self, client_type, client_config, server_endpoints,
fetch_names):
if self.with_serving == False:
logging.debug("{} no client".format(self.name))
_LOGGER.debug("{} no client".format(self.name))
return
logging.debug("{} client_config: {}".format(self.name, client_config))
logging.debug("{} fetch_names: {}".format(self.name, fetch_names))
_LOGGER.debug("{} client_config: {}".format(self.name, client_config))
_LOGGER.debug("{} fetch_names: {}".format(self.name, fetch_names))
if client_type == 'brpc':
self._client = Client()
self._client.load_client_config(client_config)
elif client_type == 'grpc':
self._client = MultiLangClient()
else:
raise ValueError("unknow client type: {}".format(client_type))
self._client.load_client_config(client_config)
self._client.connect(server_endpoints)
self._fetch_names = fetch_names
......@@ -117,38 +119,31 @@ class Op(object):
channel.add_producer(self.name)
self._outputs.append(channel)
def preprocess(self, channeldata):
def preprocess(self, input_dicts):
# multiple previous Op
if isinstance(channeldata, dict):
if len(input_dicts) != 1:
raise NotImplementedError(
'this Op has multiple previous inputs. Please override this method'
'this Op has multiple previous inputs. Please override this func.'
)
if channeldata.datatype is not ChannelDataType.CHANNEL_NPDATA.value:
raise NotImplementedError(
'datatype in channeldata is not CHANNEL_NPDATA({}). '
'Please override this method'.format(channeldata.datatype))
# get numpy dict
feed_data = channeldata.parse()
return feed_data
(_, input_dict), = input_dicts.items()
return input_dict
def process(self, feed_dict):
if not isinstance(feed_dict, dict):
raise Exception(
self._log(
'feed_dict must be dict type(the output of preprocess()), but get {}'.
format(type(feed_dict))))
logging.debug(self._log('feed_dict: {}'.format(feed_dict)))
logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
err, err_info = ChannelData.check_npdata(feed_dict)
if err != 0:
raise NotImplementedError(
"{} Please override preprocess func.".format(err_info))
_LOGGER.debug(self._log('feed_dict: {}'.format(feed_dict)))
_LOGGER.debug(self._log('fetch: {}'.format(self._fetch_names)))
if isinstance(self._client, MultiLangClient):
call_result = self._client.predict(
feed=feed_dict, fetch=self._fetch_names)
logging.debug(self._log("get call_result"))
_LOGGER.debug(self._log("get call_result"))
else:
call_result = self._client.predict(
feed=feed_dict, fetch=self._fetch_names)
logging.debug(self._log("get fetch_dict"))
_LOGGER.debug(self._log("get fetch_dict"))
return call_result
def postprocess(self, fetch_dict):
......@@ -160,21 +155,19 @@ class Op(object):
channel.stop()
self._is_run = False
def _parse_channeldata(self, channeldata):
def _parse_channeldata(self, channeldata_dict):
data_id, error_channeldata = None, None
if isinstance(channeldata, dict):
parsed_data = {}
key = list(channeldata.keys())[0]
data_id = channeldata[key].id
for _, data in channeldata.items():
key = list(channeldata_dict.keys())[0]
data_id = channeldata_dict[key].id
for name, data in channeldata_dict.items():
if data.ecode != ChannelDataEcode.OK.value:
error_channeldata = data
break
else:
data_id = channeldata.id
if channeldata.ecode != ChannelDataEcode.OK.value:
error_channeldata = channeldata
return data_id, error_channeldata
parsed_data[name] = data.parse()
return data_id, error_channeldata, parsed_data
def _push_to_output_channels(self, data, channels, name=None):
if name is None:
......@@ -229,11 +222,12 @@ class Op(object):
self._is_run = True
while self._is_run:
self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
channeldata = input_channel.front(self.name)
channeldata_dict = input_channel.front(self.name)
self._profiler_record("{}-get#{}_1".format(op_info_prefix, tid))
logging.debug(log("input_data: {}".format(channeldata)))
_LOGGER.debug(log("input_data: {}".format(channeldata_dict)))
data_id, error_channeldata = self._parse_channeldata(channeldata)
data_id, error_channeldata, parsed_data = self._parse_channeldata(
channeldata_dict)
# error data in predecessor Op
if error_channeldata is not None:
......@@ -245,13 +239,13 @@ class Op(object):
try:
self._profiler_record("{}-prep#{}_0".format(op_info_prefix,
tid))
preped_data = self.preprocess(channeldata)
preped_data = self.preprocess(parsed_data)
self._profiler_record("{}-prep#{}_1".format(op_info_prefix,
tid))
except NotImplementedError as e:
# preprocess function not implemented
error_info = log(e)
logging.error(error_info)
_LOGGER.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
......@@ -262,7 +256,7 @@ class Op(object):
except TypeError as e:
# Error type in channeldata.datatype
error_info = log(e)
logging.error(error_info)
_LOGGER.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value,
......@@ -272,7 +266,7 @@ class Op(object):
continue
except Exception as e:
error_info = log(e)
logging.error(error_info)
_LOGGER.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
......@@ -293,7 +287,7 @@ class Op(object):
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
logging.error(error_info)
_LOGGER.error(error_info)
else:
for i in range(self._retry):
try:
......@@ -305,14 +299,14 @@ class Op(object):
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
error_info = log(e)
logging.error(error_info)
_LOGGER.error(error_info)
else:
logging.warn(
_LOGGER.warn(
log("timeout, retry({})".format(i + 1)))
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
logging.error(error_info)
_LOGGER.error(error_info)
break
else:
break
......@@ -346,7 +340,7 @@ class Op(object):
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
logging.error(error_info)
_LOGGER.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info, data_id=data_id),
......@@ -356,7 +350,7 @@ class Op(object):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("output of postprocess funticon must be " \
"dict type, but get {}".format(type(postped_data)))
logging.error(error_info)
_LOGGER.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info, data_id=data_id),
......@@ -385,11 +379,62 @@ class Op(object):
return "{} {}".format(self.name, info)
class ReadOp(Op):
class RequestOp(Op):
""" RequestOp do not run preprocess, process, postprocess. """
def __init__(self, concurrency=1):
# PipelineService.name = "#G"
super(ReadOp, self).__init__(
super(RequestOp, self).__init__(
name="#G", input_ops=[], concurrency=concurrency)
# load user resources
self.load_user_resources()
def unpack_request_package(self, request):
dictdata = {}
for idx, key in enumerate(request.key):
dictdata[key] = request.value[idx]
return dictdata
class ResponseOp(Op):
""" ResponseOp do not run preprocess, process, postprocess. """
def __init__(self, input_ops, concurrency=1):
super(ResponseOp, self).__init__(
name="#R", input_ops=input_ops, concurrency=concurrency)
# load user resources
self.load_user_resources()
def pack_response_package(self, channeldata):
resp = pipeline_service_pb2.Response()
resp.ecode = channeldata.ecode
if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
feed = channeldata.parse()
# ndarray to string:
# https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
for name, var in feed.items():
resp.value.append(var.__repr__())
resp.key.append(name)
elif channeldata.datatype == ChannelDataType.DICT.value:
feed = channeldata.parse()
for name, var in feed.items():
if not isinstance(var, str):
resp.ecode = ChannelDataEcode.TYPE_ERROR.value
resp.error_info = self._log(
"fetch var type must be str({}).".format(
type(var)))
break
resp.value.append(var)
resp.key.append(name)
else:
resp.ecode = ChannelDataEcode.TYPE_ERROR.value
resp.error_info = self._log(
"Error type({}) in datatype.".format(channeldata.datatype))
_LOGGER.error(resp.error_info)
else:
resp.error_info = channeldata.error_info
return resp
class VirtualOp(Op):
......@@ -427,17 +472,11 @@ class VirtualOp(Op):
self._is_run = True
while self._is_run:
self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
channeldata = input_channel.front(self.name)
channeldata_dict = input_channel.front(self.name)
self._profiler_record("{}-get#{}_1".format(op_info_prefix, tid))
self._profiler_record("{}-push#{}_0".format(op_info_prefix, tid))
if isinstance(channeldata, dict):
for name, data in channeldata.items():
for name, data in channeldata_dict.items():
self._push_to_output_channels(
data, channels=output_channels, name=name)
else:
self._push_to_output_channels(
channeldata,
channels=output_channels,
name=self._virtual_pred_ops[0].name)
self._profiler_record("{}-push#{}_1".format(op_info_prefix, tid))
......@@ -14,10 +14,13 @@
# pylint: disable=doc-string-missing
import grpc
import numpy as np
from numpy import array
from numpy import *
import logging
from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc
_LOGGER = logging.getLogger(__name__)
class PipelineClient(object):
def __init__(self):
......
......@@ -40,22 +40,24 @@ import yaml
from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc
from .operator import Op, ReadOp, VirtualOp
from .operator import Op, RequestOp, ResponseOp, VirtualOp
from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcode, ChannelDataType
from .profiler import TimeProfiler
from .util import NameGenerator
_LOGGER = logging.getLogger(__name__)
_profiler = TimeProfiler()
class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
def __init__(self, in_channel, out_channel, retry=2):
def __init__(self, in_channel, out_channel, unpack_func, pack_func,
retry=2):
super(PipelineService, self).__init__()
self.name = "#G"
self.set_in_channel(in_channel)
self.set_out_channel(out_channel)
logging.debug(self._log(in_channel.debug()))
logging.debug(self._log(out_channel.debug()))
_LOGGER.debug(self._log(in_channel.debug()))
_LOGGER.debug(self._log(out_channel.debug()))
#TODO:
# multi-lock for different clients
# diffenert lock for server and client
......@@ -64,6 +66,8 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
self._globel_resp_dict = {}
self._id_counter = 0
self._retry = retry
self._pack_func = pack_func
self._unpack_func = unpack_func
self._recive_func = threading.Thread(
target=PipelineService._recive_out_channel_func, args=(self, ))
self._recive_func.start()
......@@ -89,7 +93,10 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
def _recive_out_channel_func(self):
while True:
channeldata = self._out_channel.front(self.name)
channeldata_dict = self._out_channel.front(self.name)
if len(channeldata_dict) != 1:
raise Exception("out_channel cannot have multiple input ops")
(_, channeldata), = channeldata_dict.items()
if not isinstance(channeldata, ChannelData):
raise TypeError(
self._log('data must be ChannelData type, but get {}'.
......@@ -114,18 +121,15 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
return resp
def _pack_data_for_infer(self, request):
logging.debug(self._log('start inferce'))
_LOGGER.debug(self._log('start inferce'))
data_id = self._get_next_id()
dictdata = {}
dictdata = None
try:
for idx, key in enumerate(request.key):
logging.debug(self._log('key: {}'.format(key)))
logging.debug(self._log('value: {}'.format(request.value[idx])))
dictdata[key] = request.value[idx]
dictdata = self._unpack_func(request)
except Exception as e:
return ChannelData(
ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
error_info="rpc package error",
error_info="rpc package error: {}".format(e),
data_id=data_id), data_id
else:
return ChannelData(
......@@ -134,35 +138,8 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
data_id=data_id), data_id
def _pack_data_for_resp(self, channeldata):
logging.debug(self._log('get channeldata'))
resp = pipeline_service_pb2.Response()
resp.ecode = channeldata.ecode
if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
feed = channeldata.parse()
# ndarray to string
for name, var in feed.items():
resp.value.append(var.__repr__())
resp.key.append(name)
elif channeldata.datatype == ChannelDataType.DICT.value:
feed = channeldata.parse()
for name, var in feed.items():
if not isinstance(var, str):
resp.ecode = ChannelDataEcode.TYPE_ERROR.value
resp.error_info = self._log(
"fetch var type must be str({}).".format(
type(var)))
break
resp.value.append(var)
resp.key.append(name)
else:
resp.ecode = ChannelDataEcode.TYPE_ERROR.value
resp.error_info = self._log(
"Error type({}) in datatype.".format(channeldata.datatype))
logging.error(resp.error_info)
else:
resp.error_info = channeldata.error_info
return resp
_LOGGER.debug(self._log('get channeldata'))
return self._pack_func(channeldata)
def inference(self, request, context):
_profiler.record("{}-prepack_0".format(self.name))
......@@ -171,12 +148,12 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
resp_channeldata = None
for i in range(self._retry):
logging.debug(self._log('push data'))
_LOGGER.debug(self._log('push data'))
_profiler.record("{}-push_0".format(self.name))
self._in_channel.push(data, self.name)
_profiler.record("{}-push_1".format(self.name))
logging.debug(self._log('wait for infer'))
_LOGGER.debug(self._log('wait for infer'))
_profiler.record("{}-fetch_0".format(self.name))
resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self.name))
......@@ -184,7 +161,7 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
if resp_channeldata.ecode == ChannelDataEcode.OK.value:
break
if i + 1 < self._retry:
logging.warn("retry({}): {}".format(
_LOGGER.warn("retry({}): {}".format(
i + 1, resp_channeldata.error_info))
_profiler.record("{}-postpack_0".format(self.name))
......@@ -197,32 +174,27 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer):
class PipelineServer(object):
def __init__(self):
self._channels = []
self._user_ops = []
self._actual_ops = []
self._port = None
self._worker_num = None
self._in_channel = None
self._out_channel = None
self._response_op = None
self._pack_func = None
self._unpack_func = None
def add_channel(self, channel):
self._channels.append(channel)
def add_op(self, op):
self._user_ops.append(op)
def add_ops(self, ops):
self._user_ops.extend(ops)
def gen_desc(self):
logging.info('here will generate desc for PAAS')
_LOGGER.info('here will generate desc for PAAS')
pass
def set_response_op(self, response_op):
if not isinstance(response_op, Op):
raise Exception("response_op must be Op type.")
if len(response_op.get_input_ops()) == 0:
raise Exception("response_op cannot be ReadOp.")
if len(response_op.get_input_ops()) != 1:
raise Exception("response_op can only have one previous op.")
self._response_op = response_op
def _topo_sort(self, response_op):
......@@ -230,18 +202,20 @@ class PipelineServer(object):
raise Exception("response_op has not been set.")
def get_use_ops(root):
# root: response_op
unique_names = set()
use_ops = set()
succ_ops_of_use_op = {} # {op_name: succ_ops}
que = Queue.Queue()
que.put(root)
use_ops.add(root)
unique_names.add(root.name)
#use_ops.add(root)
#unique_names.add(root.name)
while que.qsize() != 0:
op = que.get()
for pred_op in op.get_input_ops():
if pred_op.name not in succ_ops_of_use_op:
succ_ops_of_use_op[pred_op.name] = []
if op != root:
succ_ops_of_use_op[pred_op.name].append(op)
if pred_op not in use_ops:
que.put(pred_op)
......@@ -268,7 +242,8 @@ class PipelineServer(object):
zero_indegree_num += 1
if zero_indegree_num != 1:
raise Exception("DAG contains multiple input Ops")
ques[que_idx].put(response_op)
last_op = response_op.get_input_ops()[0]
ques[que_idx].put(last_op)
# topo sort to get dag_views
dag_views = []
......@@ -344,7 +319,7 @@ class PipelineServer(object):
continue
channel = gen_channel(channel_name_gen)
channels.append(channel)
logging.debug("{} => {}".format(channel.name, op.name))
_LOGGER.debug("{} => {}".format(channel.name, op.name))
op.add_input_channel(channel)
pred_ops = pred_op_of_next_view_op[op.name]
if v_idx == 0:
......@@ -352,7 +327,7 @@ class PipelineServer(object):
else:
# if pred_op is virtual op, it will use ancestors as producers to channel
for pred_op in pred_ops:
logging.debug("{} => {}".format(pred_op.name,
_LOGGER.debug("{} => {}".format(pred_op.name,
channel.name))
pred_op.add_output_channel(channel)
processed_op.add(op.name)
......@@ -369,24 +344,26 @@ class PipelineServer(object):
same_flag = False
break
if same_flag:
logging.debug("{} => {}".format(channel.name,
_LOGGER.debug("{} => {}".format(channel.name,
other_op.name))
other_op.add_input_channel(channel)
processed_op.add(other_op.name)
output_channel = gen_channel(channel_name_gen)
channels.append(output_channel)
response_op.add_output_channel(output_channel)
last_op.add_output_channel(output_channel)
pack_func, unpack_func = None, None
pack_func = self._response_op.pack_response_package
self._actual_ops = virtual_ops
for op in use_ops:
if len(op.get_input_ops()) == 0:
# pass read op
unpack_func = op.unpack_request_package
continue
self._actual_ops.append(op)
self._channels = channels
for c in channels:
logging.debug(c.debug())
return input_channel, output_channel
_LOGGER.debug(c.debug())
return input_channel, output_channel, pack_func, unpack_func
def _port_is_available(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
......@@ -414,7 +391,8 @@ class PipelineServer(object):
"profile cannot be used in multiprocess version temporarily")
_profiler.enable(profile)
input_channel, output_channel = self._topo_sort(self._response_op)
input_channel, output_channel, self._pack_func, self._unpack_func = self._topo_sort(
self._response_op)
self._in_channel = input_channel
self._out_channel = output_channel
for op in self._actual_ops:
......@@ -443,7 +421,8 @@ class PipelineServer(object):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(
PipelineService(self._in_channel, self._out_channel, self._retry),
PipelineService(self._in_channel, self._out_channel,
self._unpack_func, self._pack_func, self._retry),
server)
server.add_insecure_port('[::]:{}'.format(self._port))
server.start()
......@@ -454,4 +433,4 @@ class PipelineServer(object):
def prepare_serving(self, op):
# run a server (not in PyServing)
logging.info("run a server (not in PyServing)")
_LOGGER.info("run a server (not in PyServing)")
......@@ -15,6 +15,7 @@
import os
import sys
import logging
if sys.version_info.major == 2:
import Queue
elif sys.version_info.major == 3:
......@@ -23,6 +24,8 @@ else:
raise Exception("Error Python version")
import time
_LOGGER = logging.getLogger(__name__)
class TimeProfiler(object):
def __init__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册