diff --git a/python/pipeline/dag.py b/python/pipeline/dag.py new file mode 100644 index 0000000000000000000000000000000000000000..24128e966c461035b4db41a9a6c48e3f9bfadb58 --- /dev/null +++ b/python/pipeline/dag.py @@ -0,0 +1,398 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=doc-string-missing +import threading +import multiprocessing +import sys +if sys.version_info.major == 2: + import Queue +elif sys.version_info.major == 3: + import queue as Queue +else: + raise Exception("Error Python version") +import os +import logging + +from .operator import Op, RequestOp, ResponseOp, VirtualOp +from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcode, ChannelDataType +from .util import NameGenerator + +_LOGGER = logging.getLogger() + + +class DAGExecutor(object): + def __init__(self, response_op, profiler, use_multithread, retry, + client_type, channel_size): + self.name = "#G" + self._retry = min(retry, 1) + self._profiler = profiler + self._dag = DAG(response_op, profiler, use_multithread, client_type, + channel_size) + in_channel, out_channel, pack_rpc_func, unpack_rpc_func = self._dag.build( + ) + self._dag.start() + + self._set_in_channel(in_channel) + self.set_out_channel(out_channel) + self._pack_rpc_func = pack_rpc_func + self._unpack_rpc_func = unpack_rpc_func + + _LOGGER.debug(self._log(in_channel.debug())) + _LOGGER.debug(self._log(out_channel.debug())) + + self._id_lock = threading.Lock() + self._cv = threading.Condition() + self._globel_resp_dict = {} + self._id_counter = 0 + self._reset_max_id = 1000000000000000000 + self._is_run = True + self._recive_func = threading.Thread( + target=DAGExecutor._recive_out_channel_func, args=(self, )) + self._recive_func.start() + + def stop(self): + self._is_run = False + self._dag.stop() + self._dag.join() + + def _get_next_data_id(self): + with self._id_lock: + if self._id_counter >= self._reset_max_id: + self._id_counter -= self._reset_max_id + self._id_counter += 1 + return self._id_counter - 1 + + def _set_in_channel(self, in_channel): + if not isinstance(in_channel, (ThreadChannel, ProcessChannel)): + raise TypeError( + self._log('in_channel must be Channel type, but get {}'.format( + type(in_channel)))) + in_channel.add_producer(self.name) + self._in_channel = in_channel + + def _set_out_channel(self, out_channel): + if not isinstance(out_channel, (ThreadChannel, ProcessChannel)): + raise TypeError( + self._log('out_channel must be Channel type, but get {}'.format( + type(out_channel)))) + out_channel.add_consumer(self.name) + self._out_channel = out_channel + + def _recive_out_channel_func(self): + while self._is_run: + channeldata_dict = self._out_channel.front(self.name) + if len(channeldata_dict) != 1: + _LOGGER.error("out_channel cannot have multiple input ops") + os._exit(-1) + (_, channeldata), = channeldata_dict.items() + if not isinstance(channeldata, ChannelData): + raise TypeError( + self._log('data must be ChannelData type, but get {}'. + format(type(channeldata)))) + with self._cv: + data_id = channeldata.id + self._globel_resp_dict[data_id] = channeldata + self._cv.notify_all() + + def _get_channeldata_from_fetch_buffer(self, data_id): + resp = None + with self._cv: + while data_id not in self._globel_resp_dict: + self._cv.wait() + resp = self._globel_resp_dict.pop(data_id) + self._cv.notify_all() + return resp + + def _pack_channeldata(self, rpc_request): + _LOGGER.debug(self._log('start inferce')) + data_id = self._get_next_data_id() + dictdata = None + try: + dictdata = self._unpack_rpc_func(rpc_request) + except Exception as e: + return ChannelData( + ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value, + error_info="rpc package error: {}".format(e), + data_id=data_id), data_id + else: + return ChannelData( + datatype=ChannelDataType.DICT.value, + dictdata=dictdata, + data_id=data_id), data_id + + def call(self, rpc_request): + self._profiler.record("{}-prepack_0".format(self.name)) + req_channeldata, data_id = self._pack_channeldata(rpc_request) + self._profiler.record("{}-prepack_1".format(self.name)) + + resp_channeldata = None + for i in range(self._retry): + _LOGGER.debug(self._log('push data')) + #self._profiler.record("{}-push_0".format(self.name)) + self._in_channel.push(req_channeldata, self.name) + #self._profiler.record("{}-push_1".format(self.name)) + + _LOGGER.debug(self._log('wait for infer')) + #self._profiler.record("{}-fetch_0".format(self.name)) + resp_channeldata = self._get_channeldata_from_fetch_buffer(data_id) + #self._profiler.record("{}-fetch_1".format(self.name)) + + if resp_channeldata.ecode == ChannelDataEcode.OK.value: + break + if i + 1 < self._retry: + _LOGGER.warn("retry({}): {}".format( + i + 1, resp_channeldata.error_info)) + + self._profiler.record("{}-postpack_0".format(self.name)) + rpc_resp = self._pack_for_rpc_resp(resp_channeldata) + self._profiler.record("{}-postpack_1".format(self.name)) + self._profiler.print_profile() + return rpc_resp + + def _pack_for_rpc_resp(self, channeldata): + _LOGGER.debug(self._log('get channeldata')) + return self._pack_rpc_func(channeldata) + + def _log(self, info_str): + return "[{}] {}".format(self.name, info_str) + + +class DAG(object): + def __init__(slef, response_op, profiler, use_multithread, client_type, + channel_size): + self._response_op = response_op + self._use_multithread = use_multithread + self._channel_size = channel_size + self._client_type = client_type + if not self._use_multithread: + self._manager = multiprocessing.Manager() + self._profiler = profiler + + def get_use_ops(self, response_op): + unique_names = set() + use_ops = set() + succ_ops_of_use_op = {} # {op_name: succ_ops} + que = Queue.Queue() + que.put(response_op) + while que.qsize() != 0: + op = que.get() + for pred_op in op.get_input_ops(): + if pred_op.name not in succ_ops_of_use_op: + succ_ops_of_use_op[pred_op.name] = [] + if op != response_op: + succ_ops_of_use_op[pred_op.name].append(op) + if pred_op not in use_ops: + que.put(pred_op) + use_ops.add(pred_op) + # check the name of op is globally unique + if pred_op.name in unique_names: + raise Exception("the name of Op must be unique: {}". + format(pred_op.name)) + unique_names.add(pred_op.name) + return use_ops, succ_ops_of_use_op + + def _gen_channel(self, name_gen): + channel = None + if self._use_multithread: + channel = ThreadChannel( + name=name_gen.next(), maxsize=self._channel_size) + else: + channel = ProcessChannel( + self._manager, name=name_gen.next(), maxsize=self._channel_size) + return channel + + def _gen_virtual_op(self, name_gen): + return VirtualOp(name=name_gen.next()) + + def _topo_sort(self, used_ops, response_op, out_degree_ops): + out_degree_num = { + name: len(ops) + for name, ops in out_degree_ops.items() + } + que_idx = 0 # scroll queue + ques = [Queue.Queue() for _ in range(2)] + zero_indegree_num = 0 + for op in use_ops: + if len(op.get_input_ops()) == 0: + zero_indegree_num += 1 + if zero_indegree_num != 1: + raise Exception("DAG contains multiple input Ops") + last_op = response_op.get_input_ops()[0] + ques[que_idx].put(last_op) + + # topo sort to get dag_views + dag_views = [] + sorted_op_num = 0 + while True: + que = ques[que_idx] + next_que = ques[(que_idx + 1) % 2] + dag_view = [] + while que.qsize() != 0: + op = que.get() + dag_view.append(op) + sorted_op_num += 1 + for pred_op in op.get_input_ops(): + out_degree_num[pred_op.name] -= 1 + if out_degree_num[pred_op.name] == 0: + next_que.put(pred_op) + dag_views.append(dag_view) + if next_que.qsize() == 0: + break + que_idx = (que_idx + 1) % 2 + if sorted_op_num < len(use_ops): + raise Exception("not legal DAG") + + return dag_views, last_op + + def build(slef, response_op): + if response_op is None: + raise Exception("response_op has not been set.") + use_ops, out_degree_ops = self.get_use_ops(response_op) + _LOGGER.info("================= use op ==================") + for op in use_ops: + _LOGGER.info(op.name) + _LOGGER.info("===========================================") + if len(use_ops) <= 1: + raise Exception( + "Besides RequestOp and ResponseOp, there should be at least one Op in DAG." + ) + + dag_views, last_op = self._topo_sort(used_ops, response_op, + out_degree_ops) + + # create channels and virtual ops + virtual_op_name_gen = NameGenerator("vir") + channel_name_gen = NameGenerator("chl") + virtual_ops = [] + channels = [] + input_channel = None + actual_view = None + dag_views = list(reversed(dag_views)) + for v_idx, view in enumerate(dag_views): + if v_idx + 1 >= len(dag_views): + break + next_view = dag_views[v_idx + 1] + if actual_view is None: + actual_view = view + actual_next_view = [] + pred_op_of_next_view_op = {} + for op in actual_view: + # find actual succ op in next view and create virtual op + for succ_op in out_degree_ops[op.name]: + if succ_op in next_view: + if succ_op not in actual_next_view: + actual_next_view.append(succ_op) + if succ_op.name not in pred_op_of_next_view_op: + pred_op_of_next_view_op[succ_op.name] = [] + pred_op_of_next_view_op[succ_op.name].append(op) + else: + # create virtual op + virtual_op = self._gen_virtual_op(virtual_op_name_gen) + virtual_ops.append(virtual_op) + out_degree_ops[virtual_op.name] = [succ_op] + actual_next_view.append(virtual_op) + pred_op_of_next_view_op[virtual_op.name] = [op] + virtual_op.add_virtual_pred_op(op) + actual_view = actual_next_view + # create channel + processed_op = set() + for o_idx, op in enumerate(actual_next_view): + if op.name in processed_op: + continue + channel = self._gen_channel(channel_name_gen) + channels.append(channel) + _LOGGER.debug("{} => {}".format(channel.name, op.name)) + op.add_input_channel(channel) + pred_ops = pred_op_of_next_view_op[op.name] + if v_idx == 0: + input_channel = channel + else: + # if pred_op is virtual op, it will use ancestors as producers to channel + for pred_op in pred_ops: + _LOGGER.debug("{} => {}".format(pred_op.name, + channel.name)) + pred_op.add_output_channel(channel) + processed_op.add(op.name) + # find same input op to combine channel + for other_op in actual_next_view[o_idx + 1:]: + if other_op.name in processed_op: + continue + other_pred_ops = pred_op_of_next_view_op[other_op.name] + if len(other_pred_ops) != len(pred_ops): + continue + same_flag = True + for pred_op in pred_ops: + if pred_op not in other_pred_ops: + same_flag = False + break + if same_flag: + _LOGGER.debug("{} => {}".format(channel.name, + other_op.name)) + other_op.add_input_channel(channel) + processed_op.add(other_op.name) + output_channel = self._gen_channel(channel_name_gen) + channels.append(output_channel) + last_op.add_output_channel(output_channel) + + pack_func, unpack_func = None, None + pack_func = response_op.pack_response_package + + actual_ops = virtual_ops + for op in use_ops: + if len(op.get_input_ops()) == 0: + unpack_func = op.unpack_request_package + continue + actual_ops.append(op) + + for c in channels: + _LOGGER.debug(c.debug()) + + return (actual_ops, channels, input_channel, output_channel, pack_func, + unpack_func) + + def build(self): + (actual_ops, channels, input_channel, output_channel, pack_func, + unpack_func) = self._topo_sort(self._response_op) + + self._actual_ops = actual_ops + self._channels = channels + self._input_channel = input_channel + self._output_channel = output_channel + self._pack_func = pack_func + self._unpack_func = unpack_func + + return self._input_channel, self._output_channel, self._pack_func, self._unpack_func + + def start(self): + self._threads_or_proces = [] + for op in self._actual_ops: + op.init_profiler(self._profiler) + if self._use_multithread: + threads_or_proces.extend( + op.start_with_thread(self._client_type)) + else: + threads_or_proces.extend( + op.start_with_process(self._client_type)) + # not join yet + return self._threads_or_proces + + def join(self): + for x in self._threads_or_proces: + x.join() + + def stop(self): + for op in self._actual_ops: + op.stop() + for chl in self._channels: + chl.stop()