提交 9226c0d3 编写于 作者: W wangjiawei04

bprc version

上级 7a5221b6
...@@ -14,18 +14,12 @@ ...@@ -14,18 +14,12 @@
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
import threading import threading
import multiprocessing import multiprocessing
import multiprocessing.queues import Queue
import sys
if sys.version_info.major == 2:
import Queue
elif sys.version_info.major == 3:
import queue as Queue
else:
raise Exception("Error Python version")
import os import os
import sys
import paddle_serving_server import paddle_serving_server
from paddle_serving_client import MultiLangClient as Client #from paddle_serving_client import MultiLangClient as Client
from paddle_serving_client import MultiLangPredictFuture from paddle_serving_client import Client
from concurrent import futures from concurrent import futures
import numpy as np import numpy as np
import grpc import grpc
...@@ -116,27 +110,34 @@ class ChannelData(object): ...@@ -116,27 +110,34 @@ class ChannelData(object):
4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id) 4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id)
5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id) 5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
6. ChannelData(ecode, error_info, data_id) 6. ChannelData(ecode, error_info, data_id)
Protobufs are not pickle-able:
https://stackoverflow.com/questions/55344376/how-to-import-protobuf-module
''' '''
if ecode is not None: if ecode is not None:
if data_id is None or error_info is None: if data_id is None or error_info is None:
raise ValueError("data_id and error_info cannot be None") raise ValueError("data_id and error_info cannot be None")
pbdata = channel_pb2.ChannelData()
pbdata.ecode = ecode
pbdata.id = data_id
pbdata.error_info = error_info
datatype = ChannelDataType.ERROR.value datatype = ChannelDataType.ERROR.value
else: else:
if datatype == ChannelDataType.CHANNEL_FUTURE.value: if datatype == ChannelDataType.CHANNEL_FUTURE.value:
if pbdata is None:
if data_id is None: if data_id is None:
raise ValueError("data_id cannot be None") raise ValueError("data_id cannot be None")
ecode = ChannelDataEcode.OK.value pbdata = channel_pb2.ChannelData()
pbdata.ecode = ChannelDataEcode.OK.value
pbdata.id = data_id
elif datatype == ChannelDataType.CHANNEL_PBDATA.value: elif datatype == ChannelDataType.CHANNEL_PBDATA.value:
if pbdata is None: if pbdata is None:
if data_id is None: if data_id is None:
raise ValueError("data_id cannot be None") raise ValueError("data_id cannot be None")
pbdata = channel_pb2.ChannelData() pbdata = channel_pb2.ChannelData()
pbdata.id = data_id
ecode, error_info = self._check_npdata(npdata) ecode, error_info = self._check_npdata(npdata)
if ecode != ChannelDataEcode.OK.value: pbdata.ecode = ecode
logging.error(error_info) if pbdata.ecode != ChannelDataEcode.OK.value:
pbdata.error_info = error_info
logging.error(pbdata.error_info)
else: else:
for name, value in npdata.items(): for name, value in npdata.items():
inst = channel_pb2.Inst() inst = channel_pb2.Inst()
...@@ -148,18 +149,23 @@ class ChannelData(object): ...@@ -148,18 +149,23 @@ class ChannelData(object):
pbdata.insts.append(inst) pbdata.insts.append(inst)
elif datatype == ChannelDataType.CHANNEL_NPDATA.value: elif datatype == ChannelDataType.CHANNEL_NPDATA.value:
ecode, error_info = self._check_npdata(npdata) ecode, error_info = self._check_npdata(npdata)
if ecode != ChannelDataEcode.OK.value: pbdata = channel_pb2.ChannelData()
logging.error(error_info) pbdata.id = data_id
pbdata.ecode = ecode
if pbdata.ecode != ChannelDataEcode.OK.value:
pbdata.error_info = error_info
logging.error(pbdata.error_info)
else: else:
raise ValueError("datatype not match") raise ValueError("datatype not match")
if not isinstance(pbdata, channel_pb2.ChannelData):
raise TypeError(
"pbdata must be pyserving_channel_pb2.ChannelData type({})".
format(type(pbdata)))
self.future = future self.future = future
self.pbdata = pbdata self.pbdata = pbdata
self.npdata = npdata self.npdata = npdata
self.datatype = datatype self.datatype = datatype
self.callback_func = callback_func self.callback_func = callback_func
self.id = data_id
self.ecode = ecode
self.error_info = error_info
def _check_npdata(self, npdata): def _check_npdata(self, npdata):
ecode = ChannelDataEcode.OK.value ecode = ChannelDataEcode.OK.value
...@@ -192,15 +198,15 @@ class ChannelData(object): ...@@ -192,15 +198,15 @@ class ChannelData(object):
elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value: elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
feed = self.npdata feed = self.npdata
else: else:
raise TypeError("Error type({}) in datatype.".format(self.datatype)) raise TypeError("Error type({}) in datatype.".format(datatype))
return feed return feed
def __str__(self): def __str__(self):
return "type[{}], ecode[{}], id[{}]".format( return "type[{}], ecode[{}]".format(
ChannelDataType(self.datatype).name, self.ecode, self.id) ChannelDataType(self.datatype).name, self.pbdata.ecode)
class Channel(multiprocessing.queues.Queue): class Channel(Queue.Queue):
""" """
The channel used for communication between Ops. The channel used for communication between Ops.
...@@ -218,36 +224,23 @@ class Channel(multiprocessing.queues.Queue): ...@@ -218,36 +224,23 @@ class Channel(multiprocessing.queues.Queue):
and can only be called during initialization. and can only be called during initialization.
""" """
def __init__(self, manager, name=None, maxsize=0, timeout=None): def __init__(self, name=None, maxsize=-1, timeout=None):
# https://stackoverflow.com/questions/39496554/cannot-subclass-multiprocessing-queue-in-python-3-5/ Queue.Queue.__init__(self, maxsize=maxsize)
if sys.version_info.major == 2:
super(Channel, self).__init__(maxsize=maxsize)
elif sys.version_info.major == 3:
super(Channel, self).__init__(
maxsize=maxsize, ctx=multiprocessing.get_context())
else:
raise Exception("Error Python version")
self._maxsize = maxsize self._maxsize = maxsize
self._timeout = timeout self._timeout = timeout
self.name = name self.name = name
self._stop = False self._stop = False
self._cv = multiprocessing.Condition() self._cv = threading.Condition()
self._producers = [] self._producers = []
self._producer_res_count = manager.dict() # {data_id: count} self._producer_res_count = {} # {data_id: count}
# self._producer_res_count = {} # {data_id: count} self._push_res = {} # {data_id: {op_name: data}}
self._push_res = manager.dict() # {data_id: {op_name: data}}
# self._push_res = {} # {data_id: {op_name: data}} self._consumers = {} # {op_name: idx}
self._idx_consumer_num = {} # {idx: num}
self._consumers = manager.dict() # {op_name: idx} self._consumer_base_idx = 0
# self._consumers = {} # {op_name: idx} self._front_res = []
self._idx_consumer_num = manager.dict() # {idx: num}
# self._idx_consumer_num = {} # {idx: num}
self._consumer_base_idx = manager.Value('i', 0)
# self._consumer_base_idx = 0
self._front_res = manager.list()
# self._front_res = []
def get_producers(self): def get_producers(self):
return self._producers return self._producers
...@@ -297,11 +290,7 @@ class Channel(multiprocessing.queues.Queue): ...@@ -297,11 +290,7 @@ class Channel(multiprocessing.queues.Queue):
break break
except Queue.Full: except Queue.Full:
self._cv.wait() self._cv.wait()
logging.debug(
self._log("{} channel size: {}".format(op_name,
self.qsize())))
self._cv.notify_all() self._cv.notify_all()
logging.debug(self._log("{} notify all".format(op_name)))
logging.debug(self._log("{} push data succ!".format(op_name))) logging.debug(self._log("{} push data succ!".format(op_name)))
return True return True
elif op_name is None: elif op_name is None:
...@@ -310,7 +299,7 @@ class Channel(multiprocessing.queues.Queue): ...@@ -310,7 +299,7 @@ class Channel(multiprocessing.queues.Queue):
"There are multiple producers, so op_name cannot be None.")) "There are multiple producers, so op_name cannot be None."))
producer_num = len(self._producers) producer_num = len(self._producers)
data_id = channeldata.id data_id = channeldata.pbdata.id
put_data = None put_data = None
with self._cv: with self._cv:
logging.debug(self._log("{} get lock".format(op_name))) logging.debug(self._log("{} get lock".format(op_name)))
...@@ -320,12 +309,7 @@ class Channel(multiprocessing.queues.Queue): ...@@ -320,12 +309,7 @@ class Channel(multiprocessing.queues.Queue):
for name in self._producers for name in self._producers
} }
self._producer_res_count[data_id] = 0 self._producer_res_count[data_id] = 0
# see: https://docs.python.org/3.6/library/multiprocessing.html?highlight=multiprocess#proxy-objects self._push_res[data_id][op_name] = channeldata
# self._push_res[data_id][op_name] = channeldata
tmp_push_res = self._push_res[data_id]
tmp_push_res[op_name] = channeldata
self._push_res[data_id] = tmp_push_res
if self._producer_res_count[data_id] + 1 == producer_num: if self._producer_res_count[data_id] + 1 == producer_num:
put_data = self._push_res[data_id] put_data = self._push_res[data_id]
self._push_res.pop(data_id) self._push_res.pop(data_id)
...@@ -340,9 +324,6 @@ class Channel(multiprocessing.queues.Queue): ...@@ -340,9 +324,6 @@ class Channel(multiprocessing.queues.Queue):
else: else:
while self._stop is False: while self._stop is False:
try: try:
logging.debug(
self._log("{} push data succ: {}".format(
op_name, put_data.__str__())))
self.put(put_data, timeout=0) self.put(put_data, timeout=0)
break break
except Queue.Empty: except Queue.Empty:
...@@ -354,7 +335,7 @@ class Channel(multiprocessing.queues.Queue): ...@@ -354,7 +335,7 @@ class Channel(multiprocessing.queues.Queue):
return True return True
def front(self, op_name=None): def front(self, op_name=None):
logging.debug(self._log("{} try to get data...".format(op_name))) logging.debug(self._log("{} try to get data".format(op_name)))
if len(self._consumers) == 0: if len(self._consumers) == 0:
raise Exception( raise Exception(
self._log( self._log(
...@@ -365,26 +346,9 @@ class Channel(multiprocessing.queues.Queue): ...@@ -365,26 +346,9 @@ class Channel(multiprocessing.queues.Queue):
with self._cv: with self._cv:
while self._stop is False and resp is None: while self._stop is False and resp is None:
try: try:
logging.debug(
self._log("{} try to get(with channel empty: {})".
format(op_name, self.empty())))
# For Python2, after putting an object on an empty queue there may
# be an infinitessimal delay before the queue's :meth:`~Queue.empty`
# see more:
# - https://bugs.python.org/issue18277
# - https://hg.python.org/cpython/rev/860fc6a2bd21
if sys.version_info.major == 2:
resp = self.get(timeout=1e-3)
elif sys.version_info.major == 3:
resp = self.get(timeout=0) resp = self.get(timeout=0)
else:
raise Exception("Error Python version")
break break
except Queue.Empty: except Queue.Empty:
logging.debug(
self._log(
"{} wait for empty queue(with channel empty: {})".
format(op_name, self.empty())))
self._cv.wait() self._cv.wait()
logging.debug( logging.debug(
self._log("{} get data succ: {}".format(op_name, resp.__str__( self._log("{} get data succ: {}".format(op_name, resp.__str__(
...@@ -398,39 +362,16 @@ class Channel(multiprocessing.queues.Queue): ...@@ -398,39 +362,16 @@ class Channel(multiprocessing.queues.Queue):
with self._cv: with self._cv:
# data_idx = consumer_idx - base_idx # data_idx = consumer_idx - base_idx
while self._stop is False and self._consumers[ while self._stop is False and self._consumers[
op_name] - self._consumer_base_idx.value >= len( op_name] - self._consumer_base_idx >= len(self._front_res):
self._front_res):
logging.debug(
self._log(
"({}) B self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}".
format(op_name, self._consumers, self.
_consumer_base_idx.value, len(self._front_res))))
try: try:
logging.debug(
self._log("{} try to get(with channel size: {})".format(
op_name, self.qsize())))
# For Python2, after putting an object on an empty queue there may
# be an infinitessimal delay before the queue's :meth:`~Queue.empty`
# see more:
# - https://bugs.python.org/issue18277
# - https://hg.python.org/cpython/rev/860fc6a2bd21
if sys.version_info.major == 2:
channeldata = self.get(timeout=1e-3)
elif sys.version_info.major == 3:
channeldata = self.get(timeout=0) channeldata = self.get(timeout=0)
else:
raise Exception("Error Python version")
self._front_res.append(channeldata) self._front_res.append(channeldata)
break break
except Queue.Empty: except Queue.Empty:
logging.debug(
self._log(
"{} wait for empty queue(with channel size: {})".
format(op_name, self.qsize())))
self._cv.wait() self._cv.wait()
consumer_idx = self._consumers[op_name] consumer_idx = self._consumers[op_name]
base_idx = self._consumer_base_idx.value base_idx = self._consumer_base_idx
data_idx = consumer_idx - base_idx data_idx = consumer_idx - base_idx
resp = self._front_res[data_idx] resp = self._front_res[data_idx]
logging.debug(self._log("{} get data: {}".format(op_name, resp))) logging.debug(self._log("{} get data: {}".format(op_name, resp)))
...@@ -440,19 +381,14 @@ class Channel(multiprocessing.queues.Queue): ...@@ -440,19 +381,14 @@ class Channel(multiprocessing.queues.Queue):
consumer_idx] == 0: consumer_idx] == 0:
self._idx_consumer_num.pop(consumer_idx) self._idx_consumer_num.pop(consumer_idx)
self._front_res.pop(0) self._front_res.pop(0)
self._consumer_base_idx.value += 1 self._consumer_base_idx += 1
self._consumers[op_name] += 1 self._consumers[op_name] += 1
new_consumer_idx = self._consumers[op_name] new_consumer_idx = self._consumers[op_name]
if self._idx_consumer_num.get(new_consumer_idx) is None: if self._idx_consumer_num.get(new_consumer_idx) is None:
self._idx_consumer_num[new_consumer_idx] = 0 self._idx_consumer_num[new_consumer_idx] = 0
self._idx_consumer_num[new_consumer_idx] += 1 self._idx_consumer_num[new_consumer_idx] += 1
logging.debug(
self._log(
"({}) A self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}".
format(op_name, self._consumers, self._consumer_base_idx.
value, len(self._front_res))))
logging.debug(self._log("{} notify all".format(op_name)))
self._cv.notify_all() self._cv.notify_all()
logging.debug(self._log("multi | {} get data succ!".format(op_name))) logging.debug(self._log("multi | {} get data succ!".format(op_name)))
...@@ -478,42 +414,33 @@ class Op(object): ...@@ -478,42 +414,33 @@ class Op(object):
concurrency=1, concurrency=1,
timeout=-1, timeout=-1,
retry=2): retry=2):
self._is_run = False self._run = False
self.name = name # to identify the type of OP, it must be globally unique self.name = name # to identify the type of OP, it must be globally unique
self._concurrency = concurrency # amount of concurrency self._concurrency = concurrency # amount of concurrency
self.set_input_ops(inputs) self.set_input_ops(inputs)
self.set_client(client_config, server_name, fetch_names)
self._server_model = server_model
self._server_port = server_port
self._device = device
self._timeout = timeout self._timeout = timeout
self._retry = max(1, retry) self._retry = max(1, retry)
self._input = None self._input = None
self._outputs = [] self._outputs = []
self.with_serving = False def set_client(self, client_config, server_name, fetch_names):
self._client_config = client_config self._client = None
self._server_name = server_name if client_config is None or \
self._fetch_names = fetch_names server_name is None or \
self._server_model = server_model fetch_names is None:
self._server_port = server_port
self._device = device
if self._client_config is not None and \
self._server_name is not None and \
self._fetch_names is not None and \
self._server_model is not None and \
self._server_port is not None and \
self._device is not None:
self.with_serving = True
def init_client(self, client_config, server_name, fetch_names):
if self.with_serving == False:
logging.debug("{} no client".format(self.name))
return return
logging.debug("{} client_config: {}".format(self.name, client_config))
logging.debug("{} server_name: {}".format(self.name, server_name))
logging.debug("{} fetch_names: {}".format(self.name, fetch_names))
self._client = Client() self._client = Client()
self._client.load_client_config(client_config) self._client.load_client_config(client_config)
self._client.connect([server_name]) self._client.connect([server_name])
self._fetch_names = fetch_names self._fetch_names = fetch_names
def with_serving(self):
return self._client is not None
def get_input_channel(self): def get_input_channel(self):
return self._input return self._input
...@@ -558,7 +485,7 @@ class Op(object): ...@@ -558,7 +485,7 @@ class Op(object):
feed = channeldata.parse() feed = channeldata.parse()
return feed return feed
def midprocess(self, data): def midprocess(self, data, asyn):
if not isinstance(data, dict): if not isinstance(data, dict):
raise Exception( raise Exception(
self._log( self._log(
...@@ -566,10 +493,12 @@ class Op(object): ...@@ -566,10 +493,12 @@ class Op(object):
format(type(data)))) format(type(data))))
logging.debug(self._log('data: {}'.format(data))) logging.debug(self._log('data: {}'.format(data)))
logging.debug(self._log('fetch: {}'.format(self._fetch_names))) logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
call_future = self._client.predict( #call_result = self._client.predict(
feed=data, fetch=self._fetch_names, asyn=True) # feed=data, fetch=self._fetch_names, asyn=asyn)
call_result = self._client.predict(
feed=data, fetch=self._fetch_names)
logging.debug(self._log("get call_future")) logging.debug(self._log("get call_future"))
return call_future return call_result
def postprocess(self, output_data): def postprocess(self, output_data):
return output_data return output_data
...@@ -578,59 +507,48 @@ class Op(object): ...@@ -578,59 +507,48 @@ class Op(object):
self._input.stop() self._input.stop()
for channel in self._outputs: for channel in self._outputs:
channel.stop() channel.stop()
self._is_run = False self._run = False
def _parse_channeldata(self, channeldata): def _parse_channeldata(self, channeldata):
data_id, error_channeldata = None, None data_id, error_pbdata = None, None
if isinstance(channeldata, dict): if isinstance(channeldata, dict):
parsed_data = {} parsed_data = {}
key = channeldata.keys()[0] key = channeldata.keys()[0]
data_id = channeldata[key].id data_id = channeldata[key].pbdata.id
for _, data in channeldata.items(): for _, data in channeldata.items():
if data.ecode != ChannelDataEcode.OK.value: if data.pbdata.ecode != ChannelDataEcode.OK.value:
error_channeldata = data error_pbdata = data.pbdata
break break
else: else:
data_id = channeldata.id data_id = channeldata.pbdata.id
if channeldata.ecode != ChannelDataEcode.OK.value: if channeldata.pbdata.ecode != ChannelDataEcode.OK.value:
error_channeldata = channeldata error_pbdata = channeldata.pbdata
return data_id, error_channeldata return data_id, error_pbdata
def _push_to_output_channels(self, data, channels, name=None): def _push_to_output_channels(self, data, name=None):
if name is None: if name is None:
name = self.name name = self.name
for channel in channels: for channel in self._outputs:
channel.push(data, name) channel.push(data, name)
def start(self): def start(self, concurrency_idx):
proces = []
for concurrency_idx in range(self._concurrency):
p = multiprocessing.Process(
target=self._run,
args=(concurrency_idx, self.get_input_channel(),
self.get_output_channels()))
p.start()
proces.append(p)
return proces
def _run(self, concurrency_idx, input_channel, output_channels):
self.init_client(self._client_config, self._server_name,
self._fetch_names)
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix) log = self._get_log_func(op_info_prefix)
self._is_run = True self._run = True
while self._is_run: while self._run:
_profiler.record("{}-get_0".format(op_info_prefix)) _profiler.record("{}-get_0".format(op_info_prefix))
channeldata = input_channel.front(self.name) channeldata = self._input.front(self.name)
_profiler.record("{}-get_1".format(op_info_prefix)) _profiler.record("{}-get_1".format(op_info_prefix))
logging.debug(log("input_data: {}".format(channeldata))) logging.debug(log("input_data: {}".format(channeldata)))
data_id, error_channeldata = self._parse_channeldata(channeldata) data_id, error_pbdata = self._parse_channeldata(channeldata)
# error data in predecessor Op # error data in predecessor Op
if error_channeldata is not None: if error_pbdata is not None:
self._push_to_output_channels(error_channeldata, self._push_to_output_channels(
output_channels) ChannelData(
datatype=ChannelDataType.CHANNEL_PBDATA.value,
pbdata=error_pbdata))
continue continue
# preprecess # preprecess
...@@ -646,8 +564,7 @@ class Op(object): ...@@ -646,8 +564,7 @@ class Op(object):
ChannelData( ChannelData(
ecode=ChannelDataEcode.NOT_IMPLEMENTED.value, ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
error_info=error_info, error_info=error_info,
data_id=data_id), data_id=data_id))
output_channels)
continue continue
except TypeError as e: except TypeError as e:
# Error type in channeldata.datatype # Error type in channeldata.datatype
...@@ -657,8 +574,7 @@ class Op(object): ...@@ -657,8 +574,7 @@ class Op(object):
ChannelData( ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value, ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info=error_info, error_info=error_info,
data_id=data_id), data_id=data_id))
output_channels)
continue continue
except Exception as e: except Exception as e:
error_info = log(e) error_info = log(e)
...@@ -667,18 +583,18 @@ class Op(object): ...@@ -667,18 +583,18 @@ class Op(object):
ChannelData( ChannelData(
ecode=ChannelDataEcode.UNKNOW.value, ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info, error_info=error_info,
data_id=data_id), data_id=data_id))
output_channels)
continue continue
# midprocess # midprocess
call_future = None midped_data = None
if self.with_serving: asyn = False
if self.with_serving():
ecode = ChannelDataEcode.OK.value ecode = ChannelDataEcode.OK.value
_profiler.record("{}-midp_0".format(op_info_prefix)) _profiler.record("{}-midp_0".format(op_info_prefix))
if self._timeout <= 0: if self._timeout <= 0:
try: try:
call_future = self.midprocess(preped_data) midped_data = self.midprocess(preped_data, asyn)
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e) error_info = log(e)
...@@ -686,10 +602,10 @@ class Op(object): ...@@ -686,10 +602,10 @@ class Op(object):
else: else:
for i in range(self._retry): for i in range(self._retry):
try: try:
call_future = func_timeout.func_timeout( midped_data = func_timeout.func_timeout(
self._timeout, self._timeout,
self.midprocess, self.midprocess,
args=(preped_data, )) args=(preped_data, asyn))
except func_timeout.FunctionTimedOut as e: except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry: if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value ecode = ChannelDataEcode.TIMEOUT.value
...@@ -709,33 +625,25 @@ class Op(object): ...@@ -709,33 +625,25 @@ class Op(object):
self._push_to_output_channels( self._push_to_output_channels(
ChannelData( ChannelData(
ecode=ecode, error_info=error_info, ecode=ecode, error_info=error_info,
data_id=data_id), data_id=data_id))
output_channels)
continue continue
_profiler.record("{}-midp_1".format(op_info_prefix)) _profiler.record("{}-midp_1".format(op_info_prefix))
else:
midped_data = preped_data
# postprocess # postprocess
output_data = None output_data = None
_profiler.record("{}-postp_0".format(op_info_prefix)) _profiler.record("{}-postp_0".format(op_info_prefix))
if self.with_serving: if self.with_serving() and asyn:
# use call_future # use call_future
output_data = ChannelData( output_data = ChannelData(
datatype=ChannelDataType.CHANNEL_FUTURE.value, datatype=ChannelDataType.CHANNEL_FUTURE.value,
future=call_future, future=midped_data,
data_id=data_id, data_id=data_id,
callback_func=self.postprocess) callback_func=self.postprocess)
#TODO: for future are not picklable
npdata = self.postprocess(call_future.result())
self._push_to_output_channels(
ChannelData(
ChannelDataType.CHANNEL_NPDATA.value,
npdata=npdata,
data_id=data_id),
output_channels)
continue
else: else:
try: try:
postped_data = self.postprocess(preped_data) postped_data = self.postprocess(midped_data)
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e) error_info = log(e)
...@@ -743,8 +651,7 @@ class Op(object): ...@@ -743,8 +651,7 @@ class Op(object):
self._push_to_output_channels( self._push_to_output_channels(
ChannelData( ChannelData(
ecode=ecode, error_info=error_info, ecode=ecode, error_info=error_info,
data_id=data_id), data_id=data_id))
output_channels)
continue continue
if not isinstance(postped_data, dict): if not isinstance(postped_data, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value ecode = ChannelDataEcode.TYPE_ERROR.value
...@@ -754,8 +661,7 @@ class Op(object): ...@@ -754,8 +661,7 @@ class Op(object):
self._push_to_output_channels( self._push_to_output_channels(
ChannelData( ChannelData(
ecode=ecode, error_info=error_info, ecode=ecode, error_info=error_info,
data_id=data_id), data_id=data_id))
output_channels)
continue continue
output_data = ChannelData( output_data = ChannelData(
...@@ -766,7 +672,7 @@ class Op(object): ...@@ -766,7 +672,7 @@ class Op(object):
# push data to channel (if run succ) # push data to channel (if run succ)
_profiler.record("{}-push_0".format(op_info_prefix)) _profiler.record("{}-push_0".format(op_info_prefix))
self._push_to_output_channels(output_data, output_channels) self._push_to_output_channels(output_data)
_profiler.record("{}-push_1".format(op_info_prefix)) _profiler.record("{}-push_1".format(op_info_prefix))
def _log(self, info): def _log(self, info):
...@@ -802,30 +708,27 @@ class VirtualOp(Op): ...@@ -802,30 +708,27 @@ class VirtualOp(Op):
channel.add_producer(op.name) channel.add_producer(op.name)
self._outputs.append(channel) self._outputs.append(channel)
def _run(self, input_channel, output_channels): def start(self, concurrency_idx):
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = self._get_log_func(op_info_prefix) log = self._get_log_func(op_info_prefix)
self._is_run = True self._run = True
while self._is_run: while self._run:
_profiler.record("{}-get_0".format(op_info_prefix)) _profiler.record("{}-get_0".format(op_info_prefix))
channeldata = input_channel.front(self.name) channeldata = self._input.front(self.name)
_profiler.record("{}-get_1".format(op_info_prefix)) _profiler.record("{}-get_1".format(op_info_prefix))
_profiler.record("{}-push_0".format(op_info_prefix)) _profiler.record("{}-push_0".format(op_info_prefix))
if isinstance(channeldata, dict): if isinstance(channeldata, dict):
for name, data in channeldata.items(): for name, data in channeldata.items():
self._push_to_output_channels( self._push_to_output_channels(data, name=name)
data, channels=output_channels, name=name)
else: else:
self._push_to_output_channels( self._push_to_output_channels(channeldata,
channeldata, self._virtual_pred_ops[0].name)
channels=output_channels,
name=self._virtual_pred_ops[0].name)
_profiler.record("{}-push_1".format(op_info_prefix)) _profiler.record("{}-push_1".format(op_info_prefix))
class GeneralPythonService( class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonServiceServicer): general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, in_channel, out_channel, retry=2): def __init__(self, in_channel, out_channel, retry=2):
super(GeneralPythonService, self).__init__() super(GeneralPythonService, self).__init__()
self.name = "#G" self.name = "#G"
...@@ -872,7 +775,7 @@ class GeneralPythonService( ...@@ -872,7 +775,7 @@ class GeneralPythonService(
self._log('data must be ChannelData type, but get {}'. self._log('data must be ChannelData type, but get {}'.
format(type(channeldata)))) format(type(channeldata))))
with self._cv: with self._cv:
data_id = channeldata.id data_id = channeldata.pbdata.id
self._globel_resp_dict[data_id] = channeldata self._globel_resp_dict[data_id] = channeldata
self._cv.notify_all() self._cv.notify_all()
...@@ -892,33 +795,33 @@ class GeneralPythonService( ...@@ -892,33 +795,33 @@ class GeneralPythonService(
def _pack_data_for_infer(self, request): def _pack_data_for_infer(self, request):
logging.debug(self._log('start inferce')) logging.debug(self._log('start inferce'))
pbdata = channel_pb2.ChannelData()
data_id = self._get_next_id() data_id = self._get_next_id()
npdata = {} pbdata.id = data_id
pbdata.ecode = ChannelDataEcode.OK.value
try: try:
for idx, name in enumerate(request.feed_var_names): for idx, name in enumerate(request.feed_var_names):
logging.debug( logging.debug(
self._log('name: {}'.format(request.feed_var_names[idx]))) self._log('name: {}'.format(request.feed_var_names[idx])))
logging.debug( logging.debug(
self._log('data: {}'.format(request.feed_insts[idx]))) self._log('data: {}'.format(request.feed_insts[idx])))
npdata[name] = np.frombuffer( inst = channel_pb2.Inst()
request.feed_insts[idx], dtype=request.type[idx]) inst.data = request.feed_insts[idx]
npdata[name].shape = np.frombuffer( inst.shape = request.shape[idx]
request.shape[idx], dtype="int32") inst.name = name
inst.type = request.type[idx]
pbdata.insts.append(inst)
except Exception as e: except Exception as e:
pbdata.ecode = ChannelDataEcode.RPC_PACKAGE_ERROR.value
pbdata.error_info = "rpc package error"
return ChannelData( return ChannelData(
ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value, datatype=ChannelDataType.CHANNEL_PBDATA.value,
error_info="rpc package error", pbdata=pbdata), data_id
data_id=data_id), data_id
else:
return ChannelData(
datatype=ChannelDataType.CHANNEL_NPDATA.value,
npdata=npdata,
data_id=data_id), data_id
def _pack_data_for_resp(self, channeldata): def _pack_data_for_resp(self, channeldata):
logging.debug(self._log('get channeldata')) logging.debug(self._log('get channeldata'))
resp = pyservice_pb2.Response() resp = pyservice_pb2.Response()
resp.ecode = channeldata.ecode resp.ecode = channeldata.pbdata.ecode
if resp.ecode == ChannelDataEcode.OK.value: if resp.ecode == ChannelDataEcode.OK.value:
if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value: if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value:
for inst in channeldata.pbdata.insts: for inst in channeldata.pbdata.insts:
...@@ -941,7 +844,7 @@ class GeneralPythonService( ...@@ -941,7 +844,7 @@ class GeneralPythonService(
self._log("Error type({}) in datatype.".format( self._log("Error type({}) in datatype.".format(
channeldata.datatype))) channeldata.datatype)))
else: else:
resp.error_info = channeldata.error_info resp.error_info = channeldata.pbdata.error_info
return resp return resp
def inference(self, request, context): def inference(self, request, context):
...@@ -961,11 +864,11 @@ class GeneralPythonService( ...@@ -961,11 +864,11 @@ class GeneralPythonService(
resp_channeldata = self._get_data_in_globel_resp_dict(data_id) resp_channeldata = self._get_data_in_globel_resp_dict(data_id)
_profiler.record("{}-fetch_1".format(self.name)) _profiler.record("{}-fetch_1".format(self.name))
if resp_channeldata.ecode == ChannelDataEcode.OK.value: if resp_channeldata.pbdata.ecode == ChannelDataEcode.OK.value:
break break
if i + 1 < self._retry: if i + 1 < self._retry:
logging.warn("retry({}): {}".format( logging.warn("retry({}): {}".format(
i + 1, resp_channeldata.error_info)) i + 1, resp_channeldata.pbdata.error_info))
_profiler.record("{}-postpack_0".format(self.name)) _profiler.record("{}-postpack_0".format(self.name))
resp = self._pack_data_for_resp(resp_channeldata) resp = self._pack_data_for_resp(resp_channeldata)
...@@ -979,12 +882,12 @@ class PyServer(object): ...@@ -979,12 +882,12 @@ class PyServer(object):
self._channels = [] self._channels = []
self._user_ops = [] self._user_ops = []
self._actual_ops = [] self._actual_ops = []
self._op_threads = []
self._port = None self._port = None
self._worker_num = None self._worker_num = None
self._in_channel = None self._in_channel = None
self._out_channel = None self._out_channel = None
self._retry = retry self._retry = retry
self._manager = multiprocessing.Manager()
_profiler.enable(profile) _profiler.enable(profile)
def add_channel(self, channel): def add_channel(self, channel):
...@@ -1009,7 +912,6 @@ class PyServer(object): ...@@ -1009,7 +912,6 @@ class PyServer(object):
op.name = "#G" # update read_op.name op.name = "#G" # update read_op.name
break break
outdegs = {op.name: [] for op in self._user_ops} outdegs = {op.name: [] for op in self._user_ops}
zero_indeg_num, zero_outdeg_num = 0, 0
for idx, op in enumerate(self._user_ops): for idx, op in enumerate(self._user_ops):
# check the name of op is globally unique # check the name of op is globally unique
if op.name in indeg_num: if op.name in indeg_num:
...@@ -1017,16 +919,8 @@ class PyServer(object): ...@@ -1017,16 +919,8 @@ class PyServer(object):
indeg_num[op.name] = len(op.get_input_ops()) indeg_num[op.name] = len(op.get_input_ops())
if indeg_num[op.name] == 0: if indeg_num[op.name] == 0:
ques[que_idx].put(op) ques[que_idx].put(op)
zero_indeg_num += 1
for pred_op in op.get_input_ops(): for pred_op in op.get_input_ops():
outdegs[pred_op.name].append(op) outdegs[pred_op.name].append(op)
if zero_indeg_num != 1:
raise Exception("DAG contains multiple input Ops")
for _, succ_list in outdegs.items():
if len(succ_list) == 0:
zero_outdeg_num += 1
if zero_outdeg_num != 1:
raise Exception("DAG contains multiple output Ops")
# topo sort to get dag_views # topo sort to get dag_views
dag_views = [] dag_views = []
...@@ -1049,6 +943,10 @@ class PyServer(object): ...@@ -1049,6 +943,10 @@ class PyServer(object):
que_idx = (que_idx + 1) % 2 que_idx = (que_idx + 1) % 2
if sorted_op_num < len(self._user_ops): if sorted_op_num < len(self._user_ops):
raise Exception("not legal DAG") raise Exception("not legal DAG")
if len(dag_views[0]) != 1:
raise Exception("DAG contains multiple input Ops")
if len(dag_views[-1]) != 1:
raise Exception("DAG contains multiple output Ops")
# create channels and virtual ops # create channels and virtual ops
def name_generator(prefix): def name_generator(prefix):
...@@ -1086,14 +984,7 @@ class PyServer(object): ...@@ -1086,14 +984,7 @@ class PyServer(object):
else: else:
# create virtual op # create virtual op
virtual_op = None virtual_op = None
if sys.version_info.major == 2: virtual_op = VirtualOp(name=virtual_op_name_gen.next())
virtual_op = VirtualOp(
name=virtual_op_name_gen.next())
elif sys.version_info.major == 3:
virtual_op = VirtualOp(
name=virtual_op_name_gen.__next__())
else:
raise Exception("Error Python version")
virtual_ops.append(virtual_op) virtual_ops.append(virtual_op)
outdegs[virtual_op.name] = [succ_op] outdegs[virtual_op.name] = [succ_op]
actual_next_view.append(virtual_op) actual_next_view.append(virtual_op)
...@@ -1105,14 +996,7 @@ class PyServer(object): ...@@ -1105,14 +996,7 @@ class PyServer(object):
for o_idx, op in enumerate(actual_next_view): for o_idx, op in enumerate(actual_next_view):
if op.name in processed_op: if op.name in processed_op:
continue continue
if sys.version_info.major == 2: channel = Channel(name=channel_name_gen.next())
channel = Channel(
self._manager, name=channel_name_gen.next())
elif sys.version_info.major == 3:
channel = Channel(
self._manager, name=channel_name_gen.__next__())
else:
raise Exception("Error Python version")
channels.append(channel) channels.append(channel)
logging.debug("{} => {}".format(channel.name, op.name)) logging.debug("{} => {}".format(channel.name, op.name))
op.add_input_channel(channel) op.add_input_channel(channel)
...@@ -1143,14 +1027,7 @@ class PyServer(object): ...@@ -1143,14 +1027,7 @@ class PyServer(object):
other_op.name)) other_op.name))
other_op.add_input_channel(channel) other_op.add_input_channel(channel)
processed_op.add(other_op.name) processed_op.add(other_op.name)
if sys.version_info.major == 2: output_channel = Channel(name=channel_name_gen.next())
output_channel = Channel(
self._manager, name=channel_name_gen.next())
elif sys.version_info.major == 3:
output_channel = Channel(
self._manager, name=channel_name_gen.__next__())
else:
raise Exception("Error Python version")
channels.append(output_channel) channels.append(output_channel)
last_op = dag_views[-1][0] last_op = dag_views[-1][0]
last_op.add_output_channel(output_channel) last_op.add_output_channel(output_channel)
...@@ -1174,22 +1051,30 @@ class PyServer(object): ...@@ -1174,22 +1051,30 @@ class PyServer(object):
self._in_channel = input_channel self._in_channel = input_channel
self._out_channel = output_channel self._out_channel = output_channel
for op in self._actual_ops: for op in self._actual_ops:
if op.with_serving: if op.with_serving():
self.prepare_serving(op) self.prepare_serving(op)
self.gen_desc() self.gen_desc()
def _op_start_wrapper(self, op, concurrency_idx):
return op.start(concurrency_idx)
def _run_ops(self): def _run_ops(self):
proces = []
for op in self._actual_ops: for op in self._actual_ops:
proces.extend(op.start()) op_concurrency = op.get_concurrency()
return proces logging.debug("run op: {}, op_concurrency: {}".format(
op.name, op_concurrency))
for c in range(op_concurrency):
th = threading.Thread(
target=self._op_start_wrapper, args=(op, c))
th.start()
self._op_threads.append(th)
def _stop_ops(self): def _stop_ops(self):
for op in self._actual_ops: for op in self._actual_ops:
op.stop() op.stop()
def run_server(self): def run_server(self):
op_proces = self._run_ops() self._run_ops()
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num)) futures.ThreadPoolExecutor(max_workers=self._worker_num))
general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server( general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
...@@ -1199,8 +1084,8 @@ class PyServer(object): ...@@ -1199,8 +1084,8 @@ class PyServer(object):
server.start() server.start()
server.wait_for_termination() server.wait_for_termination()
self._stop_ops() # TODO self._stop_ops() # TODO
for p in op_proces: for th in self._op_threads:
p.join() th.join()
def prepare_serving(self, op): def prepare_serving(self, op):
model_path = op._server_model model_path = op._server_model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册