diff --git a/python/examples/imdb/test_py_server.py b/python/examples/imdb/test_py_server.py index e2e8de97987dfcda13ecf9de36653b647e10acb5..d887956400411dec6ec8f806727e8f75f552119c 100644 --- a/python/examples/imdb/test_py_server.py +++ b/python/examples/imdb/test_py_server.py @@ -27,8 +27,6 @@ logging.basicConfig( class CombineOp(Op): - pass - ''' def preprocess(self, input_data): combined_prediction = 0 for op_name, channeldata in input_data.items(): @@ -37,7 +35,6 @@ class CombineOp(Op): combined_prediction += data["prediction"] data = {"combined_prediction": combined_prediction / 2} return data - ''' read_op = Op(name="read", inputs=None) diff --git a/python/paddle_serving_server/pyserver.py b/python/paddle_serving_server/pyserver.py index 9f77b8506c82ed1b32728eacf884347f19dfe608..216a2db140aa19efabc1713db27228df8ef58521 100644 --- a/python/paddle_serving_server/pyserver.py +++ b/python/paddle_serving_server/pyserver.py @@ -98,10 +98,10 @@ class ChannelData(object): ''' There are several ways to use it: - - ChannelData(future, pbdata[, callback_func]) - - ChannelData(future, data_id[, callback_func]) - - ChannelData(pbdata) - - ChannelData(ecode, error_info, data_id) + 1. ChannelData(future, pbdata[, callback_func]) + 2. ChannelData(future, data_id[, callback_func]) + 3. ChannelData(pbdata) + 4. ChannelData(ecode, error_info, data_id) ''' if ecode is not None: if data_id is None or error_info is None: @@ -138,9 +138,8 @@ class ChannelData(object): if self.callback_func is not None: feed = self.callback_func(feed) else: - raise TypeError( - self._log("Error type({}) in pbdata.type.".format( - self.pbdata.type))) + raise TypeError("Error type({}) in pbdata.type.".format( + self.pbdata.type)) return feed @@ -334,6 +333,7 @@ class Channel(Queue.Queue): #TODO self.close() self._stop = True + self._cv.notify_all() class Op(object): @@ -358,7 +358,7 @@ class Op(object): self._server_port = server_port self._device = device self._timeout = timeout - self._retry = retry + self._retry = max(1, retry) self._input = None self._outputs = [] @@ -443,48 +443,51 @@ class Op(object): self._run = False def _parse_channeldata(self, channeldata): - data_id, error_data = None, None + data_id, error_pbdata = None, None if isinstance(channeldata, dict): parsed_data = {} key = channeldata.keys()[0] data_id = channeldata[key].pbdata.id for _, data in channeldata.items(): - if data.pbdata.ecode != 0: - error_data = data.pbdata + if data.pbdata.ecode != ChannelDataEcode.OK.value: + error_pbdata = data.pbdata break else: data_id = channeldata.pbdata.id - if channeldata.pbdata.ecode != 0: - error_data = channeldata.pbdata - return data_id, error_data + if channeldata.pbdata.ecode != ChannelDataEcode.OK.value: + error_pbdata = channeldata.pbdata + return data_id, error_pbdata - def _push_to_output_channels(self, data): + def _push_to_output_channels(self, data, name=None): + if name is None: + name = self.name for channel in self._outputs: - channel.push(data, self.name) + channel.push(data, name) def start(self, concurrency_idx): - op_info_prefix = "[{}{}]".format(self.name, concurrency_idx) + op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) log = self._get_log_func(op_info_prefix) self._run = True while self._run: _profiler.record("{}-get_0".format(op_info_prefix)) - input_data = self._input.front(self.name) + channeldata = self._input.front(self.name) _profiler.record("{}-get_1".format(op_info_prefix)) - logging.debug(log("input_data: {}".format(input_data))) + logging.debug(log("input_data: {}".format(channeldata))) - data_id, error_data = self._parse_channeldata(input_data) + data_id, error_pbdata = self._parse_channeldata(channeldata) - # predecessor Op error - if error_data is not None: - self._push_to_output_channels(ChannelData(pbdata=error_data)) + # error data in predecessor Op + if error_pbdata is not None: + self._push_to_output_channels(ChannelData(pbdata=error_pbdata)) continue - # preprocess function not implemented + # preprecess try: _profiler.record("{}-prep_0".format(op_info_prefix)) - data = self.preprocess(input_data) + preped_data = self.preprocess(channeldata) _profiler.record("{}-prep_1".format(op_info_prefix)) except NotImplementedError as e: + # preprocess function not implemented error_info = log(e) logging.error(error_info) self._push_to_output_channels( @@ -493,18 +496,35 @@ class Op(object): error_info=error_info, data_id=data_id)) continue + except TypeError as e: + # Error type in channeldata.pbdata.type + error_info = log(e) + logging.error(error_info) + self._push_to_output_channels( + ChannelData( + ecode=ChannelDataEcode.TYPE_ERROR.value, + error_info=error_info, + data_id=data_id)) + continue + except Exception as e: + error_info = log(e) + logging.error(error_info) + self._push_to_output_channels( + ChannelData( + ecode=ChannelDataEcode.TYPE_ERROR.value, + error_info=error_info, + data_id=data_id)) + continue # midprocess call_future = None - ecode = 0 - error_info = None if self.with_serving(): + ecode = ChannelDataEcode.OK.value _profiler.record("{}-midp_0".format(op_info_prefix)) if self._timeout <= 0: try: - call_future = self.midprocess(data) + call_future = self.midprocess(preped_data) except Exception as e: - logging.error(self._log(e)) ecode = ChannelDataEcode.UNKNOW.value error_info = log(e) logging.error(error_info) @@ -512,15 +532,17 @@ class Op(object): for i in range(self._retry): try: call_future = func_timeout.func_timeout( - self._timeout, self.midprocess, args=(data, )) - except func_timeout.FunctionTimedOut: + self._timeout, + self.midprocess, + args=(preped_data, )) + except func_timeout.FunctionTimedOut as e: if i + 1 >= self._retry: ecode = ChannelDataEcode.TIMEOUT.value - error_info = "{} timeout".format(op_info_prefix) + error_info = log(e) + logging.error(error_info) else: logging.warn( - log("warn: timeout, retry({})".format(i + - 1))) + log("timeout, retry({})".format(i + 1))) except Exception as e: ecode = ChannelDataEcode.UNKNOW.value error_info = log(e) @@ -528,7 +550,7 @@ class Op(object): break else: break - if ecode != 0: + if ecode != ChannelDataEcode.OK.value: self._push_to_output_channels( ChannelData( ecode=ecode, error_info=error_info, @@ -539,32 +561,63 @@ class Op(object): # postprocess output_data = None _profiler.record("{}-postp_0".format(op_info_prefix)) - if self.with_serving(): # use call_future + if self.with_serving(): + # use call_future output_data = ChannelData( future=call_future, data_id=data_id, callback_func=self.postprocess) else: - post_data = self.postprocess(data) - if not isinstance(post_data, dict): + try: + postped_data = self.postprocess(preped_data) + except Exception as e: + ecode = ChannelDataEcode.UNKNOW.value + error_info = log(e) + logging.error(error_info) + self._push_to_output_channels( + ChannelData( + ecode=ecode, error_info=error_info, + data_id=data_id)) + continue + if not isinstance(postped_data, dict): ecode = ChannelDataEcode.TYPE_ERROR.value error_info = log("output of postprocess funticon must be " \ - "dict type, but get {}".format(type(post_data))) + "dict type, but get {}".format(type(postped_data))) logging.error(error_info) self._push_to_output_channels( ChannelData( ecode=ecode, error_info=error_info, data_id=data_id)) continue + + ecode = ChannelDataEcode.OK.value + error_info = None pbdata = channel_pb2.ChannelData() - for name, value in post_data.items(): + for name, value in postped_data.items(): + if not isinstance(name, (str, unicode)): + ecode = ChannelDataEcode.TYPE_ERROR.value + error_info = log("the key of postped_data must " \ + "be str, but get {}".format(type(name))) + break + if not isinstance(value, np.ndarray): + ecode = ChannelDataEcode.TYPE_ERROR.value + error_info = log("the value of postped_data must " \ + "be np.ndarray, but get {}".format(type(value))) + break inst = channel_pb2.Inst() inst.data = value.tobytes() inst.name = name inst.shape = np.array(value.shape, dtype="int32").tobytes() inst.type = str(value.dtype) pbdata.insts.append(inst) - pbdata.ecode = 0 + if ecode != ChannelDataEcode.OK.value: + logging.error(error_info) + self._push_to_output_channels( + ChannelData( + ecode=ecode, error_info=error_info, + data_id=data_id)) + continue + pbdata.ecode = ecode pbdata.id = data_id output_data = ChannelData(pbdata=pbdata) _profiler.record("{}-postp_1".format(op_info_prefix)) @@ -587,6 +640,45 @@ class Op(object): return self._concurrency +class VirtualOp(Op): + ''' For connecting two channels. ''' + + def __init__(self, name, concurrency=1): + super(VirtualOp, self).__init__( + name=name, inputs=None, concurrency=concurrency) + self._virtual_pred_ops = [] + + def add_virtual_pred_op(self, op): + self._virtual_pred_ops.append(op) + + def add_output_channel(self, channel): + if not isinstance(channel, Channel): + raise TypeError( + self._log('output channel must be Channel type, not {}'.format( + type(channel)))) + for op in self._virtual_pred_ops: + channel.add_producer(op.name) + self._outputs.append(channel) + + def start(self, concurrency_idx): + op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx) + log = self._get_log_func(op_info_prefix) + self._run = True + while self._run: + _profiler.record("{}-get_0".format(op_info_prefix)) + channeldata = self._input.front(self.name) + _profiler.record("{}-get_1".format(op_info_prefix)) + + _profiler.record("{}-push_0".format(op_info_prefix)) + if isinstance(channeldata, dict): + for name, data in channeldata.items(): + self._push_to_output_channels(data, name=name) + else: + self._push_to_output_channels(channeldata, + self._virtual_pred_ops[0].name) + _profiler.record("{}-push_1".format(op_info_prefix)) + + class GeneralPythonService( general_python_service_pb2_grpc.GeneralPythonService): def __init__(self, in_channel, out_channel, retry=2): @@ -668,35 +760,27 @@ class GeneralPythonService( inst.name = name inst.type = request.type[idx] pbdata.insts.append(inst) - pbdata.ecode = 0 #TODO: parse request error + pbdata.ecode = ChannelDataEcode.OK.value #TODO: parse request error return ChannelData(pbdata=pbdata), data_id def _pack_data_for_resp(self, channeldata): logging.debug(self._log('get channeldata')) - logging.debug(self._log('gen resp')) resp = pyservice_pb2.Response() resp.ecode = channeldata.pbdata.ecode - if resp.ecode == 0: + if resp.ecode == ChannelDataEcode.OK.value: if channeldata.pbdata.type == ChannelDataType.CHANNEL_PBDATA.value: for inst in channeldata.pbdata.insts: - logging.debug(self._log('append data')) resp.fetch_insts.append(inst.data) - logging.debug(self._log('append name')) resp.fetch_var_names.append(inst.name) - logging.debug(self._log('append shape')) resp.shape.append(inst.shape) - logging.debug(self._log('append type')) resp.type.append(inst.type) elif channeldata.pbdata.type == ChannelDataType.CHANNEL_FUTURE.value: feed = channeldata.futures.result() if channeldata.callback_func is not None: feed = channeldata.callback_func(feed) for name, var in feed: - logging.debug(self._log('append data')) resp.fetch_insts.append(var.tobytes()) - logging.debug(self._log('append name')) resp.fetch_var_names.append(name) - logging.debug(self._log('append shape')) resp.shape.append( np.array( var.shape, dtype="int32").tobytes()) @@ -726,7 +810,7 @@ class GeneralPythonService( resp_channeldata = self._get_data_in_globel_resp_dict(data_id) _profiler.record("{}-fetch_1".format(self.name)) - if resp_channeldata.pbdata.ecode == 0: + if resp_channeldata.pbdata.ecode == ChannelDataEcode.OK.value: break if i + 1 < self._retry: logging.warn("retry({}): {}".format( @@ -743,7 +827,7 @@ class PyServer(object): def __init__(self, retry=2, profile=False): self._channels = [] self._user_ops = [] - self._total_ops = [] + self._actual_ops = [] self._op_threads = [] self._port = None self._worker_num = None @@ -767,9 +851,13 @@ class PyServer(object): def _topo_sort(self): indeg_num = {} - outdegs = {op.name: [] for op in self._user_ops} que_idx = 0 # scroll queue ques = [Queue.Queue() for _ in range(2)] + for op in self._user_ops: + if len(op.get_input_ops()) == 0: + op.name = "#G" # update read_op.name + break + outdegs = {op.name: [] for op in self._user_ops} for idx, op in enumerate(self._user_ops): # check the name of op is globally unique if op.name in indeg_num: @@ -780,7 +868,7 @@ class PyServer(object): for pred_op in op.get_input_ops(): outdegs[pred_op.name].append(op) - # get dag_views + # topo sort to get dag_views dag_views = [] sorted_op_num = 0 while True: @@ -790,9 +878,8 @@ class PyServer(object): while que.qsize() != 0: op = que.get() dag_view.append(op) - op_name = op.name sorted_op_num += 1 - for succ_op in outdegs[op_name]: + for succ_op in outdegs[op.name]: indeg_num[succ_op.name] -= 1 if indeg_num[succ_op.name] == 0: next_que.put(succ_op) @@ -808,52 +895,69 @@ class PyServer(object): raise Exception("DAG contains multiple output Ops") # create channels and virtual ops - virtual_op_idx = 0 - channel_idx = 0 + def name_generator(prefix): + def number_generator(): + idx = 0 + while True: + yield "{}{}".format(prefix, idx) + idx += 1 + + return number_generator() + + virtual_op_name_gen = name_generator("vir") + channel_name_gen = name_generator("chl") virtual_ops = [] channels = [] input_channel = None + actual_view = None for v_idx, view in enumerate(dag_views): if v_idx + 1 >= len(dag_views): break next_view = dag_views[v_idx + 1] + if actual_view is None: + actual_view = view actual_next_view = [] pred_op_of_next_view_op = {} - for op in view: - # create virtual op + for op in actual_view: + # find actual succ op in next view and create virtual op for succ_op in outdegs[op.name]: if succ_op in next_view: - actual_next_view.append(succ_op) + if succ_op not in actual_next_view: + actual_next_view.append(succ_op) if succ_op.name not in pred_op_of_next_view_op: pred_op_of_next_view_op[succ_op.name] = [] pred_op_of_next_view_op[succ_op.name].append(op) else: - vop = Op(name="vir{}".format(virtual_op_idx), inputs=[]) - virtual_op_idx += 1 + # create virtual op + virtual_op = None + virtual_op = VirtualOp(name=virtual_op_name_gen.next()) virtual_ops.append(virtual_op) - outdegs[vop.name] = [succ_op] - actual_next_view.append(vop) - # TODO: combine vop - pred_op_of_next_view_op[vop.name] = [op] + outdegs[virtual_op.name] = [succ_op] + actual_next_view.append(virtual_op) + pred_op_of_next_view_op[virtual_op.name] = [op] + virtual_op.add_virtual_pred_op(op) + actual_view = actual_next_view # create channel processed_op = set() for o_idx, op in enumerate(actual_next_view): - op_name = op.name - if op_name in processed_op: + if op.name in processed_op: continue - channel = Channel(name="chl{}".format(channel_idx)) - channel_idx += 1 + channel = Channel(name=channel_name_gen.next()) channels.append(channel) + logging.debug("{} => {}".format(channel.name, op.name)) op.add_input_channel(channel) - pred_ops = pred_op_of_next_view_op[op_name] + pred_ops = pred_op_of_next_view_op[op.name] if v_idx == 0: input_channel = channel else: + # if pred_op is virtual op, it will use ancestors as producers to channel for pred_op in pred_ops: + logging.debug("{} => {}".format(pred_op.name, + channel.name)) pred_op.add_output_channel(channel) - processed_op.add(op_name) - # combine channel - for other_op in actual_next_view[o_idx:]: + processed_op.add(op.name) + # find same input op to combine channel + for other_op in actual_next_view[o_idx + 1:]: if other_op.name in processed_op: continue other_pred_ops = pred_op_of_next_view_op[other_op.name] @@ -865,19 +969,24 @@ class PyServer(object): same_flag = False break if same_flag: + logging.debug("{} => {}".format(channel.name, + other_op.name)) other_op.add_input_channel(channel) processed_op.add(other_op.name) - output_channel = Channel(name="Ochl") + output_channel = Channel(name=channel_name_gen.next()) channels.append(output_channel) last_op = dag_views[-1][0] last_op.add_output_channel(output_channel) - self._ops = virtual_ops + self._actual_ops = virtual_ops for op in self._user_ops: if len(op.get_input_ops()) == 0: + # pass read op continue - self._ops.append(op) + self._actual_ops.append(op) self._channels = channels + for c in channels: + logging.debug(c.debug()) return input_channel, output_channel def prepare_server(self, port, worker_num): @@ -887,7 +996,7 @@ class PyServer(object): input_channel, output_channel = self._topo_sort() self._in_channel = input_channel self._out_channel = output_channel - for op in self._ops: + for op in self._actual_ops: if op.with_serving(): self.prepare_serving(op) self.gen_desc() @@ -896,7 +1005,7 @@ class PyServer(object): return op.start(concurrency_idx) def _run_ops(self): - for op in self._ops: + for op in self._actual_ops: op_concurrency = op.get_concurrency() logging.debug("run op: {}, op_concurrency: {}".format( op.name, op_concurrency)) @@ -907,7 +1016,7 @@ class PyServer(object): self._op_threads.append(th) def _stop_ops(self): - for op in self._ops: + for op in self._actual_ops: op.stop() def run_server(self): @@ -921,6 +1030,8 @@ class PyServer(object): server.start() server.wait_for_termination() self._stop_ops() # TODO + for th in self._op_threads: + th.join() def prepare_serving(self, op): model_path = op._server_model @@ -935,5 +1046,3 @@ class PyServer(object): " --model {} --thread 4 --port {} --use_multilang &>/dev/null &".format(model_path, port) # run a server (not in PyServing) logging.info("run a server (not in PyServing): {}".format(cmd)) - return - # os.system(cmd)