diff --git a/python/paddle_serving_server/pyserver.py b/python/paddle_serving_server/pyserver.py index 7285ccc7b622d1afeab152b410b43f48aca581fd..4ca99c18eb30d5432b59a169747e9a91e8ab5bc7 100644 --- a/python/paddle_serving_server/pyserver.py +++ b/python/paddle_serving_server/pyserver.py @@ -14,18 +14,12 @@ # pylint: disable=doc-string-missing import threading import multiprocessing -import multiprocessing.queues -import sys -if sys.version_info.major == 2: - import Queue -elif sys.version_info.major == 3: - import queue as Queue -else: - raise Exception("Error Python version") +import Queue import os +import sys import paddle_serving_server -from paddle_serving_client import MultiLangClient as Client -from paddle_serving_client import MultiLangPredictFuture +#from paddle_serving_client import MultiLangClient as Client +from paddle_serving_client import Client from concurrent import futures import numpy as np import grpc @@ -116,27 +110,34 @@ class ChannelData(object): 4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id) 5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id) 6. ChannelData(ecode, error_info, data_id) - - Protobufs are not pickle-able: - https://stackoverflow.com/questions/55344376/how-to-import-protobuf-module ''' if ecode is not None: if data_id is None or error_info is None: raise ValueError("data_id and error_info cannot be None") + pbdata = channel_pb2.ChannelData() + pbdata.ecode = ecode + pbdata.id = data_id + pbdata.error_info = error_info datatype = ChannelDataType.ERROR.value else: if datatype == ChannelDataType.CHANNEL_FUTURE.value: - if data_id is None: - raise ValueError("data_id cannot be None") - ecode = ChannelDataEcode.OK.value + if pbdata is None: + if data_id is None: + raise ValueError("data_id cannot be None") + pbdata = channel_pb2.ChannelData() + pbdata.ecode = ChannelDataEcode.OK.value + pbdata.id = data_id elif datatype == ChannelDataType.CHANNEL_PBDATA.value: if pbdata is None: if data_id is None: raise ValueError("data_id cannot be None") pbdata = channel_pb2.ChannelData() + pbdata.id = data_id ecode, error_info = self._check_npdata(npdata) - if ecode != ChannelDataEcode.OK.value: - logging.error(error_info) + pbdata.ecode = ecode + if pbdata.ecode != ChannelDataEcode.OK.value: + pbdata.error_info = error_info + logging.error(pbdata.error_info) else: for name, value in npdata.items(): inst = channel_pb2.Inst() @@ -148,18 +149,23 @@ class ChannelData(object): pbdata.insts.append(inst) elif datatype == ChannelDataType.CHANNEL_NPDATA.value: ecode, error_info = self._check_npdata(npdata) - if ecode != ChannelDataEcode.OK.value: - logging.error(error_info) + pbdata = channel_pb2.ChannelData() + pbdata.id = data_id + pbdata.ecode = ecode + if pbdata.ecode != ChannelDataEcode.OK.value: + pbdata.error_info = error_info + logging.error(pbdata.error_info) else: raise ValueError("datatype not match") + if not isinstance(pbdata, channel_pb2.ChannelData): + raise TypeError( + "pbdata must be pyserving_channel_pb2.ChannelData type({})". + format(type(pbdata))) self.future = future self.pbdata = pbdata self.npdata = npdata self.datatype = datatype self.callback_func = callback_func - self.id = data_id - self.ecode = ecode - self.error_info = error_info def _check_npdata(self, npdata): ecode = ChannelDataEcode.OK.value @@ -192,15 +198,15 @@ class ChannelData(object): elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value: feed = self.npdata else: - raise TypeError("Error type({}) in datatype.".format(self.datatype)) + raise TypeError("Error type({}) in datatype.".format(datatype)) return feed def __str__(self): - return "type[{}], ecode[{}], id[{}]".format( - ChannelDataType(self.datatype).name, self.ecode, self.id) + return "type[{}], ecode[{}]".format( + ChannelDataType(self.datatype).name, self.pbdata.ecode) -class Channel(multiprocessing.queues.Queue): +class Channel(Queue.Queue): """ The channel used for communication between Ops. @@ -218,36 +224,23 @@ class Channel(multiprocessing.queues.Queue): and can only be called during initialization. """ - def __init__(self, manager, name=None, maxsize=0, timeout=None): - # https://stackoverflow.com/questions/39496554/cannot-subclass-multiprocessing-queue-in-python-3-5/ - if sys.version_info.major == 2: - super(Channel, self).__init__(maxsize=maxsize) - elif sys.version_info.major == 3: - super(Channel, self).__init__( - maxsize=maxsize, ctx=multiprocessing.get_context()) - else: - raise Exception("Error Python version") + def __init__(self, name=None, maxsize=-1, timeout=None): + Queue.Queue.__init__(self, maxsize=maxsize) self._maxsize = maxsize self._timeout = timeout self.name = name self._stop = False - self._cv = multiprocessing.Condition() + self._cv = threading.Condition() self._producers = [] - self._producer_res_count = manager.dict() # {data_id: count} - # self._producer_res_count = {} # {data_id: count} - self._push_res = manager.dict() # {data_id: {op_name: data}} - # self._push_res = {} # {data_id: {op_name: data}} - - self._consumers = manager.dict() # {op_name: idx} - # self._consumers = {} # {op_name: idx} - self._idx_consumer_num = manager.dict() # {idx: num} - # self._idx_consumer_num = {} # {idx: num} - self._consumer_base_idx = manager.Value('i', 0) - # self._consumer_base_idx = 0 - self._front_res = manager.list() - # self._front_res = [] + self._producer_res_count = {} # {data_id: count} + self._push_res = {} # {data_id: {op_name: data}} + + self._consumers = {} # {op_name: idx} + self._idx_consumer_num = {} # {idx: num} + self._consumer_base_idx = 0 + self._front_res = [] def get_producers(self): return self._producers @@ -297,11 +290,7 @@ class Channel(multiprocessing.queues.Queue): break except Queue.Full: self._cv.wait() - logging.debug( - self._log("{} channel size: {}".format(op_name, - self.qsize()))) self._cv.notify_all() - logging.debug(self._log("{} notify all".format(op_name))) logging.debug(self._log("{} push data succ!".format(op_name))) return True elif op_name is None: @@ -310,7 +299,7 @@ class Channel(multiprocessing.queues.Queue): "There are multiple producers, so op_name cannot be None.")) producer_num = len(self._producers) - data_id = channeldata.id + data_id = channeldata.pbdata.id put_data = None with self._cv: logging.debug(self._log("{} get lock".format(op_name))) @@ -320,12 +309,7 @@ class Channel(multiprocessing.queues.Queue): for name in self._producers } self._producer_res_count[data_id] = 0 - # see: https://docs.python.org/3.6/library/multiprocessing.html?highlight=multiprocess#proxy-objects - # self._push_res[data_id][op_name] = channeldata - tmp_push_res = self._push_res[data_id] - tmp_push_res[op_name] = channeldata - self._push_res[data_id] = tmp_push_res - + self._push_res[data_id][op_name] = channeldata if self._producer_res_count[data_id] + 1 == producer_num: put_data = self._push_res[data_id] self._push_res.pop(data_id) @@ -340,9 +324,6 @@ class Channel(multiprocessing.queues.Queue): else: while self._stop is False: try: - logging.debug( - self._log("{} push data succ: {}".format( - op_name, put_data.__str__()))) self.put(put_data, timeout=0) break except Queue.Empty: @@ -354,7 +335,7 @@ class Channel(multiprocessing.queues.Queue): return True def front(self, op_name=None): - logging.debug(self._log("{} try to get data...".format(op_name))) + logging.debug(self._log("{} try to get data".format(op_name))) if len(self._consumers) == 0: raise Exception( self._log( @@ -365,26 +346,9 @@ class Channel(multiprocessing.queues.Queue): with self._cv: while self._stop is False and resp is None: try: - logging.debug( - self._log("{} try to get(with channel empty: {})". - format(op_name, self.empty()))) - # For Python2, after putting an object on an empty queue there may - # be an infinitessimal delay before the queue's :meth:`~Queue.empty` - # see more: - # - https://bugs.python.org/issue18277 - # - https://hg.python.org/cpython/rev/860fc6a2bd21 - if sys.version_info.major == 2: - resp = self.get(timeout=1e-3) - elif sys.version_info.major == 3: - resp = self.get(timeout=0) - else: - raise Exception("Error Python version") + resp = self.get(timeout=0) break except Queue.Empty: - logging.debug( - self._log( - "{} wait for empty queue(with channel empty: {})". - format(op_name, self.empty()))) self._cv.wait() logging.debug( self._log("{} get data succ: {}".format(op_name, resp.__str__( @@ -398,39 +362,16 @@ class Channel(multiprocessing.queues.Queue): with self._cv: # data_idx = consumer_idx - base_idx while self._stop is False and self._consumers[ - op_name] - self._consumer_base_idx.value >= len( - self._front_res): - logging.debug( - self._log( - "({}) B self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}". - format(op_name, self._consumers, self. - _consumer_base_idx.value, len(self._front_res)))) + op_name] - self._consumer_base_idx >= len(self._front_res): try: - logging.debug( - self._log("{} try to get(with channel size: {})".format( - op_name, self.qsize()))) - # For Python2, after putting an object on an empty queue there may - # be an infinitessimal delay before the queue's :meth:`~Queue.empty` - # see more: - # - https://bugs.python.org/issue18277 - # - https://hg.python.org/cpython/rev/860fc6a2bd21 - if sys.version_info.major == 2: - channeldata = self.get(timeout=1e-3) - elif sys.version_info.major == 3: - channeldata = self.get(timeout=0) - else: - raise Exception("Error Python version") + channeldata = self.get(timeout=0) self._front_res.append(channeldata) break except Queue.Empty: - logging.debug( - self._log( - "{} wait for empty queue(with channel size: {})". - format(op_name, self.qsize()))) self._cv.wait() consumer_idx = self._consumers[op_name] - base_idx = self._consumer_base_idx.value + base_idx = self._consumer_base_idx data_idx = consumer_idx - base_idx resp = self._front_res[data_idx] logging.debug(self._log("{} get data: {}".format(op_name, resp))) @@ -440,19 +381,14 @@ class Channel(multiprocessing.queues.Queue): consumer_idx] == 0: self._idx_consumer_num.pop(consumer_idx) self._front_res.pop(0) - self._consumer_base_idx.value += 1 + self._consumer_base_idx += 1 self._consumers[op_name] += 1 new_consumer_idx = self._consumers[op_name] if self._idx_consumer_num.get(new_consumer_idx) is None: self._idx_consumer_num[new_consumer_idx] = 0 self._idx_consumer_num[new_consumer_idx] += 1 - logging.debug( - self._log( - "({}) A self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}". - format(op_name, self._consumers, self._consumer_base_idx. - value, len(self._front_res)))) - logging.debug(self._log("{} notify all".format(op_name))) + self._cv.notify_all() logging.debug(self._log("multi | {} get data succ!".format(op_name))) @@ -478,42 +414,33 @@ class Op(object): concurrency=1, timeout=-1, retry=2): - self._is_run = False + self._run = False self.name = name # to identify the type of OP, it must be globally unique self._concurrency = concurrency # amount of concurrency self.set_input_ops(inputs) + self.set_client(client_config, server_name, fetch_names) + self._server_model = server_model + self._server_port = server_port + self._device = device self._timeout = timeout self._retry = max(1, retry) self._input = None self._outputs = [] - self.with_serving = False - self._client_config = client_config - self._server_name = server_name - self._fetch_names = fetch_names - self._server_model = server_model - self._server_port = server_port - self._device = device - if self._client_config is not None and \ - self._server_name is not None and \ - self._fetch_names is not None and \ - self._server_model is not None and \ - self._server_port is not None and \ - self._device is not None: - self.with_serving = True - - def init_client(self, client_config, server_name, fetch_names): - if self.with_serving == False: - logging.debug("{} no client".format(self.name)) + def set_client(self, client_config, server_name, fetch_names): + self._client = None + if client_config is None or \ + server_name is None or \ + fetch_names is None: return - logging.debug("{} client_config: {}".format(self.name, client_config)) - logging.debug("{} server_name: {}".format(self.name, server_name)) - logging.debug("{} fetch_names: {}".format(self.name, fetch_names)) self._client = Client() self._client.load_client_config(client_config) self._client.connect([server_name]) self._fetch_names = fetch_names + def with_serving(self): + return self._client is not None + def get_input_channel(self): return self._input @@ -558,7 +485,7 @@ class Op(object): feed = channeldata.parse() return feed - def midprocess(self, data): + def midprocess(self, data, asyn): if not isinstance(data, dict): raise Exception( self._log( @@ -566,10 +493,12 @@ class Op(object): format(type(data)))) logging.debug(self._log('data: {}'.format(data))) logging.debug(self._log('fetch: {}'.format(self._fetch_names))) - call_future = self._client.predict( - feed=data, fetch=self._fetch_names, asyn=True) + #call_result = self._client.predict( + # feed=data, fetch=self._fetch_names, asyn=asyn) + call_result = self._client.predict( + feed=data, fetch=self._fetch_names) logging.debug(self._log("get call_future")) - return call_future + return call_result def postprocess(self, output_data): return output_data @@ -578,59 +507,48 @@ class Op(object): self._input.stop() for channel in self._outputs: channel.stop() - self._is_run = False + self._run = False def _parse_channeldata(self, channeldata): - data_id, error_channeldata = None, None + data_id, error_pbdata = None, None if isinstance(channeldata, dict): parsed_data = {} key = channeldata.keys()[0] - data_id = channeldata[key].id + data_id = channeldata[key].pbdata.id for _, data in channeldata.items(): - if data.ecode != ChannelDataEcode.OK.value: - error_channeldata = data + if data.pbdata.ecode != ChannelDataEcode.OK.value: + error_pbdata = data.pbdata break else: - data_id = channeldata.id - if channeldata.ecode != ChannelDataEcode.OK.value: - error_channeldata = channeldata - return data_id, error_channeldata + data_id = channeldata.pbdata.id + if channeldata.pbdata.ecode != ChannelDataEcode.OK.value: + error_pbdata = channeldata.pbdata + return data_id, error_pbdata - def _push_to_output_channels(self, data, channels, name=None): + def _push_to_output_channels(self, data, name=None): if name is None: name = self.name - for channel in channels: + for channel in self._outputs: channel.push(data, name) - def start(self): - proces = [] - for concurrency_idx in range(self._concurrency): - p = multiprocessing.Process( - target=self._run, - args=(concurrency_idx, self.get_input_channel(), - self.get_output_channels())) - p.start() - proces.append(p) - return proces - - def _run(self, concurrency_idx, input_channel, output_channels): - self.init_client(self._client_config, self._server_name, - self._fetch_names) + def start(self, concurrency_idx): op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) log = self._get_log_func(op_info_prefix) - self._is_run = True - while self._is_run: + self._run = True + while self._run: _profiler.record("{}-get_0".format(op_info_prefix)) - channeldata = input_channel.front(self.name) + channeldata = self._input.front(self.name) _profiler.record("{}-get_1".format(op_info_prefix)) logging.debug(log("input_data: {}".format(channeldata))) - data_id, error_channeldata = self._parse_channeldata(channeldata) + data_id, error_pbdata = self._parse_channeldata(channeldata) # error data in predecessor Op - if error_channeldata is not None: - self._push_to_output_channels(error_channeldata, - output_channels) + if error_pbdata is not None: + self._push_to_output_channels( + ChannelData( + datatype=ChannelDataType.CHANNEL_PBDATA.value, + pbdata=error_pbdata)) continue # preprecess @@ -646,8 +564,7 @@ class Op(object): ChannelData( ecode=ChannelDataEcode.NOT_IMPLEMENTED.value, error_info=error_info, - data_id=data_id), - output_channels) + data_id=data_id)) continue except TypeError as e: # Error type in channeldata.datatype @@ -657,8 +574,7 @@ class Op(object): ChannelData( ecode=ChannelDataEcode.TYPE_ERROR.value, error_info=error_info, - data_id=data_id), - output_channels) + data_id=data_id)) continue except Exception as e: error_info = log(e) @@ -667,18 +583,18 @@ class Op(object): ChannelData( ecode=ChannelDataEcode.UNKNOW.value, error_info=error_info, - data_id=data_id), - output_channels) + data_id=data_id)) continue # midprocess - call_future = None - if self.with_serving: + midped_data = None + asyn = False + if self.with_serving(): ecode = ChannelDataEcode.OK.value _profiler.record("{}-midp_0".format(op_info_prefix)) if self._timeout <= 0: try: - call_future = self.midprocess(preped_data) + midped_data = self.midprocess(preped_data, asyn) except Exception as e: ecode = ChannelDataEcode.UNKNOW.value error_info = log(e) @@ -686,10 +602,10 @@ class Op(object): else: for i in range(self._retry): try: - call_future = func_timeout.func_timeout( + midped_data = func_timeout.func_timeout( self._timeout, self.midprocess, - args=(preped_data, )) + args=(preped_data, asyn)) except func_timeout.FunctionTimedOut as e: if i + 1 >= self._retry: ecode = ChannelDataEcode.TIMEOUT.value @@ -709,33 +625,25 @@ class Op(object): self._push_to_output_channels( ChannelData( ecode=ecode, error_info=error_info, - data_id=data_id), - output_channels) + data_id=data_id)) continue _profiler.record("{}-midp_1".format(op_info_prefix)) + else: + midped_data = preped_data # postprocess output_data = None _profiler.record("{}-postp_0".format(op_info_prefix)) - if self.with_serving: + if self.with_serving() and asyn: # use call_future output_data = ChannelData( datatype=ChannelDataType.CHANNEL_FUTURE.value, - future=call_future, + future=midped_data, data_id=data_id, callback_func=self.postprocess) - #TODO: for future are not picklable - npdata = self.postprocess(call_future.result()) - self._push_to_output_channels( - ChannelData( - ChannelDataType.CHANNEL_NPDATA.value, - npdata=npdata, - data_id=data_id), - output_channels) - continue else: try: - postped_data = self.postprocess(preped_data) + postped_data = self.postprocess(midped_data) except Exception as e: ecode = ChannelDataEcode.UNKNOW.value error_info = log(e) @@ -743,8 +651,7 @@ class Op(object): self._push_to_output_channels( ChannelData( ecode=ecode, error_info=error_info, - data_id=data_id), - output_channels) + data_id=data_id)) continue if not isinstance(postped_data, dict): ecode = ChannelDataEcode.TYPE_ERROR.value @@ -754,8 +661,7 @@ class Op(object): self._push_to_output_channels( ChannelData( ecode=ecode, error_info=error_info, - data_id=data_id), - output_channels) + data_id=data_id)) continue output_data = ChannelData( @@ -766,7 +672,7 @@ class Op(object): # push data to channel (if run succ) _profiler.record("{}-push_0".format(op_info_prefix)) - self._push_to_output_channels(output_data, output_channels) + self._push_to_output_channels(output_data) _profiler.record("{}-push_1".format(op_info_prefix)) def _log(self, info): @@ -802,30 +708,27 @@ class VirtualOp(Op): channel.add_producer(op.name) self._outputs.append(channel) - def _run(self, input_channel, output_channels): + def start(self, concurrency_idx): op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) log = self._get_log_func(op_info_prefix) - self._is_run = True - while self._is_run: + self._run = True + while self._run: _profiler.record("{}-get_0".format(op_info_prefix)) - channeldata = input_channel.front(self.name) + channeldata = self._input.front(self.name) _profiler.record("{}-get_1".format(op_info_prefix)) _profiler.record("{}-push_0".format(op_info_prefix)) if isinstance(channeldata, dict): for name, data in channeldata.items(): - self._push_to_output_channels( - data, channels=output_channels, name=name) + self._push_to_output_channels(data, name=name) else: - self._push_to_output_channels( - channeldata, - channels=output_channels, - name=self._virtual_pred_ops[0].name) + self._push_to_output_channels(channeldata, + self._virtual_pred_ops[0].name) _profiler.record("{}-push_1".format(op_info_prefix)) class GeneralPythonService( - general_python_service_pb2_grpc.GeneralPythonServiceServicer): + general_python_service_pb2_grpc.GeneralPythonService): def __init__(self, in_channel, out_channel, retry=2): super(GeneralPythonService, self).__init__() self.name = "#G" @@ -872,7 +775,7 @@ class GeneralPythonService( self._log('data must be ChannelData type, but get {}'. format(type(channeldata)))) with self._cv: - data_id = channeldata.id + data_id = channeldata.pbdata.id self._globel_resp_dict[data_id] = channeldata self._cv.notify_all() @@ -892,33 +795,33 @@ class GeneralPythonService( def _pack_data_for_infer(self, request): logging.debug(self._log('start inferce')) + pbdata = channel_pb2.ChannelData() data_id = self._get_next_id() - npdata = {} + pbdata.id = data_id + pbdata.ecode = ChannelDataEcode.OK.value try: for idx, name in enumerate(request.feed_var_names): logging.debug( self._log('name: {}'.format(request.feed_var_names[idx]))) logging.debug( self._log('data: {}'.format(request.feed_insts[idx]))) - npdata[name] = np.frombuffer( - request.feed_insts[idx], dtype=request.type[idx]) - npdata[name].shape = np.frombuffer( - request.shape[idx], dtype="int32") + inst = channel_pb2.Inst() + inst.data = request.feed_insts[idx] + inst.shape = request.shape[idx] + inst.name = name + inst.type = request.type[idx] + pbdata.insts.append(inst) except Exception as e: - return ChannelData( - ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value, - error_info="rpc package error", - data_id=data_id), data_id - else: - return ChannelData( - datatype=ChannelDataType.CHANNEL_NPDATA.value, - npdata=npdata, - data_id=data_id), data_id + pbdata.ecode = ChannelDataEcode.RPC_PACKAGE_ERROR.value + pbdata.error_info = "rpc package error" + return ChannelData( + datatype=ChannelDataType.CHANNEL_PBDATA.value, + pbdata=pbdata), data_id def _pack_data_for_resp(self, channeldata): logging.debug(self._log('get channeldata')) resp = pyservice_pb2.Response() - resp.ecode = channeldata.ecode + resp.ecode = channeldata.pbdata.ecode if resp.ecode == ChannelDataEcode.OK.value: if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value: for inst in channeldata.pbdata.insts: @@ -941,7 +844,7 @@ class GeneralPythonService( self._log("Error type({}) in datatype.".format( channeldata.datatype))) else: - resp.error_info = channeldata.error_info + resp.error_info = channeldata.pbdata.error_info return resp def inference(self, request, context): @@ -961,11 +864,11 @@ class GeneralPythonService( resp_channeldata = self._get_data_in_globel_resp_dict(data_id) _profiler.record("{}-fetch_1".format(self.name)) - if resp_channeldata.ecode == ChannelDataEcode.OK.value: + if resp_channeldata.pbdata.ecode == ChannelDataEcode.OK.value: break if i + 1 < self._retry: logging.warn("retry({}): {}".format( - i + 1, resp_channeldata.error_info)) + i + 1, resp_channeldata.pbdata.error_info)) _profiler.record("{}-postpack_0".format(self.name)) resp = self._pack_data_for_resp(resp_channeldata) @@ -979,12 +882,12 @@ class PyServer(object): self._channels = [] self._user_ops = [] self._actual_ops = [] + self._op_threads = [] self._port = None self._worker_num = None self._in_channel = None self._out_channel = None self._retry = retry - self._manager = multiprocessing.Manager() _profiler.enable(profile) def add_channel(self, channel): @@ -1009,7 +912,6 @@ class PyServer(object): op.name = "#G" # update read_op.name break outdegs = {op.name: [] for op in self._user_ops} - zero_indeg_num, zero_outdeg_num = 0, 0 for idx, op in enumerate(self._user_ops): # check the name of op is globally unique if op.name in indeg_num: @@ -1017,16 +919,8 @@ class PyServer(object): indeg_num[op.name] = len(op.get_input_ops()) if indeg_num[op.name] == 0: ques[que_idx].put(op) - zero_indeg_num += 1 for pred_op in op.get_input_ops(): outdegs[pred_op.name].append(op) - if zero_indeg_num != 1: - raise Exception("DAG contains multiple input Ops") - for _, succ_list in outdegs.items(): - if len(succ_list) == 0: - zero_outdeg_num += 1 - if zero_outdeg_num != 1: - raise Exception("DAG contains multiple output Ops") # topo sort to get dag_views dag_views = [] @@ -1049,6 +943,10 @@ class PyServer(object): que_idx = (que_idx + 1) % 2 if sorted_op_num < len(self._user_ops): raise Exception("not legal DAG") + if len(dag_views[0]) != 1: + raise Exception("DAG contains multiple input Ops") + if len(dag_views[-1]) != 1: + raise Exception("DAG contains multiple output Ops") # create channels and virtual ops def name_generator(prefix): @@ -1086,14 +984,7 @@ class PyServer(object): else: # create virtual op virtual_op = None - if sys.version_info.major == 2: - virtual_op = VirtualOp( - name=virtual_op_name_gen.next()) - elif sys.version_info.major == 3: - virtual_op = VirtualOp( - name=virtual_op_name_gen.__next__()) - else: - raise Exception("Error Python version") + virtual_op = VirtualOp(name=virtual_op_name_gen.next()) virtual_ops.append(virtual_op) outdegs[virtual_op.name] = [succ_op] actual_next_view.append(virtual_op) @@ -1105,14 +996,7 @@ class PyServer(object): for o_idx, op in enumerate(actual_next_view): if op.name in processed_op: continue - if sys.version_info.major == 2: - channel = Channel( - self._manager, name=channel_name_gen.next()) - elif sys.version_info.major == 3: - channel = Channel( - self._manager, name=channel_name_gen.__next__()) - else: - raise Exception("Error Python version") + channel = Channel(name=channel_name_gen.next()) channels.append(channel) logging.debug("{} => {}".format(channel.name, op.name)) op.add_input_channel(channel) @@ -1143,14 +1027,7 @@ class PyServer(object): other_op.name)) other_op.add_input_channel(channel) processed_op.add(other_op.name) - if sys.version_info.major == 2: - output_channel = Channel( - self._manager, name=channel_name_gen.next()) - elif sys.version_info.major == 3: - output_channel = Channel( - self._manager, name=channel_name_gen.__next__()) - else: - raise Exception("Error Python version") + output_channel = Channel(name=channel_name_gen.next()) channels.append(output_channel) last_op = dag_views[-1][0] last_op.add_output_channel(output_channel) @@ -1174,22 +1051,30 @@ class PyServer(object): self._in_channel = input_channel self._out_channel = output_channel for op in self._actual_ops: - if op.with_serving: + if op.with_serving(): self.prepare_serving(op) self.gen_desc() + def _op_start_wrapper(self, op, concurrency_idx): + return op.start(concurrency_idx) + def _run_ops(self): - proces = [] for op in self._actual_ops: - proces.extend(op.start()) - return proces + op_concurrency = op.get_concurrency() + logging.debug("run op: {}, op_concurrency: {}".format( + op.name, op_concurrency)) + for c in range(op_concurrency): + th = threading.Thread( + target=self._op_start_wrapper, args=(op, c)) + th.start() + self._op_threads.append(th) def _stop_ops(self): for op in self._actual_ops: op.stop() def run_server(self): - op_proces = self._run_ops() + self._run_ops() server = grpc.server( futures.ThreadPoolExecutor(max_workers=self._worker_num)) general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server( @@ -1199,8 +1084,8 @@ class PyServer(object): server.start() server.wait_for_termination() self._stop_ops() # TODO - for p in op_proces: - p.join() + for th in self._op_threads: + th.join() def prepare_serving(self, op): model_path = op._server_model