dag.py 23.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
import sys
B
barrierye 已提交
18
import copy
19 20 21 22 23 24 25 26 27 28
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
import os
import logging

from .operator import Op, RequestOp, ResponseOp, VirtualOp
B
barrierye 已提交
29 30
from .channel import (ThreadChannel, ProcessChannel, ChannelData,
                      ChannelDataEcode, ChannelDataType, ChannelStopError)
B
barriery 已提交
31
from .profiler import TimeProfiler, PerformanceTracer
32
from .util import NameGenerator
B
barriery 已提交
33
from .proto import pipeline_service_pb2
34

35
_LOGGER = logging.getLogger(__name__)
36 37 38


class DAGExecutor(object):
B
barriery 已提交
39 40 41 42 43
    def __init__(self, response_op, server_conf):
        build_dag_each_worker = server_conf["build_dag_each_worker"]
        server_worker_num = server_conf["worker_num"]
        dag_conf = server_conf["dag"]

44 45 46 47 48
        self._retry = dag_conf["retry"]
        client_type = dag_conf["client_type"]
        self._server_use_profile = dag_conf["use_profile"]
        channel_size = dag_conf["channel_size"]
        self._is_thread_op = dag_conf["is_thread_op"]
B
barrierye 已提交
49

B
barriery 已提交
50 51 52 53
        tracer_conf = dag_conf["tracer"]
        tracer_interval_s = tracer_conf["interval_s"]

        self.name = "@DAGExecutor"
B
barrierye 已提交
54
        self._profiler = TimeProfiler()
B
barrierye 已提交
55
        self._profiler.enable(True)
B
barrierye 已提交
56

B
barriery 已提交
57 58 59 60
        self._tracer = None
        if tracer_interval_s >= 1:
            self._tracer = PerformanceTracer(
                self._is_thread_op, tracer_interval_s, server_worker_num)
B
barriery 已提交
61

B
barrierye 已提交
62 63
        self._dag = DAG(self.name, response_op, self._server_use_profile,
                        self._is_thread_op, client_type, channel_size,
B
barriery 已提交
64
                        build_dag_each_worker, self._tracer)
B
barrierye 已提交
65 66
        (in_channel, out_channel, pack_rpc_func,
         unpack_rpc_func) = self._dag.build()
67 68 69
        self._dag.start()

        self._set_in_channel(in_channel)
70
        self._set_out_channel(out_channel)
71 72 73
        self._pack_rpc_func = pack_rpc_func
        self._unpack_rpc_func = unpack_rpc_func

B
barriery 已提交
74 75 76
        if self._tracer is not None:
            self._tracer.start()

77 78 79
        self._id_lock = threading.Lock()
        self._id_counter = 0
        self._reset_max_id = 1000000000000000000
B
barrierye 已提交
80 81
        self._cv_pool = {}
        self._cv_for_cv_pool = threading.Condition()
82
        self._fetch_buffer = {}
83 84
        self._recive_func = None

B
barrierye 已提交
85 86 87
        self._client_profile_key = "pipeline.profile"
        self._client_profile_value = "1"

88
    def start(self):
89 90
        self._recive_func = threading.Thread(
            target=DAGExecutor._recive_out_channel_func, args=(self, ))
B
barriery 已提交
91
        self._recive_func.daemon = True
92
        self._recive_func.start()
B
barriery 已提交
93
        _LOGGER.debug("[DAG Executor] Start recive thread")
94 95 96 97

    def stop(self):
        self._dag.stop()
        self._dag.join()
B
barriery 已提交
98
        _LOGGER.info("[DAG Executor] Stop")
99 100

    def _get_next_data_id(self):
B
bug fix  
barriery 已提交
101
        data_id = None
102 103
        with self._id_lock:
            if self._id_counter >= self._reset_max_id:
B
barriery 已提交
104
                _LOGGER.info("[DAG Executor] Reset request id")
105
                self._id_counter -= self._reset_max_id
B
bug fix  
barriery 已提交
106
            data_id = self._id_counter
107
            self._id_counter += 1
B
bug fix  
barriery 已提交
108 109 110
        cond_v = threading.Condition()
        with self._cv_for_cv_pool:
            self._cv_pool[data_id] = cond_v
111
            self._fetch_buffer[data_id] = None
B
bug fix  
barriery 已提交
112
        return data_id, cond_v
113 114 115

    def _set_in_channel(self, in_channel):
        if not isinstance(in_channel, (ThreadChannel, ProcessChannel)):
B
barriery 已提交
116 117 118
            _LOGGER.critical("[DAG Executor] Failed to set in_channel: "
                             "in_channel must be Channel type, but get {}".
                             format(type(in_channel)))
119
            os._exit(-1)
120 121 122 123 124
        in_channel.add_producer(self.name)
        self._in_channel = in_channel

    def _set_out_channel(self, out_channel):
        if not isinstance(out_channel, (ThreadChannel, ProcessChannel)):
B
barriery 已提交
125 126 127
            _LOGGER.critical("[DAG Executor] Failed to set out_channel: "
                             "must be Channel type, but get {}".format(
                                 type(out_channel)))
128
            os._exit(-1)
129 130 131 132
        out_channel.add_consumer(self.name)
        self._out_channel = out_channel

    def _recive_out_channel_func(self):
B
barrierye 已提交
133
        cv = None
B
barrierye 已提交
134 135 136 137
        while True:
            try:
                channeldata_dict = self._out_channel.front(self.name)
            except ChannelStopError:
B
barriery 已提交
138
                _LOGGER.info("[DAG Executor] Stop.")
B
barrierye 已提交
139 140 141 142 143 144 145
                with self._cv_for_cv_pool:
                    for data_id, cv in self._cv_pool.items():
                        closed_errror_data = ChannelData(
                            ecode=ChannelDataEcode.CLOSED_ERROR.value,
                            error_info="dag closed.",
                            data_id=data_id)
                        with cv:
146
                            self._fetch_buffer[data_id] = closed_errror_data
B
barrierye 已提交
147 148 149
                            cv.notify_all()
                break

150
            if len(channeldata_dict) != 1:
151
                _LOGGER.critical(
B
barriery 已提交
152 153
                    "[DAG Executor] Failed to fetch result: out_channel "
                    "cannot have multiple input ops")
154 155 156
                os._exit(-1)
            (_, channeldata), = channeldata_dict.items()
            if not isinstance(channeldata, ChannelData):
157
                _LOGGER.critical(
B
barriery 已提交
158 159
                    '[DAG Executor] Failed to fetch result: data in out_channel" \
                    " must be ChannelData type, but get {}'
B
barriery 已提交
160
                    .format(type(channeldata)))
B
barriery 已提交
161
                os._exit(-1)
B
barrierye 已提交
162 163

            data_id = channeldata.id
B
barriery 已提交
164 165
            _LOGGER.debug("(logid={}) [recive thread] Fetched data".format(
                data_id))
B
barrierye 已提交
166
            with self._cv_for_cv_pool:
167 168 169 170
                cond_v = self._cv_pool[data_id]
            with cond_v:
                self._fetch_buffer[data_id] = channeldata
                cond_v.notify_all()
171

B
bug fix  
barriery 已提交
172
    def _get_channeldata_from_fetch_buffer(self, data_id, cond_v):
173 174
        ready_data = None

B
bug fix  
barriery 已提交
175
        with cond_v:
176 177 178 179 180 181 182 183 184 185 186 187 188
            with self._cv_for_cv_pool:
                if self._fetch_buffer[data_id] is not None:
                    # The requested data is already ready
                    ready_data = self._fetch_buffer[data_id]
                    self._cv_pool.pop(data_id)
                    self._fetch_buffer.pop(data_id)
            if ready_data is None:
                # Wait for data ready
                cond_v.wait()
                with self._cv_for_cv_pool:
                    ready_data = self._fetch_buffer[data_id]
                    self._cv_pool.pop(data_id)
                    self._fetch_buffer.pop(data_id)
B
barriery 已提交
189
        _LOGGER.debug("(logid={}) [resp thread] Got data".format(data_id))
190
        return ready_data
191

B
barrierye 已提交
192
    def _pack_channeldata(self, rpc_request, data_id):
193 194 195 196
        dictdata = None
        try:
            dictdata = self._unpack_rpc_func(rpc_request)
        except Exception as e:
B
barriery 已提交
197 198 199 200
            _LOGGER.error(
                "(logid={}) Failed to parse RPC request package: {}"
                .format(data_id, e),
                exc_info=True)
201 202 203
            return ChannelData(
                ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
                error_info="rpc package error: {}".format(e),
B
barrierye 已提交
204
                data_id=data_id)
205
        else:
B
barrierye 已提交
206 207 208 209 210 211 212
            # because unpack_rpc_func is rewritten by user, we need
            # to look for client_profile_key field in rpc_request
            profile_value = None
            for idx, key in enumerate(rpc_request.key):
                if key == self._client_profile_key:
                    profile_value = rpc_request.value[idx]
                    break
B
barriery 已提交
213
            client_need_profile = (profile_value == self._client_profile_value)
B
barriery 已提交
214
            _LOGGER.debug("(logid={}) Need profile in client: {}".format(
B
barriery 已提交
215
                data_id, client_need_profile))
216 217 218
            return ChannelData(
                datatype=ChannelDataType.DICT.value,
                dictdata=dictdata,
B
barrierye 已提交
219
                data_id=data_id,
B
barriery 已提交
220
                client_need_profile=client_need_profile)
221 222

    def call(self, rpc_request):
B
barriery 已提交
223 224
        if self._tracer is not None:
            trace_buffer = self._tracer.data_buffer()
B
barriery 已提交
225

B
bug fix  
barriery 已提交
226
        data_id, cond_v = self._get_next_data_id()
B
barriery 已提交
227
        _LOGGER.info("(logid={}) Succ generate id".format(data_id))
B
barriery 已提交
228

B
barriery 已提交
229
        start_call, end_call = None, None
B
barrierye 已提交
230
        if not self._is_thread_op:
B
barriery 已提交
231 232
            start_call = self._profiler.record("call_{}#DAG-{}_0".format(
                data_id, data_id))
B
barrierye 已提交
233
        else:
B
barriery 已提交
234
            start_call = self._profiler.record("call_{}#DAG_0".format(data_id))
B
barrierye 已提交
235

B
barriery 已提交
236
        _LOGGER.debug("(logid={}) Parsing RPC request package".format(data_id))
B
barrierye 已提交
237 238 239
        self._profiler.record("prepack_{}#{}_0".format(data_id, self.name))
        req_channeldata = self._pack_channeldata(rpc_request, data_id)
        self._profiler.record("prepack_{}#{}_1".format(data_id, self.name))
240 241 242

        resp_channeldata = None
        for i in range(self._retry):
B
barriery 已提交
243 244
            _LOGGER.debug("(logid={}) Pushing data into Graph engine".format(
                data_id))
B
barrierye 已提交
245 246 247
            try:
                self._in_channel.push(req_channeldata, self.name)
            except ChannelStopError:
B
barriery 已提交
248
                _LOGGER.debug("[DAG Executor] Stop")
B
bug fix  
barriery 已提交
249 250
                with self._cv_for_cv_pool:
                    self._cv_pool.pop(data_id)
B
barrierye 已提交
251 252 253 254 255
                return self._pack_for_rpc_resp(
                    ChannelData(
                        ecode=ChannelDataEcode.CLOSED_ERROR.value,
                        error_info="dag closed.",
                        data_id=data_id))
256

B
barriery 已提交
257
            _LOGGER.debug("(logid={}) Wait for Graph engine...".format(data_id))
B
bug fix  
barriery 已提交
258 259
            resp_channeldata = self._get_channeldata_from_fetch_buffer(data_id,
                                                                       cond_v)
260 261

            if resp_channeldata.ecode == ChannelDataEcode.OK.value:
B
barriery 已提交
262
                _LOGGER.info("(logid={}) Succ predict".format(data_id))
263
                break
B
barriery 已提交
264
            else:
B
barriery 已提交
265 266
                _LOGGER.error("(logid={}) Failed to predict: {}"
                              .format(data_id, resp_channeldata.error_info))
B
barriery 已提交
267 268 269
                if resp_channeldata.ecode != ChannelDataEcode.TIMEOUT.value:
                    break

270
            if i + 1 < self._retry:
B
barriery 已提交
271 272
                _LOGGER.warning("(logid={}) DAGExecutor retry({}/{})".format(
                    data_id, i + 1, self._retry))
273

B
barriery 已提交
274
        _LOGGER.debug("(logid={}) Packing RPC response package".format(data_id))
B
barrierye 已提交
275
        self._profiler.record("postpack_{}#{}_0".format(data_id, self.name))
276
        rpc_resp = self._pack_for_rpc_resp(resp_channeldata)
B
barrierye 已提交
277
        self._profiler.record("postpack_{}#{}_1".format(data_id, self.name))
B
barrierye 已提交
278
        if not self._is_thread_op:
B
barriery 已提交
279 280
            end_call = self._profiler.record("call_{}#DAG-{}_1".format(data_id,
                                                                       data_id))
B
barrierye 已提交
281
        else:
B
barriery 已提交
282
            end_call = self._profiler.record("call_{}#DAG_1".format(data_id))
B
barriery 已提交
283 284 285 286 287 288 289 290

        if self._tracer is not None:
            if resp_channeldata.ecode == ChannelDataEcode.OK.value:
                trace_buffer.put(("DAG", "call_{}".format(data_id), True,
                                  end_call - start_call))
            else:
                trace_buffer.put(("DAG", "call_{}".format(data_id), False,
                                  end_call - start_call))
B
barrierye 已提交
291 292 293 294 295 296 297 298

        profile_str = self._profiler.gen_profile_str()
        if self._server_use_profile:
            sys.stderr.write(profile_str)

        # add profile info into rpc_resp
        profile_value = ""
        if resp_channeldata.client_need_profile:
B
barrierye 已提交
299 300 301
            profile_set = resp_channeldata.profile_data_set
            profile_set.add(profile_str)
            profile_value = "".join(list(profile_set))
B
barrierye 已提交
302 303 304
        rpc_resp.key.append(self._client_profile_key)
        rpc_resp.value.append(profile_value)

305 306 307
        return rpc_resp

    def _pack_for_rpc_resp(self, channeldata):
B
barriery 已提交
308 309 310 311 312 313 314 315 316 317 318
        try:
            return self._pack_rpc_func(channeldata)
        except Exception as e:
            _LOGGER.error(
                "(logid={}) Failed to pack RPC response package: {}"
                .format(channeldata.id, e),
                exc_info=True)
            resp = pipeline_service_pb2.Response()
            resp.ecode = ChannelDataEcode.RPC_PACKAGE_ERROR.value
            resp.error_info = "rpc package error: {}".format(e)
            return resp
319 320 321


class DAG(object):
B
barrierye 已提交
322
    def __init__(self, request_name, response_op, use_profile, is_thread_op,
B
barriery 已提交
323
                 client_type, channel_size, build_dag_each_worker, tracer):
B
barrierye 已提交
324
        self._request_name = request_name
325
        self._response_op = response_op
B
barrierye 已提交
326
        self._use_profile = use_profile
B
barrierye 已提交
327
        self._is_thread_op = is_thread_op
328 329
        self._channel_size = channel_size
        self._client_type = client_type
B
barriery 已提交
330
        self._build_dag_each_worker = build_dag_each_worker
B
barriery 已提交
331
        self._tracer = tracer
B
barrierye 已提交
332
        if not self._is_thread_op:
333
            self._manager = multiprocessing.Manager()
B
barriery 已提交
334
        _LOGGER.info("[DAG] Succ init")
335 336 337

    def get_use_ops(self, response_op):
        unique_names = set()
338
        used_ops = set()
339 340 341 342 343 344 345 346 347 348
        succ_ops_of_use_op = {}  # {op_name: succ_ops}
        que = Queue.Queue()
        que.put(response_op)
        while que.qsize() != 0:
            op = que.get()
            for pred_op in op.get_input_ops():
                if pred_op.name not in succ_ops_of_use_op:
                    succ_ops_of_use_op[pred_op.name] = []
                if op != response_op:
                    succ_ops_of_use_op[pred_op.name].append(op)
349
                if pred_op not in used_ops:
350
                    que.put(pred_op)
351
                    used_ops.add(pred_op)
352 353
                    # check the name of op is globally unique
                    if pred_op.name in unique_names:
B
barriery 已提交
354 355
                        _LOGGER.critical("Failed to get used Ops: the"
                                         " name of Op must be unique: {}".
356 357
                                         format(pred_op.name))
                        os._exit(-1)
358
                    unique_names.add(pred_op.name)
359
        return used_ops, succ_ops_of_use_op
360 361 362

    def _gen_channel(self, name_gen):
        channel = None
B
barrierye 已提交
363
        if self._is_thread_op:
364 365 366 367 368
            channel = ThreadChannel(
                name=name_gen.next(), maxsize=self._channel_size)
        else:
            channel = ProcessChannel(
                self._manager, name=name_gen.next(), maxsize=self._channel_size)
B
barriery 已提交
369
        _LOGGER.debug("[DAG] Generate channel: {}".format(channel.name))
370 371 372
        return channel

    def _gen_virtual_op(self, name_gen):
B
barriery 已提交
373
        vir_op = VirtualOp(name=name_gen.next())
B
barriery 已提交
374
        _LOGGER.debug("[DAG] Generate virtual_op: {}".format(vir_op.name))
B
barriery 已提交
375
        return vir_op
376 377 378 379 380 381 382 383 384

    def _topo_sort(self, used_ops, response_op, out_degree_ops):
        out_degree_num = {
            name: len(ops)
            for name, ops in out_degree_ops.items()
        }
        que_idx = 0  # scroll queue 
        ques = [Queue.Queue() for _ in range(2)]
        zero_indegree_num = 0
385
        for op in used_ops:
386 387 388
            if len(op.get_input_ops()) == 0:
                zero_indegree_num += 1
        if zero_indegree_num != 1:
B
barriery 已提交
389 390
            _LOGGER.critical("Failed to topo sort: DAG contains "
                             "multiple RequestOps")
391
            os._exit(-1)
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413
        last_op = response_op.get_input_ops()[0]
        ques[que_idx].put(last_op)

        # topo sort to get dag_views
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                sorted_op_num += 1
                for pred_op in op.get_input_ops():
                    out_degree_num[pred_op.name] -= 1
                    if out_degree_num[pred_op.name] == 0:
                        next_que.put(pred_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
414
        if sorted_op_num < len(used_ops):
B
barriery 已提交
415
            _LOGGER.critical("Failed to topo sort: not legal DAG")
416
            os._exit(-1)
417 418 419

        return dag_views, last_op

420
    def _build_dag(self, response_op):
421
        if response_op is None:
B
barriery 已提交
422 423
            _LOGGER.critical("Failed to build DAG: ResponseOp"
                             " has not been set.")
424
            os._exit(-1)
425
        used_ops, out_degree_ops = self.get_use_ops(response_op)
426
        if not self._build_dag_each_worker:
B
barrierye 已提交
427 428 429 430 431
            _LOGGER.info("================= USED OP =================")
            for op in used_ops:
                if op.name != self._request_name:
                    _LOGGER.info(op.name)
            _LOGGER.info("-------------------------------------------")
432
        if len(used_ops) <= 1:
433
            _LOGGER.critical(
B
barriery 已提交
434 435
                "Failed to build DAG: besides RequestOp and ResponseOp, "
                "there should be at least one Op in DAG.")
436
            os._exit(-1)
B
barriery 已提交
437 438
        if self._build_dag_each_worker:
            _LOGGER.info("Because `build_dag_each_worker` mode is used, "
B
barriery 已提交
439 440
                         "Auto-batching is set to the default config: "
                         "batch_size=1, auto_batching_timeout=None")
B
barriery 已提交
441 442
            for op in used_ops:
                op.use_default_auto_batching_config()
443 444 445

        dag_views, last_op = self._topo_sort(used_ops, response_op,
                                             out_degree_ops)
B
barrierye 已提交
446
        dag_views = list(reversed(dag_views))
447 448
        if not self._build_dag_each_worker:
            _LOGGER.debug("================== DAG ====================")
B
barrierye 已提交
449
            for idx, view in enumerate(dag_views):
450
                _LOGGER.debug("(VIEW {})".format(idx))
B
barrierye 已提交
451
                for op in view:
452
                    _LOGGER.debug("  [{}]".format(op.name))
B
barrierye 已提交
453
                    for out_op in out_degree_ops[op.name]:
454 455
                        _LOGGER.debug("    - {}".format(out_op.name))
            _LOGGER.debug("-------------------------------------------")
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528

        # create channels and virtual ops
        virtual_op_name_gen = NameGenerator("vir")
        channel_name_gen = NameGenerator("chl")
        virtual_ops = []
        channels = []
        input_channel = None
        actual_view = None
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
            if actual_view is None:
                actual_view = view
            actual_next_view = []
            pred_op_of_next_view_op = {}
            for op in actual_view:
                # find actual succ op in next view and create virtual op
                for succ_op in out_degree_ops[op.name]:
                    if succ_op in next_view:
                        if succ_op not in actual_next_view:
                            actual_next_view.append(succ_op)
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
                        # create virtual op
                        virtual_op = self._gen_virtual_op(virtual_op_name_gen)
                        virtual_ops.append(virtual_op)
                        out_degree_ops[virtual_op.name] = [succ_op]
                        actual_next_view.append(virtual_op)
                        pred_op_of_next_view_op[virtual_op.name] = [op]
                        virtual_op.add_virtual_pred_op(op)
            actual_view = actual_next_view
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
                if op.name in processed_op:
                    continue
                channel = self._gen_channel(channel_name_gen)
                channels.append(channel)
                op.add_input_channel(channel)
                pred_ops = pred_op_of_next_view_op[op.name]
                if v_idx == 0:
                    input_channel = channel
                else:
                    # if pred_op is virtual op, it will use ancestors as producers to channel
                    for pred_op in pred_ops:
                        pred_op.add_output_channel(channel)
                processed_op.add(op.name)
                # find same input op to combine channel
                for other_op in actual_next_view[o_idx + 1:]:
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
        output_channel = self._gen_channel(channel_name_gen)
        channels.append(output_channel)
        last_op.add_output_channel(output_channel)

        pack_func, unpack_func = None, None
        pack_func = response_op.pack_response_package

        actual_ops = virtual_ops
529
        for op in used_ops:
530 531 532 533 534 535
            if len(op.get_input_ops()) == 0:
                unpack_func = op.unpack_request_package
                continue
            actual_ops.append(op)

        for c in channels:
B
barriery 已提交
536
            _LOGGER.debug("Channel({}):\n\t- producers: {}\n\t- consumers: {}"
B
barriery 已提交
537
                          .format(c.name, c.get_producers(), c.get_consumers()))
538 539 540 541

        return (actual_ops, channels, input_channel, output_channel, pack_func,
                unpack_func)

B
barriery 已提交
542 543 544
    def get_channels(self):
        return self._channels

545 546
    def build(self):
        (actual_ops, channels, input_channel, output_channel, pack_func,
547
         unpack_func) = self._build_dag(self._response_op)
B
barriery 已提交
548
        _LOGGER.info("[DAG] Succ build DAG")
549 550 551 552 553 554 555 556

        self._actual_ops = actual_ops
        self._channels = channels
        self._input_channel = input_channel
        self._output_channel = output_channel
        self._pack_func = pack_func
        self._unpack_func = unpack_func

B
barriery 已提交
557 558
        self._tracer.set_channels(self._channels)

559 560 561 562 563
        return self._input_channel, self._output_channel, self._pack_func, self._unpack_func

    def start(self):
        self._threads_or_proces = []
        for op in self._actual_ops:
B
barrierye 已提交
564
            op.use_profiler(self._use_profile)
B
barriery 已提交
565
            op.set_tracer(self._tracer)
B
barrierye 已提交
566
            if self._is_thread_op:
567
                self._threads_or_proces.extend(
568 569
                    op.start_with_thread(self._client_type))
            else:
570
                self._threads_or_proces.extend(
571
                    op.start_with_process(self._client_type))
B
barriery 已提交
572 573
        _LOGGER.info("[DAG] start")

574 575 576 577 578 579 580 581 582 583
        # not join yet
        return self._threads_or_proces

    def join(self):
        for x in self._threads_or_proces:
            x.join()

    def stop(self):
        for chl in self._channels:
            chl.stop()
584 585 586
        for op in self._actual_ops:
            op.clean_input_channel()
            op.clean_output_channels()