# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=doc-string-missing import threading import multiprocessing import sys import copy if sys.version_info.major == 2: import Queue elif sys.version_info.major == 3: import queue as Queue else: raise Exception("Error Python version") import os import logging import collections import json from .operator import Op, RequestOp, ResponseOp, VirtualOp from .channel import (ThreadChannel, ProcessChannel, ChannelData, ChannelDataErrcode, ChannelDataType, ChannelStopError, ProductErrCode) from .profiler import TimeProfiler, PerformanceTracer from .util import NameGenerator, ThreadIdGenerator, PipelineProcSyncManager from .proto import pipeline_service_pb2 _LOGGER = logging.getLogger(__name__) class DAGExecutor(object): """ DAG Executor, the service entrance of DAG. """ def __init__(self, response_op, server_conf, worker_idx): """ Initialize DAGExecutor. Args: response_op: Response OP server_conf: server conf. config.yaml worker_idx: DAGExecutor index, PipelineServer creates many DAGExecutors when _build_dag_each_worker is true. Returns: None. """ build_dag_each_worker = server_conf["build_dag_each_worker"] server_worker_num = server_conf["worker_num"] dag_conf = server_conf["dag"] self._retry = dag_conf["retry"] self._server_use_profile = dag_conf["use_profile"] channel_size = dag_conf["channel_size"] channel_recv_frist_arrive = dag_conf["channel_recv_frist_arrive"] self._is_thread_op = dag_conf["is_thread_op"] tracer_conf = dag_conf["tracer"] tracer_interval_s = tracer_conf["interval_s"] self.name = "@DAGExecutor" self._profiler = TimeProfiler() self._profiler.enable(True) self._tracer = None if tracer_interval_s >= 1: self._tracer = PerformanceTracer( self._is_thread_op, tracer_interval_s, server_worker_num) self._dag = DAG(self.name, response_op, self._server_use_profile, self._is_thread_op, channel_size, build_dag_each_worker, self._tracer, channel_recv_frist_arrive) (in_channel, out_channel, pack_rpc_func, unpack_rpc_func) = self._dag.build() self._dag.start() self._set_in_channel(in_channel) self._set_out_channel(out_channel) self._pack_rpc_func = pack_rpc_func self._unpack_rpc_func = unpack_rpc_func if self._tracer is not None: self._tracer.start() # generate id # data_id: Server Unique ID, automatically generated by the framework # log_id: Trace one product request, can be empty, not unique. base_counter = 0 gen_id_step = 1 if build_dag_each_worker: base_counter = worker_idx gen_id_step = server_worker_num self._id_generator = ThreadIdGenerator( max_id=1000000000000000000, base_counter=base_counter, step=gen_id_step) self._cv_pool = {} self._cv_for_cv_pool = threading.Condition() self._fetch_buffer = {} self._recive_func = None self._client_profile_key = "pipeline.profile" self._client_profile_value = "1" def start(self): """ Starting one thread for receiving data from the last channel background. Args: None Returns: None """ self._recive_func = threading.Thread( target=DAGExecutor._recive_out_channel_func, args=(self, )) self._recive_func.daemon = True self._recive_func.start() _LOGGER.debug("[DAG Executor] Start recive thread") def stop(self): """ Stopping DAG Args: None Returns: None """ self._dag.stop() self._dag.join() _LOGGER.info("[DAG Executor] Stop") def _get_next_data_id(self): """ Generate data_id incrementally and Uniquely Args: None Returns: data_id: uniq id cond_v: condition variable """ data_id = self._id_generator.next() cond_v = threading.Condition() with self._cv_for_cv_pool: self._cv_pool[data_id] = cond_v self._fetch_buffer[data_id] = None return data_id, cond_v def _set_in_channel(self, in_channel): """ Set in_channel of DAG Args: in_channel: input channel of DAG Returns: None """ if not isinstance(in_channel, (ThreadChannel, ProcessChannel)): _LOGGER.critical("[DAG Executor] Failed to set in_channel: " "in_channel must be Channel type, but get {}". format(type(in_channel))) os._exit(-1) self._in_channel = in_channel _LOGGER.info("[DAG] set in channel succ, name [{}]".format(self.name)) def _set_out_channel(self, out_channel): """ Set out_channel of DAG Args: out_channel: output channel of DAG Returns: None """ if not isinstance(out_channel, (ThreadChannel, ProcessChannel)): _LOGGER.critical("[DAG Executor] Failed to set out_channel: " "must be Channel type, but get {}".format( type(out_channel))) os._exit(-1) out_channel.add_consumer(self.name) self._out_channel = out_channel def _recive_out_channel_func(self): """ Receiving data from the output channel, and pushing data into _fetch_buffer. Function _get_channeldata_from_fetch_buffer gets data by retry time. Args: None Returns: None """ cv = None while True: try: channeldata_dict = self._out_channel.front(self.name) except ChannelStopError: _LOGGER.info("[DAG Executor] Stop.") with self._cv_for_cv_pool: for data_id, cv in self._cv_pool.items(): closed_errror_data = ChannelData( error_code=ChannelDataErrcode.CLOSED_ERROR.value, error_info="dag closed.", data_id=data_id) with cv: self._fetch_buffer[data_id] = closed_errror_data cv.notify_all() break if len(channeldata_dict) != 1: _LOGGER.critical( "[DAG Executor] Failed to fetch result: out_channel " "cannot have multiple input ops") os._exit(-1) (_, channeldata), = channeldata_dict.items() if not isinstance(channeldata, ChannelData): _LOGGER.critical( '[DAG Executor] Failed to fetch result: data in out_channel" \ " must be ChannelData type, but get {}' .format(type(channeldata))) os._exit(-1) data_id = channeldata.id _LOGGER.debug("(logid={}) [recive thread] Fetched data".format( data_id)) with self._cv_for_cv_pool: cond_v = self._cv_pool[data_id] with cond_v: self._fetch_buffer[data_id] = channeldata cond_v.notify_all() def _get_channeldata_from_fetch_buffer(self, data_id, cond_v): """ Getting the channel data from _fetch_buffer. Args: data_id: search key cond_v: conditional variable Returns: ready_data: one channel data processed """ ready_data = None with cond_v: with self._cv_for_cv_pool: if self._fetch_buffer[data_id] is not None: # The requested data is already ready ready_data = self._fetch_buffer[data_id] self._cv_pool.pop(data_id) self._fetch_buffer.pop(data_id) if ready_data is None: # Wait for data ready cond_v.wait() with self._cv_for_cv_pool: ready_data = self._fetch_buffer[data_id] self._cv_pool.pop(data_id) self._fetch_buffer.pop(data_id) _LOGGER.debug("(data_id={}) [resp thread] Got data".format(data_id)) return ready_data def _pack_channeldata(self, rpc_request, data_id): """ Unpacking data from RPC request. and creating one channelData. Args: rpc_request: one RPC request data_id: data id, unique Returns: ChannelData: one channel data to be processed """ dictdata = None log_id = None try: dictdata, log_id, prod_errcode, prod_errinfo = self._unpack_rpc_func( rpc_request) except Exception as e: _LOGGER.error( "(logid={}) Failed to parse RPC request package: {}" .format(data_id, e), exc_info=True) return ChannelData( error_code=ChannelDataErrcode.RPC_PACKAGE_ERROR.value, error_info="rpc package error: {}".format(e), data_id=data_id, log_id=log_id) else: # because unpack_rpc_func is rewritten by user, we need to look # for product_errcode in returns, and client_profile_key field # in rpc_request if prod_errcode is not None: # product errors occured _LOGGER.error("unpack_rpc_func prod_errcode:{}".format( prod_errcode)) return ChannelData( error_code=ChannelDataErrcode.PRODUCT_ERROR.value, error_info="", prod_error_code=prod_errcode, prod_error_info=prod_errinfo, data_id=data_id, log_id=log_id) profile_value = None profile_value = dictdata.get(self._client_profile_key) client_need_profile = (profile_value == self._client_profile_value) return ChannelData( datatype=ChannelDataType.DICT.value, dictdata=dictdata, data_id=data_id, log_id=log_id, client_need_profile=client_need_profile) def call(self, rpc_request): """ DAGExcutor enterance function. There are 5 steps: 1._get_next_data_id: Generate an incremental ID 2._pack_channeldata: pack the channel data from request. 3.retry loop: a. push channel_data into _in_channel b. get_channeldata_from_fetch_buffer: get results. 4._pack_for_rpc_resp: pack RPC responses 5.profile: generte profile string and pack into response. Args: rpc_request: one RPC request Returns: rpc_resp: one RPC response """ if self._tracer is not None: trace_buffer = self._tracer.data_buffer() data_id, cond_v = self._get_next_data_id() start_call, end_call = None, None if not self._is_thread_op: start_call = self._profiler.record("call_{}#DAG-{}_0".format( data_id, data_id)) else: start_call = self._profiler.record("call_{}#DAG_0".format(data_id)) self._profiler.record("prepack_{}#{}_0".format(data_id, self.name)) req_channeldata = self._pack_channeldata(rpc_request, data_id) self._profiler.record("prepack_{}#{}_1".format(data_id, self.name)) log_id = req_channeldata.log_id _LOGGER.info("(data_id={} log_id={}) Succ Generate ID ".format(data_id, log_id)) resp_channeldata = None for i in range(self._retry): _LOGGER.debug("(data_id={}) Pushing data into Graph engine".format( data_id)) try: if req_channeldata is None: _LOGGER.critical( "(data_id={} log_id={}) req_channeldata is None" .format(data_id, log_id)) if not isinstance(self._in_channel, (ThreadChannel, ProcessChannel)): _LOGGER.critical( "(data_id={} log_id={})[DAG Executor] Failed to " "set in_channel: in_channel must be Channel type, but get {}". format(data_id, log_id, type(self._in_channel))) self._in_channel.push(req_channeldata, self.name) except ChannelStopError: _LOGGER.error("(data_id:{} log_id={})[DAG Executor] Stop". format(data_id, log_id)) with self._cv_for_cv_pool: self._cv_pool.pop(data_id) return self._pack_for_rpc_resp( ChannelData( error_code=ChannelDataErrcode.CLOSED_ERROR.value, error_info="dag closed.", data_id=data_id)) _LOGGER.debug("(data_id={} log_id={}) Wait for Graph engine...". format(data_id, log_id)) resp_channeldata = self._get_channeldata_from_fetch_buffer(data_id, cond_v) if resp_channeldata.error_code == ChannelDataErrcode.OK.value: _LOGGER.info("(data_id={} log_id={}) Succ predict".format( data_id, log_id)) break else: _LOGGER.error("(data_id={} log_id={}) Failed to predict: {}" .format(data_id, log_id, resp_channeldata.error_info)) if resp_channeldata.error_code != ChannelDataErrcode.TIMEOUT.value: break if i + 1 < self._retry: _LOGGER.warning( "(data_id={} log_id={}) DAGExecutor retry({}/{})" .format(data_id, log_id, i + 1, self._retry)) _LOGGER.debug("(data_id={} log_id={}) Packing RPC response package" .format(data_id, log_id)) self._profiler.record("postpack_{}#{}_0".format(data_id, self.name)) rpc_resp = self._pack_for_rpc_resp(resp_channeldata) self._profiler.record("postpack_{}#{}_1".format(data_id, self.name)) if not self._is_thread_op: end_call = self._profiler.record("call_{}#DAG-{}_1".format(data_id, data_id)) else: end_call = self._profiler.record("call_{}#DAG_1".format(data_id)) if self._tracer is not None: trace_buffer.put({ "name": "DAG", "id": data_id, "succ": resp_channeldata.error_code == ChannelDataErrcode.OK.value, "actions": { "call_{}".format(data_id): end_call - start_call, }, }) profile_str = self._profiler.gen_profile_str() if self._server_use_profile: sys.stderr.write(profile_str) # add profile info into rpc_resp if resp_channeldata.client_need_profile: profile_set = resp_channeldata.profile_data_set profile_set.add(profile_str) profile_value = "".join(list(profile_set)) rpc_resp.key.append(self._client_profile_key) rpc_resp.value.append(profile_value) return rpc_resp def _pack_for_rpc_resp(self, channeldata): """ Packing one RPC response Args: channeldata: one channel data to be packed Returns: resp: one RPC response """ try: return self._pack_rpc_func(channeldata) except Exception as e: _LOGGER.error( "(logid={}) Failed to pack RPC response package: {}" .format(channeldata.id, e), exc_info=True) resp = pipeline_service_pb2.Response() resp.err_no = ChannelDataErrcode.RPC_PACKAGE_ERROR.value resp.err_msg = "rpc package error: {}".format(e) return resp class DAG(object): """ Directed Acyclic Graph(DAG) engine, builds one DAG topology. """ def __init__(self, request_name, response_op, use_profile, is_thread_op, channel_size, build_dag_each_worker, tracer, channel_recv_frist_arrive): self._request_name = request_name self._response_op = response_op self._use_profile = use_profile self._is_thread_op = is_thread_op self._channel_size = channel_size self._build_dag_each_worker = build_dag_each_worker self._tracer = tracer self._channel_recv_frist_arrive = channel_recv_frist_arrive if not self._is_thread_op: self._manager = PipelineProcSyncManager() _LOGGER.info("[DAG] Succ init") @staticmethod def get_use_ops(response_op): """ Starting from ResponseOp, recursively traverse the front OPs. Getting all used ops and the post op list of each op (excluding ResponseOp) Args: response_op: ResponseOp Returns: used_ops: used ops, set succ_ops_of_use_op: op and the next op list, dict. """ unique_names = set() used_ops = set() succ_ops_of_use_op = {} # {op_name: succ_ops} que = Queue.Queue() que.put(response_op) while que.qsize() != 0: op = que.get() for pred_op in op.get_input_ops(): if pred_op.name not in succ_ops_of_use_op: succ_ops_of_use_op[pred_op.name] = [] if op != response_op: succ_ops_of_use_op[pred_op.name].append(op) if pred_op not in used_ops: que.put(pred_op) used_ops.add(pred_op) # check the name of op is globally unique if pred_op.name in unique_names: _LOGGER.critical("Failed to get used Ops: the" " name of Op must be unique: {}". format(pred_op.name)) os._exit(-1) unique_names.add(pred_op.name) return used_ops, succ_ops_of_use_op def _gen_channel(self, name_gen): """ Generate one ThreadChannel or ProcessChannel. Args: name_gen: channel name Returns: channel: one channel generated """ channel = None if self._is_thread_op: channel = ThreadChannel( name=name_gen.next(), maxsize=self._channel_size, channel_recv_frist_arrive=self._channel_recv_frist_arrive) else: channel = ProcessChannel( self._manager, name=name_gen.next(), maxsize=self._channel_size, channel_recv_frist_arrive=self._channel_recv_frist_arrive) _LOGGER.debug("[DAG] Generate channel: {}".format(channel.name)) return channel def _gen_virtual_op(self, name_gen): """ Generate one virtual Op Args: name_gen: Op name Returns: vir_op: one virtual Op object. """ vir_op = VirtualOp(name=name_gen.next()) _LOGGER.debug("[DAG] Generate virtual_op: {}".format(vir_op.name)) return vir_op def _topo_sort(self, used_ops, response_op, out_degree_ops): """ Topological sort of DAG, creates inverted multi-layers views. Args: used_ops: op used in DAG response_op: response op out_degree_ops: Next op list for each op, dict. the output of get_use_ops() Returns: dag_views: the inverted hierarchical topology list. examples: DAG :[A -> B -> C -> E] \-> D / dag_views: [[E], [C, D], [B], [A]] last_op:the last op front of ResponseOp """ out_degree_num = { name: len(ops) for name, ops in out_degree_ops.items() } que_idx = 0 # scroll queue ques = [Queue.Queue() for _ in range(2)] zero_indegree_num = 0 for op in used_ops: if len(op.get_input_ops()) == 0: zero_indegree_num += 1 if zero_indegree_num != 1: _LOGGER.critical("Failed to topo sort: DAG contains " "multiple RequestOps") os._exit(-1) last_op = response_op.get_input_ops()[0] ques[que_idx].put(last_op) # topo sort to get dag_views dag_views = [] sorted_op_num = 0 while True: que = ques[que_idx] next_que = ques[(que_idx + 1) % 2] dag_view = [] while que.qsize() != 0: op = que.get() dag_view.append(op) sorted_op_num += 1 for pred_op in op.get_input_ops(): out_degree_num[pred_op.name] -= 1 if out_degree_num[pred_op.name] == 0: next_que.put(pred_op) dag_views.append(dag_view) if next_que.qsize() == 0: break que_idx = (que_idx + 1) % 2 if sorted_op_num < len(used_ops): _LOGGER.critical("Failed to topo sort: not legal DAG") os._exit(-1) return dag_views, last_op def _build_dag(self, response_op): """ Building DAG, the most important function in class DAG. Core steps: 1.get_use_ops: Getting used ops, and out degree op list for each op. 2._topo_sort: Topological sort creates inverted multi-layers views. 3.create channels and virtual ops. Args: response_op: ResponseOp Returns: actual_ops: all OPs used in DAG, including virtual OPs channels: all channels used in DAG input_channel: the channel of first OP output_channel: the channel of last OP pack_func: pack_response_package function of response_op unpack_func: unpack_request_package function of request_op """ if response_op is None: _LOGGER.critical("Failed to build DAG: ResponseOp" " has not been set.") os._exit(-1) used_ops, out_degree_ops = DAG.get_use_ops(response_op) if not self._build_dag_each_worker: _LOGGER.info("================= USED OP =================") for op in used_ops: if not isinstance(op, RequestOp): _LOGGER.info(op.name) _LOGGER.info("-------------------------------------------") if len(used_ops) <= 1: _LOGGER.critical( "Failed to build DAG: besides RequestOp and ResponseOp, " "there should be at least one Op in DAG.") os._exit(-1) if self._build_dag_each_worker: _LOGGER.info("Because `build_dag_each_worker` mode is used, " "Auto-batching is set to the default config: " "batch_size=1, auto_batching_timeout=None") for op in used_ops: op.use_default_auto_batching_config() dag_views, last_op = self._topo_sort(used_ops, response_op, out_degree_ops) dag_views = list(reversed(dag_views)) if not self._build_dag_each_worker: _LOGGER.info("================== DAG ====================") for idx, view in enumerate(dag_views): _LOGGER.info("(VIEW {})".format(idx)) for op in view: _LOGGER.info(" [{}]".format(op.name)) for out_op in out_degree_ops[op.name]: _LOGGER.info(" - {}".format(out_op.name)) _LOGGER.info("-------------------------------------------") # create channels and virtual ops virtual_op_name_gen = NameGenerator("vir") channel_name_gen = NameGenerator("chl") virtual_ops = [] channels = [] input_channel = None actual_view = None for v_idx, view in enumerate(dag_views): if v_idx + 1 >= len(dag_views): break next_view = dag_views[v_idx + 1] if actual_view is None: actual_view = view actual_next_view = [] pred_op_of_next_view_op = {} for op in actual_view: # find actual succ op in next view and create virtual op for succ_op in out_degree_ops[op.name]: if succ_op in next_view: if succ_op not in actual_next_view: actual_next_view.append(succ_op) if succ_op.name not in pred_op_of_next_view_op: pred_op_of_next_view_op[succ_op.name] = [] pred_op_of_next_view_op[succ_op.name].append(op) else: # create virtual op virtual_op = self._gen_virtual_op(virtual_op_name_gen) virtual_ops.append(virtual_op) out_degree_ops[virtual_op.name] = [succ_op] actual_next_view.append(virtual_op) pred_op_of_next_view_op[virtual_op.name] = [op] virtual_op.add_virtual_pred_op(op) actual_view = actual_next_view # create channel processed_op = set() for o_idx, op in enumerate(actual_next_view): if op.name in processed_op: continue channel = self._gen_channel(channel_name_gen) channels.append(channel) op.add_input_channel(channel) _LOGGER.info("op:{} add input channel.".format(op.name)) pred_ops = pred_op_of_next_view_op[op.name] if v_idx == 0: input_channel = channel else: # if pred_op is virtual op, it will use ancestors as producers to channel for pred_op in pred_ops: pred_op.add_output_channel(channel) _LOGGER.info("pred_op:{} add output channel".format( pred_op.name)) processed_op.add(op.name) # find same input op to combine channel for other_op in actual_next_view[o_idx + 1:]: if other_op.name in processed_op: continue other_pred_ops = pred_op_of_next_view_op[other_op.name] if len(other_pred_ops) != len(pred_ops): continue same_flag = True for pred_op in pred_ops: if pred_op not in other_pred_ops: same_flag = False break if same_flag: other_op.add_input_channel(channel) processed_op.add(other_op.name) output_channel = self._gen_channel(channel_name_gen) channels.append(output_channel) last_op.add_output_channel(output_channel) _LOGGER.info("last op:{} add output channel".format(last_op.name)) pack_func, unpack_func = None, None pack_func = response_op.pack_response_package actual_ops = virtual_ops for op in used_ops: if len(op.get_input_ops()) == 0: #set special features of the request op. #1.set unpack function. #2.set output channel. unpack_func = op.unpack_request_package op.add_output_channel(input_channel) continue actual_ops.append(op) for c in channels: _LOGGER.debug("Channel({}):\n\t- producers: {}\n\t- consumers: {}" .format(c.name, c.get_producers(), c.get_consumers())) return (actual_ops, channels, input_channel, output_channel, pack_func, unpack_func) def get_channels(self): return self._channels def build(self): """ Interface for building one DAG outside. Args: None Returns: _input_channel: the channel of first OP _output_channel: the channel of last OP _pack_func: pack_response_package function of response_op _unpack_func: unpack_request_package function of request_op """ (actual_ops, channels, input_channel, output_channel, pack_func, unpack_func) = self._build_dag(self._response_op) _LOGGER.info("[DAG] Succ build DAG") self._actual_ops = actual_ops self._channels = channels self._input_channel = input_channel self._output_channel = output_channel self._pack_func = pack_func self._unpack_func = unpack_func if self._tracer is not None: self._tracer.set_channels(self._channels) return self._input_channel, self._output_channel, self._pack_func, self._unpack_func def start(self): """ Each OP starts a thread or process by _is_thread_op Args: None Returns: _threads_or_proces: threads or process list. """ self._threads_or_proces = [] for op in self._actual_ops: op.use_profiler(self._use_profile) op.set_tracer(self._tracer) if self._is_thread_op: self._threads_or_proces.extend(op.start_with_thread()) else: self._threads_or_proces.extend(op.start_with_process()) _LOGGER.info("[DAG] start") # not join yet return self._threads_or_proces def join(self): """ All threads or processes join. Args: None Returns: None """ for x in self._threads_or_proces: if x is not None: x.join() def stop(self): """ Stopping and cleanning all channels. Args: None Returns: None """ for chl in self._channels: chl.stop() for op in self._actual_ops: op.clean_input_channel() op.clean_output_channels()