From 0f18e403f043575ab8b714db698be62c996eb13e Mon Sep 17 00:00:00 2001 From: barrierye Date: Thu, 11 Jun 2020 16:41:21 +0800 Subject: [PATCH] [WIP] remove channel def in user side --- python/examples/fit_a_line/test_py_server.py | 28 +- python/paddle_serving_server/pyserver.py | 269 ++++++++++++++----- 2 files changed, 203 insertions(+), 94 deletions(-) diff --git a/python/examples/fit_a_line/test_py_server.py b/python/examples/fit_a_line/test_py_server.py index c7bc5f6b..a372f10f 100644 --- a/python/examples/fit_a_line/test_py_server.py +++ b/python/examples/fit_a_line/test_py_server.py @@ -25,8 +25,6 @@ logging.basicConfig( #level=logging.DEBUG) level=logging.INFO) -# channel data: {name(str): data(narray)} - class CombineOp(Op): def preprocess(self, input_data): @@ -39,13 +37,9 @@ class CombineOp(Op): return data -read_channel = Channel(name="read_channel") -combine_channel = Channel(name="combine_channel") -out_channel = Channel(name="out_channel") - +read_op = Op(name="read", input=None) uci1_op = Op(name="uci1", - input=read_channel, - outputs=[combine_channel], + inputs=[read_op], server_model="./uci_housing_model", server_port="9393", device="cpu", @@ -55,10 +49,8 @@ uci1_op = Op(name="uci1", concurrency=1, timeout=0.1, retry=2) - uci2_op = Op(name="uci2", - input=read_channel, - outputs=[combine_channel], + inputs=[read_op], server_model="./uci_housing_model", server_port="9292", device="cpu", @@ -68,24 +60,14 @@ uci2_op = Op(name="uci2", concurrency=1, timeout=-1, retry=1) - combine_op = CombineOp( name="combine", - input=combine_channel, - outputs=[out_channel], + inputs=[uci1_op, uci2_op], concurrency=1, timeout=-1, retry=1) -logging.info(read_channel.debug()) -logging.info(combine_channel.debug()) -logging.info(out_channel.debug()) pyserver = PyServer(profile=False, retry=1) -pyserver.add_channel(read_channel) -pyserver.add_channel(combine_channel) -pyserver.add_channel(out_channel) -pyserver.add_op(uci1_op) -pyserver.add_op(uci2_op) -pyserver.add_op(combine_op) +pyserver.add_ops([read_op, uci1_op, uci2_op, combine_op]) pyserver.prepare_server(port=8080, worker_num=2) pyserver.run_server() diff --git a/python/paddle_serving_server/pyserver.py b/python/paddle_serving_server/pyserver.py index e37a5f90..0ce967c5 100644 --- a/python/paddle_serving_server/pyserver.py +++ b/python/paddle_serving_server/pyserver.py @@ -31,6 +31,7 @@ import random import time import func_timeout import enum +import collections class _TimeProfiler(object): @@ -140,7 +141,7 @@ class Channel(Queue.Queue): Queue.Queue.__init__(self, maxsize=maxsize) self._maxsize = maxsize self._timeout = timeout - self._name = name + self.name = name self._stop = False self._cv = threading.Condition() @@ -161,7 +162,7 @@ class Channel(Queue.Queue): return self._consumers.keys() def _log(self, info_str): - return "[{}] {}".format(self._name, info_str) + return "[{}] {}".format(self.name, info_str) def debug(self): return self._log("p: {}, c: {}".format(self.get_producers(), @@ -313,8 +314,7 @@ class Channel(Queue.Queue): class Op(object): def __init__(self, name, - input, - outputs, + inputs, server_model=None, server_port=None, device=None, @@ -325,23 +325,24 @@ class Op(object): timeout=-1, retry=2): self._run = False - # TODO: globally unique check - self._name = name # to identify the type of OP, it must be globally unique + self.name = name # to identify the type of OP, it must be globally unique self._concurrency = concurrency # amount of concurrency - self.set_input(input) - self.set_outputs(outputs) - self._client = None - if client_config is not None and \ - server_name is not None and \ - fetch_names is not None: - self.set_client(client_config, server_name, fetch_names) + self.set_input_ops(inputs) + self.set_client(client_config, server_name, fetch_names) self._server_model = server_model self._server_port = server_port self._device = device self._timeout = timeout self._retry = retry + self._input = None + self._outputs = [] def set_client(self, client_config, server_name, fetch_names): + self._client = None + if client_config is None or \ + server_name is None or \ + fetch_names is None: + return self._client = Client() self._client.load_client_config(client_config) self._client.connect([server_name]) @@ -350,28 +351,41 @@ class Op(object): def with_serving(self): return self._client is not None - def get_input(self): + def get_input_channel(self): return self._input - def set_input(self, channel): + def get_input_ops(self): + return self._input_ops + + def set_input_ops(self, ops): + if not isinstance(ops, list): + ops = [] if ops is None else [ops] + self._input_ops = [] + for op in ops: + if not isinstance(op, Op): + raise TypeError( + self._log('input op must be Op type, not {}'.format( + type(op)))) + self._input_ops.append(op) + + def add_input_channel(self, channel): if not isinstance(channel, Channel): raise TypeError( self._log('input channel must be Channel type, not {}'.format( type(channel)))) - channel.add_consumer(self._name) + channel.add_consumer(self.name) self._input = channel - def get_outputs(self): + def get_output_channels(self): return self._outputs - def set_outputs(self, channels): - if not isinstance(channels, list): + def add_output_channel(self, channel): + if not isinstance(channel, Channel): raise TypeError( - self._log('output channels must be list type, not {}'.format( - type(channels)))) - for channel in channels: - channel.add_producer(self._name) - self._outputs = channels + self._log('output channel must be Channel type, not {}'.format( + type(channel)))) + channel.add_producer(self.name) + self._outputs.append(channel) def preprocess(self, channeldata): if isinstance(channeldata, dict): @@ -430,26 +444,26 @@ class Op(object): def start(self, concurrency_idx): self._run = True while self._run: - _profiler.record("{}{}-get_0".format(self._name, concurrency_idx)) - input_data = self._input.front(self._name) - _profiler.record("{}{}-get_1".format(self._name, concurrency_idx)) + _profiler.record("{}{}-get_0".format(self.name, concurrency_idx)) + input_data = self._input.front(self.name) + _profiler.record("{}{}-get_1".format(self.name, concurrency_idx)) logging.debug(self._log("input_data: {}".format(input_data))) data_id, error_data = self._parse_channeldata(input_data) output_data = None if error_data is None: - _profiler.record("{}{}-prep_0".format(self._name, + _profiler.record("{}{}-prep_0".format(self.name, concurrency_idx)) data = self.preprocess(input_data) - _profiler.record("{}{}-prep_1".format(self._name, + _profiler.record("{}{}-prep_1".format(self.name, concurrency_idx)) call_future = None error_info = None if self.with_serving(): for i in range(self._retry): - _profiler.record("{}{}-midp_0".format(self._name, + _profiler.record("{}{}-midp_0".format(self.name, concurrency_idx)) if self._timeout > 0: try: @@ -460,21 +474,21 @@ class Op(object): except func_timeout.FunctionTimedOut: logging.error("error: timeout") error_info = "{}({}): timeout".format( - self._name, concurrency_idx) + self.name, concurrency_idx) except Exception as e: logging.error("error: {}".format(e)) error_info = "{}({}): {}".format( - self._name, concurrency_idx, e) + self.name, concurrency_idx, e) else: call_future = self.midprocess(data) - _profiler.record("{}{}-midp_1".format(self._name, + _profiler.record("{}{}-midp_1".format(self.name, concurrency_idx)) if i + 1 < self._retry: error_info = None logging.warn( self._log("warn: timeout, retry({})".format(i + 1))) - _profiler.record("{}{}-postp_0".format(self._name, + _profiler.record("{}{}-postp_0".format(self.name, concurrency_idx)) if error_info is not None: error_data = self.errorprocess(error_info, data_id) @@ -504,18 +518,18 @@ class Op(object): pbdata.ecode = 0 pbdata.id = data_id output_data = ChannelData(pbdata=pbdata) - _profiler.record("{}{}-postp_1".format(self._name, + _profiler.record("{}{}-postp_1".format(self.name, concurrency_idx)) else: output_data = ChannelData(pbdata=error_data) - _profiler.record("{}{}-push_0".format(self._name, concurrency_idx)) + _profiler.record("{}{}-push_0".format(self.name, concurrency_idx)) for channel in self._outputs: - channel.push(output_data, self._name) - _profiler.record("{}{}-push_1".format(self._name, concurrency_idx)) + channel.push(output_data, self.name) + _profiler.record("{}{}-push_1".format(self.name, concurrency_idx)) def _log(self, info_str): - return "[{}] {}".format(self._name, info_str) + return "[{}] {}".format(self.name, info_str) def get_concurrency(self): return self._concurrency @@ -525,7 +539,7 @@ class GeneralPythonService( general_python_service_pb2_grpc.GeneralPythonService): def __init__(self, in_channel, out_channel, retry=2): super(GeneralPythonService, self).__init__() - self._name = "#G" + self.name = "#G" self.set_in_channel(in_channel) self.set_out_channel(out_channel) logging.debug(self._log(in_channel.debug())) @@ -543,14 +557,14 @@ class GeneralPythonService( self._recive_func.start() def _log(self, info_str): - return "[{}] {}".format(self._name, info_str) + return "[{}] {}".format(self.name, info_str) def set_in_channel(self, in_channel): if not isinstance(in_channel, Channel): raise TypeError( self._log('in_channel must be Channel type, but get {}'.format( type(in_channel)))) - in_channel.add_producer(self._name) + in_channel.add_producer(self.name) self._in_channel = in_channel def set_out_channel(self, out_channel): @@ -558,12 +572,12 @@ class GeneralPythonService( raise TypeError( self._log('out_channel must be Channel type, but get {}'.format( type(out_channel)))) - out_channel.add_consumer(self._name) + out_channel.add_consumer(self.name) self._out_channel = out_channel def _recive_out_channel_func(self): while True: - channeldata = self._out_channel.front(self._name) + channeldata = self._out_channel.front(self.name) if not isinstance(channeldata, ChannelData): raise TypeError( self._log('data must be ChannelData type, but get {}'. @@ -644,38 +658,43 @@ class GeneralPythonService( return resp def inference(self, request, context): - _profiler.record("{}-prepack_0".format(self._name)) + _profiler.record("{}-prepack_0".format(self.name)) data, data_id = self._pack_data_for_infer(request) - _profiler.record("{}-prepack_1".format(self._name)) + _profiler.record("{}-prepack_1".format(self.name)) resp_channeldata = None for i in range(self._retry): logging.debug(self._log('push data')) - _profiler.record("{}-push_0".format(self._name)) - self._in_channel.push(data, self._name) - _profiler.record("{}-push_1".format(self._name)) + _profiler.record("{}-push_0".format(self.name)) + self._in_channel.push(data, self.name) + _profiler.record("{}-push_1".format(self.name)) logging.debug(self._log('wait for infer')) - _profiler.record("{}-fetch_0".format(self._name)) + _profiler.record("{}-fetch_0".format(self.name)) resp_channeldata = self._get_data_in_globel_resp_dict(data_id) - _profiler.record("{}-fetch_1".format(self._name)) + _profiler.record("{}-fetch_1".format(self.name)) if resp_channeldata.pbdata.ecode == 0: break logging.warn("retry({}): {}".format( i + 1, resp_channeldata.pbdata.error_info)) - _profiler.record("{}-postpack_0".format(self._name)) + _profiler.record("{}-postpack_0".format(self.name)) resp = self._pack_data_for_resp(resp_channeldata) - _profiler.record("{}-postpack_1".format(self._name)) + _profiler.record("{}-postpack_1".format(self.name)) _profiler.print_profile() return resp +class VirtualOp(Op): + pass + + class PyServer(object): def __init__(self, retry=2, profile=False): self._channels = [] - self._ops = [] + self._user_ops = [] + self._total_ops = [] self._op_threads = [] self._port = None self._worker_num = None @@ -688,40 +707,147 @@ class PyServer(object): self._channels.append(channel) def add_op(self, op): - self._ops.append(op) + self._user_ops.append(op) + + def add_ops(self, ops): + self._user_ops.expand(ops) def gen_desc(self): logging.info('here will generate desc for PAAS') pass + def _topo_sort(self): + indeg_num = {} + outdegs = {} + que_idx = 0 # scroll queue + ques = [Queue.SimpleQueue() for _ in range(2)] + for idx, op in enumerate(self._user_ops): + # check the name of op is globally unique + if op.name in indeg_num: + raise Exception("the name of Op must be unique") + indeg_num[op.name] = len(op.get_input_ops()) + if indeg_num[op.name] == 0: + ques[que_idx].put(op) + for pred_op in op.get_input_ops(): + if op.name in outdegs: + outdegs[op.name].append(op) + else: + outdegs[op.name] = [op] + + # get dag_views + dag_views = [] + sorted_op_num = 0 + while True: + que = ques[que_idx] + next_que = ques[(que_idx + 1) % 2] + dag_view = [] + while que.qsize() != 0: + op = que.get() + dag_view.append(op) + op_name = op.name + sorted_op_num += 1 + for succ_op in outdegs[op_name]: + indeg_num[op_name] -= 1 + if indeg_num[succ_op.name] == 0: + next_que.put(succ_op) + dag_views.append(dag_view) + if next_que.qsize() == 0: + break + que_idx = (que_idx + 1) % 2 + if sorted_op_num < len(self._user_ops): + raise Exception("not legal DAG") + if len(dag_views[0]) != 1: + raise Exception("DAG contains multiple input Ops") + if len(dag_views[-1]) != 1: + raise Exception("DAG contains multiple output Ops") + + # create channels and virtual ops + virtual_op_idx = 0 + channel_idx = 0 + virtual_ops = [] + channels = [] + input_channel = None + for v_idx, view in enumerate(dag_views): + if v_idx + 1 >= len(dag_views): + break + next_view = dag_views[v_idx + 1] + actual_next_view = [] + pred_op_of_next_view_op = {} + for op in view: + # create virtual op + for succ_op in outdegs[op.name]: + if succ_op in next_view: + actual_next_view.append(succ_op) + if succ_op.name not in pred_op_of_next_view_op: + pred_op_of_next_view_op[succ_op.name] = [] + pred_op_of_next_view_op[succ_op.name].append(op) + else: + vop = VirtualOp(name="vir{}".format(virtual_op_idx)) + virtual_op_idx += 1 + virtual_ops.append(virtual_op) + outdegs[vop.name] = [succ_op] + actual_next_view.append(vop) + # TODO: combine vop + pred_op_of_next_view_op[vop.name] = [op] + # create channel + processed_op = set() + for o_idx, op in enumerate(actual_next_view): + op_name = op.name + if op_name in processed_op: + continue + channel = Channel(name="chl{}".format(channel_idx)) + channel_idx += 1 + channels.append(channel) + op.add_input_channel(channel) + pred_ops = pred_op_of_next_view_op[op_name] + if v_idx == 0: + input_channel = channel + else: + for pred_op in pred_ops: + pred_op.add_output_channel(channel) + processed_op.add(op_name) + # combine channel + for other_op in actual_next_view[o_idx:]: + if other_op.name in processed_op: + continue + other_pred_ops = pred_op_of_next_view_op[other_op.name] + if len(other_pred_ops) != len(pred_ops): + continue + same_flag = True + for pred_op in pred_ops: + if pred_op not in other_pred_ops: + same_flag = False + break + if same_flag: + other_op.add_input_channel(channel) + processed_op.add(other_op.name) + output_channel = Channel(name="Ochl") + channels.append(output_channel) + last_op = dag_views[-1][0] + last_op.add_output_channel(output_channel) + + self._ops = self._user_ops + virtual_ops + self._channels = channels + return input_channel, output_channel + def prepare_server(self, port, worker_num): self._port = port self._worker_num = worker_num - inputs = set() - outputs = set() - for op in self._ops: - inputs |= set([op.get_input()]) - outputs |= set(op.get_outputs()) - if op.with_serving(): - self.prepare_serving(op) - in_channel = inputs - outputs - out_channel = outputs - inputs - if len(in_channel) != 1 or len(out_channel) != 1: - raise Exception( - "in_channel(out_channel) more than 1 or no in_channel(out_channel)" - ) - self._in_channel = in_channel.pop() - self._out_channel = out_channel.pop() + + input_channel, output_channel = self._topo_sort() + self._in_channel = input_channel + self.out_channel = output_channel self.gen_desc() def _op_start_wrapper(self, op, concurrency_idx): return op.start(concurrency_idx) def _run_ops(self): + #TODO for op in self._ops: op_concurrency = op.get_concurrency() logging.debug("run op: {}, op_concurrency: {}".format( - op._name, op_concurrency)) + op.name, op_concurrency)) for c in range(op_concurrency): # th = multiprocessing.Process( th = threading.Thread( @@ -730,6 +856,7 @@ class PyServer(object): self._op_threads.append(th) def _stop_ops(self): + # TODO for op in self._ops: op.stop() -- GitLab