# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=doc-string-missing import threading import multiprocessing import sys if sys.version_info.major == 2: import Queue elif sys.version_info.major == 3: import queue as Queue else: raise Exception("Error Python version") import os import logging from .operator import Op, RequestOp, ResponseOp, VirtualOp from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcode, ChannelDataType from .util import NameGenerator _LOGGER = logging.getLogger() class DAGExecutor(object): def __init__(self, response_op, profiler, use_multithread, retry, client_type, channel_size): self.name = "#G" self._retry = min(retry, 1) self._profiler = profiler self._dag = DAG(response_op, profiler, use_multithread, client_type, channel_size) in_channel, out_channel, pack_rpc_func, unpack_rpc_func = self._dag.build( ) self._dag.start() self._set_in_channel(in_channel) self._set_out_channel(out_channel) self._pack_rpc_func = pack_rpc_func self._unpack_rpc_func = unpack_rpc_func _LOGGER.debug(self._log(in_channel.debug())) _LOGGER.debug(self._log(out_channel.debug())) self._id_lock = threading.Lock() self._cv = threading.Condition() self._globel_resp_dict = {} self._id_counter = 0 self._reset_max_id = 1000000000000000000 self._is_run = False self._recive_func = None def start(self): self._is_run = True self._recive_func = threading.Thread( target=DAGExecutor._recive_out_channel_func, args=(self, )) self._recive_func.start() def stop(self): self._is_run = False self._dag.stop() self._dag.join() def _get_next_data_id(self): with self._id_lock: if self._id_counter >= self._reset_max_id: self._id_counter -= self._reset_max_id self._id_counter += 1 return self._id_counter - 1 def _set_in_channel(self, in_channel): if not isinstance(in_channel, (ThreadChannel, ProcessChannel)): raise TypeError( self._log('in_channel must be Channel type, but get {}'.format( type(in_channel)))) in_channel.add_producer(self.name) self._in_channel = in_channel def _set_out_channel(self, out_channel): if not isinstance(out_channel, (ThreadChannel, ProcessChannel)): raise TypeError( self._log('out_channel must be Channel type, but get {}'.format( type(out_channel)))) out_channel.add_consumer(self.name) self._out_channel = out_channel def _recive_out_channel_func(self): while self._is_run: channeldata_dict = self._out_channel.front(self.name) if len(channeldata_dict) != 1: _LOGGER.error("out_channel cannot have multiple input ops") os._exit(-1) (_, channeldata), = channeldata_dict.items() if not isinstance(channeldata, ChannelData): raise TypeError( self._log('data must be ChannelData type, but get {}'. format(type(channeldata)))) with self._cv: data_id = channeldata.id self._globel_resp_dict[data_id] = channeldata self._cv.notify_all() def _get_channeldata_from_fetch_buffer(self, data_id): resp = None with self._cv: while data_id not in self._globel_resp_dict: self._cv.wait() resp = self._globel_resp_dict.pop(data_id) self._cv.notify_all() return resp def _pack_channeldata(self, rpc_request): _LOGGER.debug(self._log('start inferce')) data_id = self._get_next_data_id() dictdata = None try: dictdata = self._unpack_rpc_func(rpc_request) except Exception as e: return ChannelData( ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value, error_info="rpc package error: {}".format(e), data_id=data_id), data_id else: return ChannelData( datatype=ChannelDataType.DICT.value, dictdata=dictdata, data_id=data_id), data_id def call(self, rpc_request): self._profiler.record("{}-prepack_0".format(self.name)) req_channeldata, data_id = self._pack_channeldata(rpc_request) self._profiler.record("{}-prepack_1".format(self.name)) resp_channeldata = None for i in range(self._retry): _LOGGER.debug(self._log('push data')) #self._profiler.record("{}-push_0".format(self.name)) self._in_channel.push(req_channeldata, self.name) #self._profiler.record("{}-push_1".format(self.name)) _LOGGER.debug(self._log('wait for infer')) #self._profiler.record("{}-fetch_0".format(self.name)) resp_channeldata = self._get_channeldata_from_fetch_buffer(data_id) #self._profiler.record("{}-fetch_1".format(self.name)) if resp_channeldata.ecode == ChannelDataEcode.OK.value: break if i + 1 < self._retry: _LOGGER.warn("retry({}): {}".format( i + 1, resp_channeldata.error_info)) self._profiler.record("{}-postpack_0".format(self.name)) rpc_resp = self._pack_for_rpc_resp(resp_channeldata) self._profiler.record("{}-postpack_1".format(self.name)) self._profiler.print_profile() return rpc_resp def _pack_for_rpc_resp(self, channeldata): _LOGGER.debug(self._log('get channeldata')) return self._pack_rpc_func(channeldata) def _log(self, info_str): return "[{}] {}".format(self.name, info_str) class DAG(object): def __init__(self, response_op, profiler, use_multithread, client_type, channel_size): self._response_op = response_op self._use_multithread = use_multithread self._channel_size = channel_size self._client_type = client_type if not self._use_multithread: self._manager = multiprocessing.Manager() self._profiler = profiler def get_use_ops(self, response_op): unique_names = set() used_ops = set() succ_ops_of_use_op = {} # {op_name: succ_ops} que = Queue.Queue() que.put(response_op) while que.qsize() != 0: op = que.get() for pred_op in op.get_input_ops(): if pred_op.name not in succ_ops_of_use_op: succ_ops_of_use_op[pred_op.name] = [] if op != response_op: succ_ops_of_use_op[pred_op.name].append(op) if pred_op not in used_ops: que.put(pred_op) used_ops.add(pred_op) # check the name of op is globally unique if pred_op.name in unique_names: raise Exception("the name of Op must be unique: {}". format(pred_op.name)) unique_names.add(pred_op.name) return used_ops, succ_ops_of_use_op def _gen_channel(self, name_gen): channel = None if self._use_multithread: channel = ThreadChannel( name=name_gen.next(), maxsize=self._channel_size) else: channel = ProcessChannel( self._manager, name=name_gen.next(), maxsize=self._channel_size) return channel def _gen_virtual_op(self, name_gen): return VirtualOp(name=name_gen.next()) def _topo_sort(self, used_ops, response_op, out_degree_ops): out_degree_num = { name: len(ops) for name, ops in out_degree_ops.items() } que_idx = 0 # scroll queue ques = [Queue.Queue() for _ in range(2)] zero_indegree_num = 0 for op in used_ops: if len(op.get_input_ops()) == 0: zero_indegree_num += 1 if zero_indegree_num != 1: raise Exception("DAG contains multiple input Ops") last_op = response_op.get_input_ops()[0] ques[que_idx].put(last_op) # topo sort to get dag_views dag_views = [] sorted_op_num = 0 while True: que = ques[que_idx] next_que = ques[(que_idx + 1) % 2] dag_view = [] while que.qsize() != 0: op = que.get() dag_view.append(op) sorted_op_num += 1 for pred_op in op.get_input_ops(): out_degree_num[pred_op.name] -= 1 if out_degree_num[pred_op.name] == 0: next_que.put(pred_op) dag_views.append(dag_view) if next_que.qsize() == 0: break que_idx = (que_idx + 1) % 2 if sorted_op_num < len(used_ops): raise Exception("not legal DAG") return dag_views, last_op def _build_dag(self, response_op): if response_op is None: raise Exception("response_op has not been set.") used_ops, out_degree_ops = self.get_use_ops(response_op) _LOGGER.info("================= use op ==================") for op in used_ops: _LOGGER.info(op.name) _LOGGER.info("===========================================") if len(used_ops) <= 1: raise Exception( "Besides RequestOp and ResponseOp, there should be at least one Op in DAG." ) dag_views, last_op = self._topo_sort(used_ops, response_op, out_degree_ops) # create channels and virtual ops virtual_op_name_gen = NameGenerator("vir") channel_name_gen = NameGenerator("chl") virtual_ops = [] channels = [] input_channel = None actual_view = None dag_views = list(reversed(dag_views)) for v_idx, view in enumerate(dag_views): if v_idx + 1 >= len(dag_views): break next_view = dag_views[v_idx + 1] if actual_view is None: actual_view = view actual_next_view = [] pred_op_of_next_view_op = {} for op in actual_view: # find actual succ op in next view and create virtual op for succ_op in out_degree_ops[op.name]: if succ_op in next_view: if succ_op not in actual_next_view: actual_next_view.append(succ_op) if succ_op.name not in pred_op_of_next_view_op: pred_op_of_next_view_op[succ_op.name] = [] pred_op_of_next_view_op[succ_op.name].append(op) else: # create virtual op virtual_op = self._gen_virtual_op(virtual_op_name_gen) virtual_ops.append(virtual_op) out_degree_ops[virtual_op.name] = [succ_op] actual_next_view.append(virtual_op) pred_op_of_next_view_op[virtual_op.name] = [op] virtual_op.add_virtual_pred_op(op) actual_view = actual_next_view # create channel processed_op = set() for o_idx, op in enumerate(actual_next_view): if op.name in processed_op: continue channel = self._gen_channel(channel_name_gen) channels.append(channel) _LOGGER.debug("{} => {}".format(channel.name, op.name)) op.add_input_channel(channel) pred_ops = pred_op_of_next_view_op[op.name] if v_idx == 0: input_channel = channel else: # if pred_op is virtual op, it will use ancestors as producers to channel for pred_op in pred_ops: _LOGGER.debug("{} => {}".format(pred_op.name, channel.name)) pred_op.add_output_channel(channel) processed_op.add(op.name) # find same input op to combine channel for other_op in actual_next_view[o_idx + 1:]: if other_op.name in processed_op: continue other_pred_ops = pred_op_of_next_view_op[other_op.name] if len(other_pred_ops) != len(pred_ops): continue same_flag = True for pred_op in pred_ops: if pred_op not in other_pred_ops: same_flag = False break if same_flag: _LOGGER.debug("{} => {}".format(channel.name, other_op.name)) other_op.add_input_channel(channel) processed_op.add(other_op.name) output_channel = self._gen_channel(channel_name_gen) channels.append(output_channel) last_op.add_output_channel(output_channel) pack_func, unpack_func = None, None pack_func = response_op.pack_response_package actual_ops = virtual_ops for op in used_ops: if len(op.get_input_ops()) == 0: unpack_func = op.unpack_request_package continue actual_ops.append(op) for c in channels: _LOGGER.debug(c.debug()) return (actual_ops, channels, input_channel, output_channel, pack_func, unpack_func) def build(self): (actual_ops, channels, input_channel, output_channel, pack_func, unpack_func) = self._build_dag(self._response_op) self._actual_ops = actual_ops self._channels = channels self._input_channel = input_channel self._output_channel = output_channel self._pack_func = pack_func self._unpack_func = unpack_func return self._input_channel, self._output_channel, self._pack_func, self._unpack_func def start(self): self._threads_or_proces = [] for op in self._actual_ops: op.init_profiler(self._profiler) if self._use_multithread: self._threads_or_proces.extend( op.start_with_thread(self._client_type)) else: self._threads_or_proces.extend( op.start_with_process(self._client_type)) # not join yet return self._threads_or_proces def join(self): for x in self._threads_or_proces: x.join() def stop(self): for op in self._actual_ops: op.stop() for chl in self._channels: chl.stop()