dag.py 32.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
import sys
B
barrierye 已提交
18
import copy
19 20 21 22 23 24 25 26
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
import os
import logging
B
barrierye 已提交
27
import collections
T
TeslaZhao 已提交
28
import json
29
from .error_catch import ErrorCatch, CustomException, CustomExceptionCode, ParamChecker, ParamVerify
30
from .operator import Op, RequestOp, ResponseOp, VirtualOp
B
barrierye 已提交
31
from .channel import (ThreadChannel, ProcessChannel, ChannelData,
32 33 34
                      ChannelDataType, ChannelStopError)
from .error_catch import  ProductErrCode
from .error_catch import CustomExceptionCode as ChannelDataErrcode
B
barriery 已提交
35
from .profiler import TimeProfiler, PerformanceTracer
36
from .util import NameGenerator, ThreadIdGenerator, PipelineProcSyncManager
B
barriery 已提交
37
from .proto import pipeline_service_pb2
38

39
_LOGGER = logging.getLogger(__name__)
40 41 42


class DAGExecutor(object):
43 44 45
    """
    DAG Executor, the service entrance of DAG.
    """
46
    def __init__(self, response_op, server_conf, worker_idx):
47 48 49 50 51 52 53 54 55 56 57 58
        """
        Initialize DAGExecutor.

        Args:
            response_op: Response OP
            server_conf: server conf. config.yaml
            worker_idx: DAGExecutor index, PipelineServer creates many
                DAGExecutors when _build_dag_each_worker is true.

        Returns:
            None.
        """
B
barriery 已提交
59 60 61 62
        build_dag_each_worker = server_conf["build_dag_each_worker"]
        server_worker_num = server_conf["worker_num"]
        dag_conf = server_conf["dag"]

63 64 65
        self._retry = dag_conf["retry"]
        self._server_use_profile = dag_conf["use_profile"]
        channel_size = dag_conf["channel_size"]
66
        channel_recv_frist_arrive = dag_conf["channel_recv_frist_arrive"]
67
        self._is_thread_op = dag_conf["is_thread_op"]
B
barrierye 已提交
68

B
barriery 已提交
69 70 71 72
        tracer_conf = dag_conf["tracer"]
        tracer_interval_s = tracer_conf["interval_s"]

        self.name = "@DAGExecutor"
B
barrierye 已提交
73
        self._profiler = TimeProfiler()
B
barrierye 已提交
74
        self._profiler.enable(True)
B
barrierye 已提交
75

B
barriery 已提交
76 77 78 79
        self._tracer = None
        if tracer_interval_s >= 1:
            self._tracer = PerformanceTracer(
                self._is_thread_op, tracer_interval_s, server_worker_num)
B
barriery 已提交
80

B
barrierye 已提交
81
        self._dag = DAG(self.name, response_op, self._server_use_profile,
W
wangjiawei04 已提交
82
                        self._is_thread_op, channel_size, build_dag_each_worker,
83
                        self._tracer, channel_recv_frist_arrive)
B
barrierye 已提交
84 85
        (in_channel, out_channel, pack_rpc_func,
         unpack_rpc_func) = self._dag.build()
86 87 88
        self._dag.start()

        self._set_in_channel(in_channel)
89
        self._set_out_channel(out_channel)
90 91 92
        self._pack_rpc_func = pack_rpc_func
        self._unpack_rpc_func = unpack_rpc_func

B
barriery 已提交
93 94 95
        if self._tracer is not None:
            self._tracer.start()

96 97 98
        # generate id 
        # data_id: Server Unique ID, automatically generated by the framework
        # log_id: Trace one product request, can be empty, not unique.
99 100 101 102 103
        base_counter = 0
        gen_id_step = 1
        if build_dag_each_worker:
            base_counter = worker_idx
            gen_id_step = server_worker_num
B
barriery 已提交
104
        self._id_generator = ThreadIdGenerator(
105 106 107
            max_id=1000000000000000000,
            base_counter=base_counter,
            step=gen_id_step)
B
barriery 已提交
108

B
barrierye 已提交
109 110
        self._cv_pool = {}
        self._cv_for_cv_pool = threading.Condition()
111
        self._fetch_buffer = {}
112 113
        self._recive_func = None

B
barrierye 已提交
114 115 116
        self._client_profile_key = "pipeline.profile"
        self._client_profile_value = "1"

117
    @ErrorCatch
118
    def start(self):
119 120 121 122 123 124 125 126 127
        """
        Starting one thread for receiving data from the last channel background.

        Args:
            None

        Returns:
            None
        """
128 129
        self._recive_func = threading.Thread(
            target=DAGExecutor._recive_out_channel_func, args=(self, ))
B
barriery 已提交
130
        self._recive_func.daemon = True
131
        self._recive_func.start()
B
barriery 已提交
132
        _LOGGER.debug("[DAG Executor] Start recive thread")
133 134

    def stop(self):
135 136 137 138 139 140 141 142 143
        """
        Stopping DAG

        Args:
            None

        Returns:
            None
        """
144 145
        self._dag.stop()
        self._dag.join()
B
barriery 已提交
146
        _LOGGER.info("[DAG Executor] Stop")
147 148

    def _get_next_data_id(self):
149 150 151 152 153 154 155 156 157 158
        """
        Generate data_id incrementally and Uniquely
   
        Args:
            None

        Returns:
            data_id: uniq id
            cond_v: condition variable
        """
B
barriery 已提交
159
        data_id = self._id_generator.next()
B
bug fix  
barriery 已提交
160 161 162
        cond_v = threading.Condition()
        with self._cv_for_cv_pool:
            self._cv_pool[data_id] = cond_v
163
            self._fetch_buffer[data_id] = None
B
bug fix  
barriery 已提交
164
        return data_id, cond_v
165 166

    def _set_in_channel(self, in_channel):
167 168 169 170 171 172 173 174 175
        """
        Set in_channel of DAG

        Args:
            in_channel: input channel of DAG

        Returns:
            None 
        """
176
        if not isinstance(in_channel, (ThreadChannel, ProcessChannel)):
B
barriery 已提交
177 178 179
            _LOGGER.critical("[DAG Executor] Failed to set in_channel: "
                             "in_channel must be Channel type, but get {}".
                             format(type(in_channel)))
180
            os._exit(-1)
181

182
        self._in_channel = in_channel
183
        _LOGGER.info("[DAG] set in channel succ, name [{}]".format(self.name))
184 185

    def _set_out_channel(self, out_channel):
186 187 188 189 190 191 192 193 194
        """
        Set out_channel of DAG

        Args:
            out_channel: output channel of DAG

        Returns:
            None 
        """
195
        if not isinstance(out_channel, (ThreadChannel, ProcessChannel)):
B
barriery 已提交
196 197 198
            _LOGGER.critical("[DAG Executor] Failed to set out_channel: "
                             "must be Channel type, but get {}".format(
                                 type(out_channel)))
199
            os._exit(-1)
200 201 202 203
        out_channel.add_consumer(self.name)
        self._out_channel = out_channel

    def _recive_out_channel_func(self):
204 205 206 207 208 209 210 211 212 213 214
        """
        Receiving data from the output channel, and pushing data into 
        _fetch_buffer. Function _get_channeldata_from_fetch_buffer gets 
        data by retry time.

        Args:
            None

        Returns:
            None
        """
B
barrierye 已提交
215
        cv = None
B
barrierye 已提交
216 217 218 219
        while True:
            try:
                channeldata_dict = self._out_channel.front(self.name)
            except ChannelStopError:
B
barriery 已提交
220
                _LOGGER.info("[DAG Executor] Stop.")
B
barrierye 已提交
221 222 223
                with self._cv_for_cv_pool:
                    for data_id, cv in self._cv_pool.items():
                        closed_errror_data = ChannelData(
T
TeslaZhao 已提交
224
                            error_code=ChannelDataErrcode.CLOSED_ERROR.value,
B
barrierye 已提交
225 226 227
                            error_info="dag closed.",
                            data_id=data_id)
                        with cv:
228
                            self._fetch_buffer[data_id] = closed_errror_data
B
barrierye 已提交
229 230
                            cv.notify_all()
                break
231
            if len(channeldata_dict) != 1:
232
                _LOGGER.critical(
B
barriery 已提交
233 234
                    "[DAG Executor] Failed to fetch result: out_channel "
                    "cannot have multiple input ops")
235 236 237
                os._exit(-1)
            (_, channeldata), = channeldata_dict.items()
            if not isinstance(channeldata, ChannelData):
238
                _LOGGER.critical(
B
barriery 已提交
239 240
                    '[DAG Executor] Failed to fetch result: data in out_channel" \
                    " must be ChannelData type, but get {}'
B
barriery 已提交
241
                    .format(type(channeldata)))
B
barriery 已提交
242
                os._exit(-1)
B
barrierye 已提交
243 244

            data_id = channeldata.id
B
barriery 已提交
245 246
            _LOGGER.debug("(logid={}) [recive thread] Fetched data".format(
                data_id))
B
barrierye 已提交
247
            with self._cv_for_cv_pool:
248 249 250 251
                cond_v = self._cv_pool[data_id]
            with cond_v:
                self._fetch_buffer[data_id] = channeldata
                cond_v.notify_all()
252

B
bug fix  
barriery 已提交
253
    def _get_channeldata_from_fetch_buffer(self, data_id, cond_v):
254 255 256 257 258 259 260 261 262 263
        """
        Getting the channel data from _fetch_buffer.

        Args:
            data_id: search key
            cond_v: conditional variable

        Returns:
            ready_data: one channel data processed
        """
264 265
        ready_data = None

B
bug fix  
barriery 已提交
266
        with cond_v:
267 268 269 270 271 272 273 274 275 276 277 278 279
            with self._cv_for_cv_pool:
                if self._fetch_buffer[data_id] is not None:
                    # The requested data is already ready
                    ready_data = self._fetch_buffer[data_id]
                    self._cv_pool.pop(data_id)
                    self._fetch_buffer.pop(data_id)
            if ready_data is None:
                # Wait for data ready
                cond_v.wait()
                with self._cv_for_cv_pool:
                    ready_data = self._fetch_buffer[data_id]
                    self._cv_pool.pop(data_id)
                    self._fetch_buffer.pop(data_id)
280
        _LOGGER.debug("(data_id={}) [resp thread] Got data".format(data_id))
281
        return ready_data
282

B
barrierye 已提交
283
    def _pack_channeldata(self, rpc_request, data_id):
284 285 286 287 288 289 290 291 292 293
        """
        Unpacking data from RPC request. and creating one channelData.

        Args:
           rpc_request: one RPC request
           data_id: data id, unique

        Returns:
            ChannelData: one channel data to be processed
        """
294
        dictdata = None
T
TeslaZhao 已提交
295
        log_id = None
296
        try:
T
TeslaZhao 已提交
297 298
            dictdata, log_id, prod_errcode, prod_errinfo = self._unpack_rpc_func(
                rpc_request)
299
        except Exception as e:
B
barriery 已提交
300 301 302 303
            _LOGGER.error(
                "(logid={}) Failed to parse RPC request package: {}"
                .format(data_id, e),
                exc_info=True)
304
            return ChannelData(
T
TeslaZhao 已提交
305
                error_code=ChannelDataErrcode.RPC_PACKAGE_ERROR.value,
306
                error_info="rpc package error: {}".format(e),
T
TeslaZhao 已提交
307 308
                data_id=data_id,
                log_id=log_id)
309
        else:
T
TeslaZhao 已提交
310 311 312 313 314
            # because unpack_rpc_func is rewritten by user, we need to look
            # for product_errcode in returns, and  client_profile_key field
            # in rpc_request
            if prod_errcode is not None:
                # product errors occured
315 316
                _LOGGER.error("unpack_rpc_func prod_errcode:{}".format(
                    prod_errcode))
T
TeslaZhao 已提交
317 318 319 320 321 322 323 324
                return ChannelData(
                    error_code=ChannelDataErrcode.PRODUCT_ERROR.value,
                    error_info="",
                    prod_error_code=prod_errcode,
                    prod_error_info=prod_errinfo,
                    data_id=data_id,
                    log_id=log_id)

B
barrierye 已提交
325
            profile_value = None
T
TeslaZhao 已提交
326
            profile_value = dictdata.get(self._client_profile_key)
B
barriery 已提交
327
            client_need_profile = (profile_value == self._client_profile_value)
328 329 330
            return ChannelData(
                datatype=ChannelDataType.DICT.value,
                dictdata=dictdata,
B
barrierye 已提交
331
                data_id=data_id,
T
TeslaZhao 已提交
332
                log_id=log_id,
B
barriery 已提交
333
                client_need_profile=client_need_profile)
334 335

    def call(self, rpc_request):
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
        """
        DAGExcutor enterance function. There are 5 steps:
        1._get_next_data_id: Generate an incremental ID
        2._pack_channeldata: pack the channel data from request.
        3.retry loop: 
            a. push channel_data into _in_channel
            b. get_channeldata_from_fetch_buffer: get results.
        4._pack_for_rpc_resp: pack RPC responses
        5.profile: generte profile string and pack into response.

        Args:
            rpc_request: one RPC request
   
        Returns:
            rpc_resp: one RPC response
        """
B
barriery 已提交
352 353
        if self._tracer is not None:
            trace_buffer = self._tracer.data_buffer()
B
barriery 已提交
354

B
bug fix  
barriery 已提交
355
        data_id, cond_v = self._get_next_data_id()
B
barriery 已提交
356

B
barriery 已提交
357
        start_call, end_call = None, None
B
barrierye 已提交
358
        if not self._is_thread_op:
B
barriery 已提交
359 360
            start_call = self._profiler.record("call_{}#DAG-{}_0".format(
                data_id, data_id))
B
barrierye 已提交
361
        else:
B
barriery 已提交
362
            start_call = self._profiler.record("call_{}#DAG_0".format(data_id))
B
barrierye 已提交
363

B
barrierye 已提交
364 365 366
        self._profiler.record("prepack_{}#{}_0".format(data_id, self.name))
        req_channeldata = self._pack_channeldata(rpc_request, data_id)
        self._profiler.record("prepack_{}#{}_1".format(data_id, self.name))
367

368 369 370 371
        log_id = req_channeldata.log_id
        _LOGGER.info("(data_id={} log_id={}) Succ Generate ID ".format(data_id,
                                                                       log_id))

372 373
        resp_channeldata = None
        for i in range(self._retry):
374
            _LOGGER.debug("(data_id={}) Pushing data into Graph engine".format(
B
barriery 已提交
375
                data_id))
B
barrierye 已提交
376
            try:
377 378 379 380 381 382 383 384 385 386
                if req_channeldata is None:
                    _LOGGER.critical(
                        "(data_id={} log_id={}) req_channeldata is None"
                        .format(data_id, log_id))
                if not isinstance(self._in_channel,
                                  (ThreadChannel, ProcessChannel)):
                    _LOGGER.critical(
                        "(data_id={} log_id={})[DAG Executor] Failed to "
                        "set in_channel: in_channel must be Channel type, but get {}".
                        format(data_id, log_id, type(self._in_channel)))
B
barrierye 已提交
387 388
                self._in_channel.push(req_channeldata, self.name)
            except ChannelStopError:
389 390
                _LOGGER.error("(data_id:{} log_id={})[DAG Executor] Stop".
                              format(data_id, log_id))
B
bug fix  
barriery 已提交
391 392
                with self._cv_for_cv_pool:
                    self._cv_pool.pop(data_id)
B
barrierye 已提交
393 394
                return self._pack_for_rpc_resp(
                    ChannelData(
T
TeslaZhao 已提交
395
                        error_code=ChannelDataErrcode.CLOSED_ERROR.value,
B
barrierye 已提交
396 397
                        error_info="dag closed.",
                        data_id=data_id))
398

399 400
            _LOGGER.debug("(data_id={} log_id={}) Wait for Graph engine...".
                          format(data_id, log_id))
B
bug fix  
barriery 已提交
401 402
            resp_channeldata = self._get_channeldata_from_fetch_buffer(data_id,
                                                                       cond_v)
403

T
TeslaZhao 已提交
404
            if resp_channeldata.error_code == ChannelDataErrcode.OK.value:
405 406
                _LOGGER.info("(data_id={} log_id={}) Succ predict".format(
                    data_id, log_id))
407
                break
B
barriery 已提交
408
            else:
409 410 411
                _LOGGER.error("(data_id={} log_id={}) Failed to predict: {}"
                              .format(data_id, log_id,
                                      resp_channeldata.error_info))
T
TeslaZhao 已提交
412
                if resp_channeldata.error_code != ChannelDataErrcode.TIMEOUT.value:
B
barriery 已提交
413 414
                    break

415
            if i + 1 < self._retry:
416 417 418
                _LOGGER.warning(
                    "(data_id={} log_id={}) DAGExecutor retry({}/{})"
                    .format(data_id, log_id, i + 1, self._retry))
419

420 421
        _LOGGER.debug("(data_id={} log_id={}) Packing RPC response package"
                      .format(data_id, log_id))
B
barrierye 已提交
422
        self._profiler.record("postpack_{}#{}_0".format(data_id, self.name))
423
        rpc_resp = self._pack_for_rpc_resp(resp_channeldata)
B
barrierye 已提交
424
        self._profiler.record("postpack_{}#{}_1".format(data_id, self.name))
B
barrierye 已提交
425
        if not self._is_thread_op:
B
barriery 已提交
426 427
            end_call = self._profiler.record("call_{}#DAG-{}_1".format(data_id,
                                                                       data_id))
B
barrierye 已提交
428
        else:
B
barriery 已提交
429
            end_call = self._profiler.record("call_{}#DAG_1".format(data_id))
B
barriery 已提交
430 431

        if self._tracer is not None:
B
barrierye 已提交
432
            trace_buffer.put({
B
barrierye 已提交
433 434
                "name": "DAG",
                "id": data_id,
T
TeslaZhao 已提交
435 436
                "succ":
                resp_channeldata.error_code == ChannelDataErrcode.OK.value,
B
barrierye 已提交
437 438 439 440
                "actions": {
                    "call_{}".format(data_id): end_call - start_call,
                },
            })
B
barrierye 已提交
441 442 443 444 445 446 447

        profile_str = self._profiler.gen_profile_str()
        if self._server_use_profile:
            sys.stderr.write(profile_str)

        # add profile info into rpc_resp
        if resp_channeldata.client_need_profile:
B
barrierye 已提交
448 449 450
            profile_set = resp_channeldata.profile_data_set
            profile_set.add(profile_str)
            profile_value = "".join(list(profile_set))
B
barriery 已提交
451 452
            rpc_resp.key.append(self._client_profile_key)
            rpc_resp.value.append(profile_value)
B
barrierye 已提交
453

454 455 456
        return rpc_resp

    def _pack_for_rpc_resp(self, channeldata):
457 458 459 460 461 462 463 464 465
        """
        Packing one RPC response

        Args:
            channeldata: one channel data to be packed

        Returns:
            resp: one RPC response
        """
B
barriery 已提交
466 467 468 469 470 471 472 473
        try:
            return self._pack_rpc_func(channeldata)
        except Exception as e:
            _LOGGER.error(
                "(logid={}) Failed to pack RPC response package: {}"
                .format(channeldata.id, e),
                exc_info=True)
            resp = pipeline_service_pb2.Response()
T
TeslaZhao 已提交
474 475
            resp.err_no = ChannelDataErrcode.RPC_PACKAGE_ERROR.value
            resp.err_msg = "rpc package error: {}".format(e)
B
barriery 已提交
476
            return resp
477 478 479


class DAG(object):
480 481 482
    """
    Directed Acyclic Graph(DAG) engine, builds one DAG topology.
    """
B
barrierye 已提交
483
    def __init__(self, request_name, response_op, use_profile, is_thread_op,
484 485
                 channel_size, build_dag_each_worker, tracer,
                 channel_recv_frist_arrive):
486 487 488 489 490 491 492 493 494 495 496
        _LOGGER.info("{}, {}, {}, {}, {} ,{} ,{} ,{}".format(request_name, response_op, use_profile, is_thread_op,
                         channel_size, build_dag_each_worker, tracer,
                                          channel_recv_frist_arrive))
        @ErrorCatch
        @ParamChecker
        def init_helper(self, request_name: str,
                         response_op, 
                         use_profile: [bool, None], 
                         is_thread_op: bool,
                         channel_size, 
                         build_dag_each_worker: [bool, None],
F
felixhjh 已提交
497
                         tracer,
498 499 500 501 502 503 504 505 506 507 508 509 510 511
                        channel_recv_frist_arrive):
            self._request_name = request_name
            self._response_op = response_op
            self._use_profile = use_profile
            self._is_thread_op = is_thread_op
            self._channel_size = channel_size
            self._build_dag_each_worker = build_dag_each_worker
            self._tracer = tracer
            self._channel_recv_frist_arrive = channel_recv_frist_arrive
            if not self._is_thread_op:
                self._manager = PipelineProcSyncManager()
        init_helper(self, request_name, response_op, use_profile, is_thread_op,
                    channel_size, build_dag_each_worker, tracer,
                    channel_recv_frist_arrive)
F
felixhjh 已提交
512
        print("[DAG] Succ init")
B
barriery 已提交
513
        _LOGGER.info("[DAG] Succ init")
514

515
    @staticmethod
B
barriery 已提交
516
    def get_use_ops(response_op):
517 518 519 520 521 522 523 524 525 526 527 528
        """
        Starting from ResponseOp, recursively traverse the front OPs. Getting
        all used ops and the post op list of each op (excluding ResponseOp)

        Args:
            response_op: ResponseOp

        Returns:
            used_ops: used ops, set
            succ_ops_of_use_op: op and the next op list, dict.
            
        """
529
        unique_names = set()
530
        used_ops = set()
531 532 533 534 535 536 537 538 539 540
        succ_ops_of_use_op = {}  # {op_name: succ_ops}
        que = Queue.Queue()
        que.put(response_op)
        while que.qsize() != 0:
            op = que.get()
            for pred_op in op.get_input_ops():
                if pred_op.name not in succ_ops_of_use_op:
                    succ_ops_of_use_op[pred_op.name] = []
                if op != response_op:
                    succ_ops_of_use_op[pred_op.name].append(op)
541
                if pred_op not in used_ops:
542
                    que.put(pred_op)
543
                    used_ops.add(pred_op)
544 545
                    # check the name of op is globally unique
                    if pred_op.name in unique_names:
B
barriery 已提交
546 547
                        _LOGGER.critical("Failed to get used Ops: the"
                                         " name of Op must be unique: {}".
548 549
                                         format(pred_op.name))
                        os._exit(-1)
550
                    unique_names.add(pred_op.name)
551
        return used_ops, succ_ops_of_use_op
552 553

    def _gen_channel(self, name_gen):
554 555 556 557 558 559 560 561 562
        """
        Generate one ThreadChannel or ProcessChannel.

        Args:
            name_gen: channel name

        Returns:
            channel: one channel generated
        """
563
        channel = None
B
barrierye 已提交
564
        if self._is_thread_op:
565
            channel = ThreadChannel(
566 567 568
                name=name_gen.next(),
                maxsize=self._channel_size,
                channel_recv_frist_arrive=self._channel_recv_frist_arrive)
569 570
        else:
            channel = ProcessChannel(
571 572 573 574
                self._manager,
                name=name_gen.next(),
                maxsize=self._channel_size,
                channel_recv_frist_arrive=self._channel_recv_frist_arrive)
B
barriery 已提交
575
        _LOGGER.debug("[DAG] Generate channel: {}".format(channel.name))
576 577 578
        return channel

    def _gen_virtual_op(self, name_gen):
579 580 581 582 583 584 585 586 587
        """
        Generate one virtual Op

        Args:
            name_gen: Op name

        Returns:
            vir_op: one virtual Op object.
        """
B
barriery 已提交
588
        vir_op = VirtualOp(name=name_gen.next())
B
barriery 已提交
589
        _LOGGER.debug("[DAG] Generate virtual_op: {}".format(vir_op.name))
B
barriery 已提交
590
        return vir_op
591 592

    def _topo_sort(self, used_ops, response_op, out_degree_ops):
593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609
        """
        Topological sort of DAG, creates inverted multi-layers views.

        Args:
            used_ops: op used in DAG
            response_op: response op
            out_degree_ops: Next op list for each op, dict. the output of 
                get_use_ops()

        Returns:
            dag_views: the inverted hierarchical topology list. examples:
                DAG :[A -> B -> C -> E]
                            \-> D /
                dag_views: [[E], [C, D], [B], [A]]
                         
            last_op:the last op front of ResponseOp
        """
610 611 612 613 614 615 616
        out_degree_num = {
            name: len(ops)
            for name, ops in out_degree_ops.items()
        }
        que_idx = 0  # scroll queue 
        ques = [Queue.Queue() for _ in range(2)]
        zero_indegree_num = 0
617
        for op in used_ops:
618 619 620
            if len(op.get_input_ops()) == 0:
                zero_indegree_num += 1
        if zero_indegree_num != 1:
B
barriery 已提交
621 622
            _LOGGER.critical("Failed to topo sort: DAG contains "
                             "multiple RequestOps")
623
            os._exit(-1)
624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645
        last_op = response_op.get_input_ops()[0]
        ques[que_idx].put(last_op)

        # topo sort to get dag_views
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                sorted_op_num += 1
                for pred_op in op.get_input_ops():
                    out_degree_num[pred_op.name] -= 1
                    if out_degree_num[pred_op.name] == 0:
                        next_que.put(pred_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
646
        if sorted_op_num < len(used_ops):
B
barriery 已提交
647
            _LOGGER.critical("Failed to topo sort: not legal DAG")
648
            os._exit(-1)
649 650 651

        return dag_views, last_op

652
    def _build_dag(self, response_op):
653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669
        """
        Building DAG, the most important function in class DAG. Core steps:
        1.get_use_ops: Getting used ops, and out degree op list for each op.
        2._topo_sort: Topological sort creates inverted multi-layers views.
        3.create channels and virtual ops.

        Args:
            response_op: ResponseOp

        Returns:
            actual_ops: all OPs used in DAG, including virtual OPs
            channels: all channels used in DAG 
            input_channel: the channel of first OP 
            output_channel: the channel of last OP
            pack_func: pack_response_package function of response_op
            unpack_func: unpack_request_package function of request_op
        """
670
        if response_op is None:
B
barriery 已提交
671 672
            _LOGGER.critical("Failed to build DAG: ResponseOp"
                             " has not been set.")
673
            os._exit(-1)
674
        used_ops, out_degree_ops = DAG.get_use_ops(response_op)
675
        if not self._build_dag_each_worker:
B
barrierye 已提交
676 677
            _LOGGER.info("================= USED OP =================")
            for op in used_ops:
B
barriery 已提交
678
                if not isinstance(op, RequestOp):
B
barrierye 已提交
679 680
                    _LOGGER.info(op.name)
            _LOGGER.info("-------------------------------------------")
681
        if len(used_ops) <= 1:
682
            _LOGGER.critical(
B
barriery 已提交
683 684
                "Failed to build DAG: besides RequestOp and ResponseOp, "
                "there should be at least one Op in DAG.")
685
            os._exit(-1)
B
barriery 已提交
686 687
        if self._build_dag_each_worker:
            _LOGGER.info("Because `build_dag_each_worker` mode is used, "
B
barriery 已提交
688 689
                         "Auto-batching is set to the default config: "
                         "batch_size=1, auto_batching_timeout=None")
B
barriery 已提交
690 691
            for op in used_ops:
                op.use_default_auto_batching_config()
692 693 694

        dag_views, last_op = self._topo_sort(used_ops, response_op,
                                             out_degree_ops)
B
barrierye 已提交
695
        dag_views = list(reversed(dag_views))
696
        if not self._build_dag_each_worker:
697
            _LOGGER.info("================== DAG ====================")
B
barrierye 已提交
698
            for idx, view in enumerate(dag_views):
699
                _LOGGER.info("(VIEW {})".format(idx))
B
barrierye 已提交
700
                for op in view:
701
                    _LOGGER.info("  [{}]".format(op.name))
B
barrierye 已提交
702
                    for out_op in out_degree_ops[op.name]:
703 704
                        _LOGGER.info("    - {}".format(out_op.name))
            _LOGGER.info("-------------------------------------------")
705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746

        # create channels and virtual ops
        virtual_op_name_gen = NameGenerator("vir")
        channel_name_gen = NameGenerator("chl")
        virtual_ops = []
        channels = []
        input_channel = None
        actual_view = None
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
            if actual_view is None:
                actual_view = view
            actual_next_view = []
            pred_op_of_next_view_op = {}
            for op in actual_view:
                # find actual succ op in next view and create virtual op
                for succ_op in out_degree_ops[op.name]:
                    if succ_op in next_view:
                        if succ_op not in actual_next_view:
                            actual_next_view.append(succ_op)
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
                        # create virtual op
                        virtual_op = self._gen_virtual_op(virtual_op_name_gen)
                        virtual_ops.append(virtual_op)
                        out_degree_ops[virtual_op.name] = [succ_op]
                        actual_next_view.append(virtual_op)
                        pred_op_of_next_view_op[virtual_op.name] = [op]
                        virtual_op.add_virtual_pred_op(op)
            actual_view = actual_next_view
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
                if op.name in processed_op:
                    continue
                channel = self._gen_channel(channel_name_gen)
                channels.append(channel)
                op.add_input_channel(channel)
747
                _LOGGER.info("op:{} add input channel.".format(op.name))
748 749 750 751 752 753 754
                pred_ops = pred_op_of_next_view_op[op.name]
                if v_idx == 0:
                    input_channel = channel
                else:
                    # if pred_op is virtual op, it will use ancestors as producers to channel
                    for pred_op in pred_ops:
                        pred_op.add_output_channel(channel)
755 756
                        _LOGGER.info("pred_op:{} add output channel".format(
                            pred_op.name))
757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775
                processed_op.add(op.name)
                # find same input op to combine channel
                for other_op in actual_next_view[o_idx + 1:]:
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
        output_channel = self._gen_channel(channel_name_gen)
        channels.append(output_channel)
        last_op.add_output_channel(output_channel)
776
        _LOGGER.info("last op:{} add output channel".format(last_op.name))
777 778 779 780 781

        pack_func, unpack_func = None, None
        pack_func = response_op.pack_response_package

        actual_ops = virtual_ops
782
        for op in used_ops:
783
            if len(op.get_input_ops()) == 0:
784 785 786
                #set special features of the request op. 
                #1.set unpack function.
                #2.set output channel. 
787
                unpack_func = op.unpack_request_package
788
                op.add_output_channel(input_channel)
789 790 791 792
                continue
            actual_ops.append(op)

        for c in channels:
B
barriery 已提交
793
            _LOGGER.debug("Channel({}):\n\t- producers: {}\n\t- consumers: {}"
B
barriery 已提交
794
                          .format(c.name, c.get_producers(), c.get_consumers()))
795 796 797 798

        return (actual_ops, channels, input_channel, output_channel, pack_func,
                unpack_func)

B
barriery 已提交
799 800 801
    def get_channels(self):
        return self._channels

802
    def build(self):
803 804 805 806 807 808 809 810 811 812 813 814
        """
        Interface for building one DAG outside.

        Args:
            None

        Returns:
            _input_channel: the channel of first OP
            _output_channel:  the channel of last OP
            _pack_func: pack_response_package function of response_op
            _unpack_func: unpack_request_package function of request_op
        """
815
        (actual_ops, channels, input_channel, output_channel, pack_func,
816
         unpack_func) = self._build_dag(self._response_op)
B
barriery 已提交
817
        _LOGGER.info("[DAG] Succ build DAG")
818 819 820 821 822 823 824 825

        self._actual_ops = actual_ops
        self._channels = channels
        self._input_channel = input_channel
        self._output_channel = output_channel
        self._pack_func = pack_func
        self._unpack_func = unpack_func

B
bug fix  
barrierye 已提交
826 827
        if self._tracer is not None:
            self._tracer.set_channels(self._channels)
B
barriery 已提交
828

829 830 831
        return self._input_channel, self._output_channel, self._pack_func, self._unpack_func

    def start(self):
832 833 834 835 836 837 838 839 840
        """
        Each OP starts a thread or process by _is_thread_op 

        Args:
            None

        Returns:
            _threads_or_proces: threads or process list.
        """
841 842
        self._threads_or_proces = []
        for op in self._actual_ops:
B
barrierye 已提交
843
            op.use_profiler(self._use_profile)
B
barriery 已提交
844
            op.set_tracer(self._tracer)
B
barrierye 已提交
845
            if self._is_thread_op:
W
wangjiawei04 已提交
846
                self._threads_or_proces.extend(op.start_with_thread())
847
            else:
W
wangjiawei04 已提交
848
                self._threads_or_proces.extend(op.start_with_process())
B
barriery 已提交
849 850
        _LOGGER.info("[DAG] start")

851 852 853 854
        # not join yet
        return self._threads_or_proces

    def join(self):
855 856 857 858 859 860 861 862 863
        """
        All threads or processes join.

        Args:
            None

        Returns:
            None
        """
864
        for x in self._threads_or_proces:
W
wangjiawei04 已提交
865 866
            if x is not None:
                x.join()
867 868

    def stop(self):
869 870 871 872 873 874 875 876 877
        """
        Stopping and cleanning all channels.

        Args:
            None

        Returns:
            None 
        """
878 879
        for chl in self._channels:
            chl.stop()
880 881 882
        for op in self._actual_ops:
            op.clean_input_channel()
            op.clean_output_channels()