diff --git a/python/paddle_serving_server/pyserver.py b/python/paddle_serving_server/pyserver.py index dfe6d73dfabe710859332ca86523c4ab54d1747b..393e9270268411c1d64809b9204c05369b4d0d31 100644 --- a/python/paddle_serving_server/pyserver.py +++ b/python/paddle_serving_server/pyserver.py @@ -14,6 +14,7 @@ # pylint: disable=doc-string-missing import threading import multiprocessing +import multiprocessing.queues import Queue import os import sys @@ -109,34 +110,27 @@ class ChannelData(object): 4. ChannelData(ChannelDataType.CHANNEL_PBDATA.value, npdata, data_id) 5. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id) 6. ChannelData(ecode, error_info, data_id) + + Protobufs are not pickle-able: + https://stackoverflow.com/questions/55344376/how-to-import-protobuf-module ''' if ecode is not None: if data_id is None or error_info is None: raise ValueError("data_id and error_info cannot be None") - pbdata = channel_pb2.ChannelData() - pbdata.ecode = ecode - pbdata.id = data_id - pbdata.error_info = error_info datatype = ChannelDataType.ERROR.value else: if datatype == ChannelDataType.CHANNEL_FUTURE.value: - if pbdata is None: - if data_id is None: - raise ValueError("data_id cannot be None") - pbdata = channel_pb2.ChannelData() - pbdata.ecode = ChannelDataEcode.OK.value - pbdata.id = data_id + if data_id is None: + raise ValueError("data_id cannot be None") + ecode = ChannelDataEcode.OK.value elif datatype == ChannelDataType.CHANNEL_PBDATA.value: if pbdata is None: if data_id is None: raise ValueError("data_id cannot be None") pbdata = channel_pb2.ChannelData() - pbdata.id = data_id ecode, error_info = self._check_npdata(npdata) - pbdata.ecode = ecode - if pbdata.ecode != ChannelDataEcode.OK.value: - pbdata.error_info = error_info - logging.error(pbdata.error_info) + if ecode != ChannelDataEcode.OK.value: + logging.error(error_info) else: for name, value in npdata.items(): inst = channel_pb2.Inst() @@ -148,23 +142,18 @@ class ChannelData(object): pbdata.insts.append(inst) elif datatype == ChannelDataType.CHANNEL_NPDATA.value: ecode, error_info = self._check_npdata(npdata) - pbdata = channel_pb2.ChannelData() - pbdata.id = data_id - pbdata.ecode = ecode - if pbdata.ecode != ChannelDataEcode.OK.value: - pbdata.error_info = error_info - logging.error(pbdata.error_info) + if ecode != ChannelDataEcode.OK.value: + logging.error(error_info) else: raise ValueError("datatype not match") - if not isinstance(pbdata, channel_pb2.ChannelData): - raise TypeError( - "pbdata must be pyserving_channel_pb2.ChannelData type({})". - format(type(pbdata))) self.future = future self.pbdata = pbdata self.npdata = npdata self.datatype = datatype self.callback_func = callback_func + self.id = data_id + self.ecode = ecode + self.error_info = error_info def _check_npdata(self, npdata): ecode = ChannelDataEcode.OK.value @@ -176,8 +165,8 @@ class ChannelData(object): "be str, but get {}".format(type(name))) break if not isinstance(value, np.ndarray): - pbdata.ecode = ChannelDataEcode.TYPE_ERROR.value - pbdata.error_info = log("the value of postped_data must " \ + ecode = ChannelDataEcode.TYPE_ERROR.value + error_info = log("the value of postped_data must " \ "be np.ndarray, but get {}".format(type(value))) break return ecode, error_info @@ -197,16 +186,15 @@ class ChannelData(object): elif self.datatype == ChannelDataType.CHANNEL_NPDATA.value: feed = self.npdata else: - raise TypeError("Error type({}) in pbdata.type.".format( - self.pbdata.type)) + raise TypeError("Error type({}) in datatype.".format(self.datatype)) return feed def __str__(self): return "type[{}], ecode[{}]".format( - ChannelDataType(self.datatype).name, self.pbdata.ecode) + ChannelDataType(self.datatype).name, self.ecode) -class Channel(Queue.Queue): +class Channel(multiprocessing.queues.Queue): """ The channel used for communication between Ops. @@ -225,7 +213,8 @@ class Channel(Queue.Queue): """ def __init__(self, name=None, maxsize=-1, timeout=None): - Queue.Queue.__init__(self, maxsize=maxsize) + # https://stackoverflow.com/questions/39496554/cannot-subclass-multiprocessing-queue-in-python-3-5 + multiprocessing.queues.Queue.__init__(self, maxsize=maxsize) self._maxsize = maxsize self._timeout = timeout self.name = name @@ -299,7 +288,7 @@ class Channel(Queue.Queue): "There are multiple producers, so op_name cannot be None.")) producer_num = len(self._producers) - data_id = channeldata.pbdata.id + data_id = channeldata.id put_data = None with self._cv: logging.debug(self._log("{} get lock".format(op_name))) @@ -418,29 +407,41 @@ class Op(object): self.name = name # to identify the type of OP, it must be globally unique self._concurrency = concurrency # amount of concurrency self.set_input_ops(inputs) - self.set_client(client_config, server_name, fetch_names) - self._server_model = server_model - self._server_port = server_port - self._device = device self._timeout = timeout self._retry = max(1, retry) self._input = None self._outputs = [] - def set_client(self, client_config, server_name, fetch_names): + self.with_serving = False + self._client_config = client_config + self._server_name = server_name + self._fetch_names = fetch_names + self._server_model = server_model + self._server_port = server_port + self._device = device + if self._client_config is not None and \ + self._server_name is not None and \ + self._fetch_names is not None and \ + self._server_model is not None and \ + self._server_port is not None and \ + self._device is not None: + self.with_serving = True + + def init_client(self, client_config, server_name, fetch_names): self._client = None if client_config is None or \ server_name is None or \ fetch_names is None: + logging.debug("no client") return + logging.debug("client_config: {}".format(client_config)) + logging.debug("server_name: {}".format(server_name)) + logging.debug("fetch_names: {}".format(fetch_names)) self._client = Client() self._client.load_client_config(client_config) self._client.connect([server_name]) self._fetch_names = fetch_names - def with_serving(self): - return self._client is not None - def get_input_channel(self): return self._input @@ -508,45 +509,56 @@ class Op(object): self._run = False def _parse_channeldata(self, channeldata): - data_id, error_pbdata = None, None + data_id, error_channeldata = None, None if isinstance(channeldata, dict): parsed_data = {} key = channeldata.keys()[0] - data_id = channeldata[key].pbdata.id + data_id = channeldata[key].id for _, data in channeldata.items(): - if data.pbdata.ecode != ChannelDataEcode.OK.value: - error_pbdata = data.pbdata + if data.ecode != ChannelDataEcode.OK.value: + error_channeldata = data break else: - data_id = channeldata.pbdata.id - if channeldata.pbdata.ecode != ChannelDataEcode.OK.value: - error_pbdata = channeldata.pbdata - return data_id, error_pbdata + data_id = channeldata.id + if channeldata.ecode != ChannelDataEcode.OK.value: + error_channeldata = channeldata + return data_id, error_channeldata - def _push_to_output_channels(self, data, name=None): + def _push_to_output_channels(self, data, channels, name=None): if name is None: name = self.name - for channel in self._outputs: + for channel in channels: channel.push(data, name) - def start(self, concurrency_idx): + def start(self): + proces = [] + for concurrency_idx in range(self._concurrency): + p = multiprocessing.Process( + target=self._run, + args=(concurrency_idx, self.get_input_channel(), + self.get_output_channels())) + p.start() + proces.append(p) + return proces + + def _run(self, input_channel, output_channels): + self.init_client(self._client_config, self._server_name, + self._fetch_names) op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) log = self._get_log_func(op_info_prefix) self._run = True while self._run: _profiler.record("{}-get_0".format(op_info_prefix)) - channeldata = self._input.front(self.name) + channeldata = input_channel.front(self.name) _profiler.record("{}-get_1".format(op_info_prefix)) logging.debug(log("input_data: {}".format(channeldata))) - data_id, error_pbdata = self._parse_channeldata(channeldata) + data_id, error_channeldata = self._parse_channeldata(channeldata) # error data in predecessor Op - if error_pbdata is not None: - self._push_to_output_channels( - ChannelData( - datatype=ChannelDataType.CHANNEL_PBDATA.value, - pbdata=error_pbdata)) + if error_channeldata is not None: + self._push_to_output_channels(error_channeldata, + output_channels) continue # preprecess @@ -562,17 +574,19 @@ class Op(object): ChannelData( ecode=ChannelDataEcode.NOT_IMPLEMENTED.value, error_info=error_info, - data_id=data_id)) + data_id=data_id), + output_channels) continue except TypeError as e: - # Error type in channeldata.pbdata.type + # Error type in channeldata.datatype error_info = log(e) logging.error(error_info) self._push_to_output_channels( ChannelData( ecode=ChannelDataEcode.TYPE_ERROR.value, error_info=error_info, - data_id=data_id)) + data_id=data_id), + output_channels) continue except Exception as e: error_info = log(e) @@ -581,12 +595,13 @@ class Op(object): ChannelData( ecode=ChannelDataEcode.TYPE_ERROR.value, error_info=error_info, - data_id=data_id)) + data_id=data_id), + output_channels) continue # midprocess call_future = None - if self.with_serving(): + if self.with_serving: ecode = ChannelDataEcode.OK.value _profiler.record("{}-midp_0".format(op_info_prefix)) if self._timeout <= 0: @@ -622,14 +637,15 @@ class Op(object): self._push_to_output_channels( ChannelData( ecode=ecode, error_info=error_info, - data_id=data_id)) + data_id=data_id), + output_channels) continue _profiler.record("{}-midp_1".format(op_info_prefix)) # postprocess output_data = None _profiler.record("{}-postp_0".format(op_info_prefix)) - if self.with_serving(): + if self.with_serving: # use call_future output_data = ChannelData( datatype=ChannelDataType.CHANNEL_FUTURE.value, @@ -646,7 +662,8 @@ class Op(object): self._push_to_output_channels( ChannelData( ecode=ecode, error_info=error_info, - data_id=data_id)) + data_id=data_id), + output_channels) continue if not isinstance(postped_data, dict): ecode = ChannelDataEcode.TYPE_ERROR.value @@ -656,7 +673,8 @@ class Op(object): self._push_to_output_channels( ChannelData( ecode=ecode, error_info=error_info, - data_id=data_id)) + data_id=data_id), + output_channels) continue output_data = ChannelData( @@ -667,7 +685,7 @@ class Op(object): # push data to channel (if run succ) _profiler.record("{}-push_0".format(op_info_prefix)) - self._push_to_output_channels(output_data) + self._push_to_output_channels(output_data, output_channels) _profiler.record("{}-push_1".format(op_info_prefix)) def _log(self, info): @@ -703,22 +721,25 @@ class VirtualOp(Op): channel.add_producer(op.name) self._outputs.append(channel) - def start(self, concurrency_idx): + def _run(self, input_channel, output_channels): op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) log = self._get_log_func(op_info_prefix) self._run = True while self._run: _profiler.record("{}-get_0".format(op_info_prefix)) - channeldata = self._input.front(self.name) + channeldata = input_channel.front(self.name) _profiler.record("{}-get_1".format(op_info_prefix)) _profiler.record("{}-push_0".format(op_info_prefix)) if isinstance(channeldata, dict): for name, data in channeldata.items(): - self._push_to_output_channels(data, name=name) + self._push_to_output_channels( + data, channels=output_channels, name=name) else: - self._push_to_output_channels(channeldata, - self._virtual_pred_ops[0].name) + self._push_to_output_channels( + channeldata, + channels=output_channels, + name=self._virtual_pred_ops[0].name) _profiler.record("{}-push_1".format(op_info_prefix)) @@ -770,7 +791,7 @@ class GeneralPythonService( self._log('data must be ChannelData type, but get {}'. format(type(channeldata)))) with self._cv: - data_id = channeldata.pbdata.id + data_id = channeldata.id self._globel_resp_dict[data_id] = channeldata self._cv.notify_all() @@ -790,33 +811,33 @@ class GeneralPythonService( def _pack_data_for_infer(self, request): logging.debug(self._log('start inferce')) - pbdata = channel_pb2.ChannelData() data_id = self._get_next_id() - pbdata.id = data_id - pbdata.ecode = ChannelDataEcode.OK.value + npdata = {} try: for idx, name in enumerate(request.feed_var_names): logging.debug( self._log('name: {}'.format(request.feed_var_names[idx]))) logging.debug( self._log('data: {}'.format(request.feed_insts[idx]))) - inst = channel_pb2.Inst() - inst.data = request.feed_insts[idx] - inst.shape = request.shape[idx] - inst.name = name - inst.type = request.type[idx] - pbdata.insts.append(inst) + npdata[name] = np.frombuffer( + request.feed_insts[idx], dtype=request.type[idx]) + npdata[name].shape = np.frombuffer( + request.shape[idx], dtype="int32") except Exception as e: - pbdata.ecode = ChannelDataEcode.RPC_PACKAGE_ERROR.value - pbdata.error_info = "rpc package error" - return ChannelData( - datatype=ChannelDataType.CHANNEL_PBDATA.value, - pbdata=pbdata), data_id + return ChannelData( + ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value, + error_info="rpc package error", + data_id=data_id), data_id + else: + return ChannelData( + datatype=ChannelDataType.CHANNEL_NPDATA.value, + npdata=npdata, + data_id=data_id), data_id def _pack_data_for_resp(self, channeldata): logging.debug(self._log('get channeldata')) resp = pyservice_pb2.Response() - resp.ecode = channeldata.pbdata.ecode + resp.ecode = channeldata.ecode if resp.ecode == ChannelDataEcode.OK.value: if channeldata.datatype == ChannelDataType.CHANNEL_PBDATA.value: for inst in channeldata.pbdata.insts: @@ -836,10 +857,10 @@ class GeneralPythonService( resp.type.append(str(var.dtype)) else: raise TypeError( - self._log("Error type({}) in pbdata.type.".format( + self._log("Error type({}) in datatype.".format( channeldata.datatype))) else: - resp.error_info = channeldata.pbdata.error_info + resp.error_info = channeldata.error_info return resp def inference(self, request, context): @@ -859,11 +880,11 @@ class GeneralPythonService( resp_channeldata = self._get_data_in_globel_resp_dict(data_id) _profiler.record("{}-fetch_1".format(self.name)) - if resp_channeldata.pbdata.ecode == ChannelDataEcode.OK.value: + if resp_channeldata.ecode == ChannelDataEcode.OK.value: break if i + 1 < self._retry: logging.warn("retry({}): {}".format( - i + 1, resp_channeldata.pbdata.error_info)) + i + 1, resp_channeldata.error_info)) _profiler.record("{}-postpack_0".format(self.name)) resp = self._pack_data_for_resp(resp_channeldata) @@ -877,7 +898,6 @@ class PyServer(object): self._channels = [] self._user_ops = [] self._actual_ops = [] - self._op_threads = [] self._port = None self._worker_num = None self._in_channel = None @@ -1046,30 +1066,22 @@ class PyServer(object): self._in_channel = input_channel self._out_channel = output_channel for op in self._actual_ops: - if op.with_serving(): + if op.with_serving: self.prepare_serving(op) self.gen_desc() - def _op_start_wrapper(self, op, concurrency_idx): - return op.start(concurrency_idx) - def _run_ops(self): + proces = [] for op in self._actual_ops: - op_concurrency = op.get_concurrency() - logging.debug("run op: {}, op_concurrency: {}".format( - op.name, op_concurrency)) - for c in range(op_concurrency): - th = threading.Thread( - target=self._op_start_wrapper, args=(op, c)) - th.start() - self._op_threads.append(th) + proces.extend(op.start()) + return proces def _stop_ops(self): for op in self._actual_ops: op.stop() def run_server(self): - self._run_ops() + op_proces = self._run_ops() server = grpc.server( futures.ThreadPoolExecutor(max_workers=self._worker_num)) general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server( @@ -1079,8 +1091,8 @@ class PyServer(object): server.start() server.wait_for_termination() self._stop_ops() # TODO - for th in self._op_threads: - th.join() + for p in op_proces: + p.join() def prepare_serving(self, op): model_path = op._server_model