dag.py 22.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
import sys
B
barrierye 已提交
18
import copy
19 20 21 22 23 24 25 26 27 28
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
import os
import logging

from .operator import Op, RequestOp, ResponseOp, VirtualOp
B
barrierye 已提交
29 30
from .channel import (ThreadChannel, ProcessChannel, ChannelData,
                      ChannelDataEcode, ChannelDataType, ChannelStopError)
B
barriery 已提交
31
from .profiler import TimeProfiler, PerformanceTracer
32
from .util import NameGenerator
B
barriery 已提交
33
from .proto import pipeline_service_pb2
34 35 36 37 38

_LOGGER = logging.getLogger()


class DAGExecutor(object):
39 40 41 42 43 44 45
    def __init__(self, response_op, dag_conf):
        self._retry = dag_conf["retry"]
        client_type = dag_conf["client_type"]
        self._server_use_profile = dag_conf["use_profile"]
        channel_size = dag_conf["channel_size"]
        self._is_thread_op = dag_conf["is_thread_op"]
        build_dag_each_worker = dag_conf["build_dag_each_worker"]
B
barrierye 已提交
46

B
barrierye 已提交
47
        self.name = "@G"
B
barrierye 已提交
48
        self._profiler = TimeProfiler()
B
barrierye 已提交
49
        self._profiler.enable(True)
B
barrierye 已提交
50

B
barriery 已提交
51 52
        self._tracer = PerformanceTracer()

B
barrierye 已提交
53 54
        self._dag = DAG(self.name, response_op, self._server_use_profile,
                        self._is_thread_op, client_type, channel_size,
B
barriery 已提交
55
                        build_dag_each_worker, self._tracer)
B
barrierye 已提交
56 57
        (in_channel, out_channel, pack_rpc_func,
         unpack_rpc_func) = self._dag.build()
58 59
        self._dag.start()

B
barriery 已提交
60 61 62
        self._tracer.set_channels(self._dag.get_channels())
        self._tracer.start()

63
        self._set_in_channel(in_channel)
64
        self._set_out_channel(out_channel)
65 66 67 68 69 70
        self._pack_rpc_func = pack_rpc_func
        self._unpack_rpc_func = unpack_rpc_func

        self._id_lock = threading.Lock()
        self._id_counter = 0
        self._reset_max_id = 1000000000000000000
B
barrierye 已提交
71 72
        self._cv_pool = {}
        self._cv_for_cv_pool = threading.Condition()
73
        self._fetch_buffer = {}
74 75
        self._recive_func = None

B
barrierye 已提交
76 77 78
        self._client_profile_key = "pipeline.profile"
        self._client_profile_value = "1"

79
    def start(self):
80 81
        self._recive_func = threading.Thread(
            target=DAGExecutor._recive_out_channel_func, args=(self, ))
B
barriery 已提交
82
        self._recive_func.daemon = True
83
        self._recive_func.start()
B
barriery 已提交
84
        _LOGGER.debug("[DAG Executor] Start recive thread")
85 86 87 88

    def stop(self):
        self._dag.stop()
        self._dag.join()
B
barriery 已提交
89
        _LOGGER.info("[DAG Executor] Stop")
90 91

    def _get_next_data_id(self):
B
bug fix  
barriery 已提交
92
        data_id = None
93 94
        with self._id_lock:
            if self._id_counter >= self._reset_max_id:
B
barriery 已提交
95
                _LOGGER.info("[DAG Executor] Reset request id")
96
                self._id_counter -= self._reset_max_id
B
bug fix  
barriery 已提交
97
            data_id = self._id_counter
98
            self._id_counter += 1
B
bug fix  
barriery 已提交
99 100 101
        cond_v = threading.Condition()
        with self._cv_for_cv_pool:
            self._cv_pool[data_id] = cond_v
102
            self._fetch_buffer[data_id] = None
B
bug fix  
barriery 已提交
103
        return data_id, cond_v
104 105 106

    def _set_in_channel(self, in_channel):
        if not isinstance(in_channel, (ThreadChannel, ProcessChannel)):
B
barriery 已提交
107 108 109
            _LOGGER.critical("[DAG Executor] Failed to set in_channel: "
                             "in_channel must be Channel type, but get {}".
                             format(type(in_channel)))
110
            os._exit(-1)
111 112 113 114 115
        in_channel.add_producer(self.name)
        self._in_channel = in_channel

    def _set_out_channel(self, out_channel):
        if not isinstance(out_channel, (ThreadChannel, ProcessChannel)):
B
barriery 已提交
116 117 118
            _LOGGER.critical("[DAG Executor] Failed to set out_channel: "
                             "must be Channel type, but get {}".format(
                                 type(out_channel)))
119
            os._exit(-1)
120 121 122 123
        out_channel.add_consumer(self.name)
        self._out_channel = out_channel

    def _recive_out_channel_func(self):
B
barrierye 已提交
124
        cv = None
B
barrierye 已提交
125 126 127 128
        while True:
            try:
                channeldata_dict = self._out_channel.front(self.name)
            except ChannelStopError:
B
barriery 已提交
129
                _LOGGER.info("[DAG Executor] Stop.")
B
barrierye 已提交
130 131 132 133 134 135 136
                with self._cv_for_cv_pool:
                    for data_id, cv in self._cv_pool.items():
                        closed_errror_data = ChannelData(
                            ecode=ChannelDataEcode.CLOSED_ERROR.value,
                            error_info="dag closed.",
                            data_id=data_id)
                        with cv:
137
                            self._fetch_buffer[data_id] = closed_errror_data
B
barrierye 已提交
138 139 140
                            cv.notify_all()
                break

141
            if len(channeldata_dict) != 1:
142
                _LOGGER.critical(
B
barriery 已提交
143 144
                    "[DAG Executor] Failed to fetch result: out_channel "
                    "cannot have multiple input ops")
145 146 147
                os._exit(-1)
            (_, channeldata), = channeldata_dict.items()
            if not isinstance(channeldata, ChannelData):
148
                _LOGGER.critical(
B
barriery 已提交
149 150
                    '[DAG Executor] Failed to fetch result: data in out_channel" \
                    " must be ChannelData type, but get {}'
B
barriery 已提交
151
                    .format(type(channeldata)))
B
barriery 已提交
152
                os._exit(-1)
B
barrierye 已提交
153 154

            data_id = channeldata.id
B
barriery 已提交
155 156
            _LOGGER.debug("(logid={}) [recive thread] Fetched data".format(
                data_id))
B
barrierye 已提交
157
            with self._cv_for_cv_pool:
158 159 160 161
                cond_v = self._cv_pool[data_id]
            with cond_v:
                self._fetch_buffer[data_id] = channeldata
                cond_v.notify_all()
162

B
bug fix  
barriery 已提交
163
    def _get_channeldata_from_fetch_buffer(self, data_id, cond_v):
164 165
        ready_data = None

B
bug fix  
barriery 已提交
166
        with cond_v:
167 168 169 170 171 172 173 174 175 176 177 178 179
            with self._cv_for_cv_pool:
                if self._fetch_buffer[data_id] is not None:
                    # The requested data is already ready
                    ready_data = self._fetch_buffer[data_id]
                    self._cv_pool.pop(data_id)
                    self._fetch_buffer.pop(data_id)
            if ready_data is None:
                # Wait for data ready
                cond_v.wait()
                with self._cv_for_cv_pool:
                    ready_data = self._fetch_buffer[data_id]
                    self._cv_pool.pop(data_id)
                    self._fetch_buffer.pop(data_id)
B
barriery 已提交
180
        _LOGGER.debug("(logid={}) [resp thread] Got data".format(data_id))
181
        return ready_data
182

B
barrierye 已提交
183
    def _pack_channeldata(self, rpc_request, data_id):
184 185 186 187
        dictdata = None
        try:
            dictdata = self._unpack_rpc_func(rpc_request)
        except Exception as e:
B
barriery 已提交
188 189 190 191
            _LOGGER.error(
                "(logid={}) Failed to parse RPC request package: {}"
                .format(data_id, e),
                exc_info=True)
192 193 194
            return ChannelData(
                ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
                error_info="rpc package error: {}".format(e),
B
barrierye 已提交
195
                data_id=data_id)
196
        else:
B
barrierye 已提交
197 198 199 200 201 202 203
            # because unpack_rpc_func is rewritten by user, we need
            # to look for client_profile_key field in rpc_request
            profile_value = None
            for idx, key in enumerate(rpc_request.key):
                if key == self._client_profile_key:
                    profile_value = rpc_request.value[idx]
                    break
B
barriery 已提交
204
            client_need_profile = (profile_value == self._client_profile_value)
B
barriery 已提交
205
            _LOGGER.debug("(logid={}) Need profile in client: {}".format(
B
barriery 已提交
206
                data_id, client_need_profile))
207 208 209
            return ChannelData(
                datatype=ChannelDataType.DICT.value,
                dictdata=dictdata,
B
barrierye 已提交
210
                data_id=data_id,
B
barriery 已提交
211
                client_need_profile=client_need_profile)
212 213

    def call(self, rpc_request):
B
barriery 已提交
214 215
        data_buffer = self._tracer.data_buffer()

B
bug fix  
barriery 已提交
216
        data_id, cond_v = self._get_next_data_id()
B
barriery 已提交
217
        _LOGGER.info("(logid={}) Succ generate id".format(data_id))
B
barriery 已提交
218

B
barriery 已提交
219
        start_call, end_call = None, None
B
barrierye 已提交
220
        if not self._is_thread_op:
B
barriery 已提交
221 222
            start_call = self._profiler.record("call_{}#DAG-{}_0".format(
                data_id, data_id))
B
barrierye 已提交
223
        else:
B
barriery 已提交
224
            start_call = self._profiler.record("call_{}#DAG_0".format(data_id))
B
barriery 已提交
225
        data_buffer.put(("DAG", "call_{}".format(data_id), 0, start_call))
B
barrierye 已提交
226

B
barriery 已提交
227
        _LOGGER.debug("(logid={}) Parsing RPC request package".format(data_id))
B
barrierye 已提交
228 229 230
        self._profiler.record("prepack_{}#{}_0".format(data_id, self.name))
        req_channeldata = self._pack_channeldata(rpc_request, data_id)
        self._profiler.record("prepack_{}#{}_1".format(data_id, self.name))
231 232 233

        resp_channeldata = None
        for i in range(self._retry):
B
barriery 已提交
234 235
            _LOGGER.debug("(logid={}) Pushing data into Graph engine".format(
                data_id))
B
barrierye 已提交
236 237 238
            try:
                self._in_channel.push(req_channeldata, self.name)
            except ChannelStopError:
B
barriery 已提交
239
                _LOGGER.debug("[DAG Executor] Stop")
B
bug fix  
barriery 已提交
240 241
                with self._cv_for_cv_pool:
                    self._cv_pool.pop(data_id)
B
barrierye 已提交
242 243 244 245 246
                return self._pack_for_rpc_resp(
                    ChannelData(
                        ecode=ChannelDataEcode.CLOSED_ERROR.value,
                        error_info="dag closed.",
                        data_id=data_id))
247

B
barriery 已提交
248
            _LOGGER.debug("(logid={}) Wait for Graph engine...".format(data_id))
B
bug fix  
barriery 已提交
249 250
            resp_channeldata = self._get_channeldata_from_fetch_buffer(data_id,
                                                                       cond_v)
251 252

            if resp_channeldata.ecode == ChannelDataEcode.OK.value:
B
barriery 已提交
253
                _LOGGER.debug("(logid={}) Succ predict".format(data_id))
254
                break
B
barriery 已提交
255
            else:
B
barriery 已提交
256 257
                _LOGGER.error("(logid={}) Failed to predict: {}"
                              .format(data_id, resp_channeldata.error_info))
B
barriery 已提交
258 259 260
                if resp_channeldata.ecode != ChannelDataEcode.TIMEOUT.value:
                    break

261
            if i + 1 < self._retry:
B
barriery 已提交
262 263
                _LOGGER.warning("(logid={}) DAGExecutor retry({}/{})".format(
                    data_id, i + 1, self._retry))
264

B
barriery 已提交
265
        _LOGGER.debug("(logid={}) Packing RPC response package".format(data_id))
B
barrierye 已提交
266
        self._profiler.record("postpack_{}#{}_0".format(data_id, self.name))
267
        rpc_resp = self._pack_for_rpc_resp(resp_channeldata)
B
barrierye 已提交
268
        self._profiler.record("postpack_{}#{}_1".format(data_id, self.name))
B
barrierye 已提交
269
        if not self._is_thread_op:
B
barriery 已提交
270 271
            end_call = self._profiler.record("call_{}#DAG-{}_1".format(data_id,
                                                                       data_id))
B
barrierye 已提交
272
        else:
B
barriery 已提交
273
            end_call = self._profiler.record("call_{}#DAG_1".format(data_id))
B
barriery 已提交
274
        data_buffer.put(("DAG", "call_{}".format(data_id), 1, end_call))
B
barrierye 已提交
275 276 277 278 279 280 281 282

        profile_str = self._profiler.gen_profile_str()
        if self._server_use_profile:
            sys.stderr.write(profile_str)

        # add profile info into rpc_resp
        profile_value = ""
        if resp_channeldata.client_need_profile:
B
barrierye 已提交
283 284 285
            profile_set = resp_channeldata.profile_data_set
            profile_set.add(profile_str)
            profile_value = "".join(list(profile_set))
B
barrierye 已提交
286 287 288
        rpc_resp.key.append(self._client_profile_key)
        rpc_resp.value.append(profile_value)

289 290 291
        return rpc_resp

    def _pack_for_rpc_resp(self, channeldata):
B
barriery 已提交
292 293 294 295 296 297 298 299 300 301 302
        try:
            return self._pack_rpc_func(channeldata)
        except Exception as e:
            _LOGGER.error(
                "(logid={}) Failed to pack RPC response package: {}"
                .format(channeldata.id, e),
                exc_info=True)
            resp = pipeline_service_pb2.Response()
            resp.ecode = ChannelDataEcode.RPC_PACKAGE_ERROR.value
            resp.error_info = "rpc package error: {}".format(e)
            return resp
303 304 305


class DAG(object):
B
barrierye 已提交
306
    def __init__(self, request_name, response_op, use_profile, is_thread_op,
B
barriery 已提交
307
                 client_type, channel_size, build_dag_each_worker, tracer):
B
barrierye 已提交
308
        self._request_name = request_name
309
        self._response_op = response_op
B
barrierye 已提交
310
        self._use_profile = use_profile
B
barrierye 已提交
311
        self._is_thread_op = is_thread_op
312 313
        self._channel_size = channel_size
        self._client_type = client_type
B
barriery 已提交
314
        self._build_dag_each_worker = build_dag_each_worker
B
barriery 已提交
315
        self._tracer = tracer
B
barrierye 已提交
316
        if not self._is_thread_op:
317
            self._manager = multiprocessing.Manager()
B
barriery 已提交
318
        _LOGGER.info("[DAG] Succ init")
319 320 321

    def get_use_ops(self, response_op):
        unique_names = set()
322
        used_ops = set()
323 324 325 326 327 328 329 330 331 332
        succ_ops_of_use_op = {}  # {op_name: succ_ops}
        que = Queue.Queue()
        que.put(response_op)
        while que.qsize() != 0:
            op = que.get()
            for pred_op in op.get_input_ops():
                if pred_op.name not in succ_ops_of_use_op:
                    succ_ops_of_use_op[pred_op.name] = []
                if op != response_op:
                    succ_ops_of_use_op[pred_op.name].append(op)
333
                if pred_op not in used_ops:
334
                    que.put(pred_op)
335
                    used_ops.add(pred_op)
336 337
                    # check the name of op is globally unique
                    if pred_op.name in unique_names:
B
barriery 已提交
338 339
                        _LOGGER.critical("Failed to get used Ops: the"
                                         " name of Op must be unique: {}".
340 341
                                         format(pred_op.name))
                        os._exit(-1)
342
                    unique_names.add(pred_op.name)
343
        return used_ops, succ_ops_of_use_op
344 345 346

    def _gen_channel(self, name_gen):
        channel = None
B
barrierye 已提交
347
        if self._is_thread_op:
348 349 350 351 352
            channel = ThreadChannel(
                name=name_gen.next(), maxsize=self._channel_size)
        else:
            channel = ProcessChannel(
                self._manager, name=name_gen.next(), maxsize=self._channel_size)
B
barriery 已提交
353
        _LOGGER.debug("[DAG] Generate channel: {}".format(channel.name))
354 355 356
        return channel

    def _gen_virtual_op(self, name_gen):
B
barriery 已提交
357
        vir_op = VirtualOp(name=name_gen.next())
B
barriery 已提交
358
        _LOGGER.debug("[DAG] Generate virtual_op: {}".format(vir_op.name))
B
barriery 已提交
359
        return vir_op
360 361 362 363 364 365 366 367 368

    def _topo_sort(self, used_ops, response_op, out_degree_ops):
        out_degree_num = {
            name: len(ops)
            for name, ops in out_degree_ops.items()
        }
        que_idx = 0  # scroll queue 
        ques = [Queue.Queue() for _ in range(2)]
        zero_indegree_num = 0
369
        for op in used_ops:
370 371 372
            if len(op.get_input_ops()) == 0:
                zero_indegree_num += 1
        if zero_indegree_num != 1:
B
barriery 已提交
373 374
            _LOGGER.critical("Failed to topo sort: DAG contains "
                             "multiple RequestOps")
375
            os._exit(-1)
376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
        last_op = response_op.get_input_ops()[0]
        ques[que_idx].put(last_op)

        # topo sort to get dag_views
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                sorted_op_num += 1
                for pred_op in op.get_input_ops():
                    out_degree_num[pred_op.name] -= 1
                    if out_degree_num[pred_op.name] == 0:
                        next_que.put(pred_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
398
        if sorted_op_num < len(used_ops):
B
barriery 已提交
399
            _LOGGER.critical("Failed to topo sort: not legal DAG")
400
            os._exit(-1)
401 402 403

        return dag_views, last_op

404
    def _build_dag(self, response_op):
405
        if response_op is None:
B
barriery 已提交
406 407
            _LOGGER.critical("Failed to build DAG: ResponseOp"
                             " has not been set.")
408
            os._exit(-1)
409
        used_ops, out_degree_ops = self.get_use_ops(response_op)
410
        if not self._build_dag_each_worker:
B
barrierye 已提交
411 412 413 414 415
            _LOGGER.info("================= USED OP =================")
            for op in used_ops:
                if op.name != self._request_name:
                    _LOGGER.info(op.name)
            _LOGGER.info("-------------------------------------------")
416
        if len(used_ops) <= 1:
417
            _LOGGER.critical(
B
barriery 已提交
418 419
                "Failed to build DAG: besides RequestOp and ResponseOp, "
                "there should be at least one Op in DAG.")
420
            os._exit(-1)
B
barriery 已提交
421 422
        if self._build_dag_each_worker:
            _LOGGER.info("Because `build_dag_each_worker` mode is used, "
B
barriery 已提交
423 424
                         "Auto-batching is set to the default config: "
                         "batch_size=1, auto_batching_timeout=None")
B
barriery 已提交
425 426
            for op in used_ops:
                op.use_default_auto_batching_config()
427 428 429

        dag_views, last_op = self._topo_sort(used_ops, response_op,
                                             out_degree_ops)
B
barrierye 已提交
430
        dag_views = list(reversed(dag_views))
431 432
        if not self._build_dag_each_worker:
            _LOGGER.debug("================== DAG ====================")
B
barrierye 已提交
433
            for idx, view in enumerate(dag_views):
434
                _LOGGER.debug("(VIEW {})".format(idx))
B
barrierye 已提交
435
                for op in view:
436
                    _LOGGER.debug("  [{}]".format(op.name))
B
barrierye 已提交
437
                    for out_op in out_degree_ops[op.name]:
438 439
                        _LOGGER.debug("    - {}".format(out_op.name))
            _LOGGER.debug("-------------------------------------------")
440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512

        # create channels and virtual ops
        virtual_op_name_gen = NameGenerator("vir")
        channel_name_gen = NameGenerator("chl")
        virtual_ops = []
        channels = []
        input_channel = None
        actual_view = None
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
            if actual_view is None:
                actual_view = view
            actual_next_view = []
            pred_op_of_next_view_op = {}
            for op in actual_view:
                # find actual succ op in next view and create virtual op
                for succ_op in out_degree_ops[op.name]:
                    if succ_op in next_view:
                        if succ_op not in actual_next_view:
                            actual_next_view.append(succ_op)
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
                        # create virtual op
                        virtual_op = self._gen_virtual_op(virtual_op_name_gen)
                        virtual_ops.append(virtual_op)
                        out_degree_ops[virtual_op.name] = [succ_op]
                        actual_next_view.append(virtual_op)
                        pred_op_of_next_view_op[virtual_op.name] = [op]
                        virtual_op.add_virtual_pred_op(op)
            actual_view = actual_next_view
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
                if op.name in processed_op:
                    continue
                channel = self._gen_channel(channel_name_gen)
                channels.append(channel)
                op.add_input_channel(channel)
                pred_ops = pred_op_of_next_view_op[op.name]
                if v_idx == 0:
                    input_channel = channel
                else:
                    # if pred_op is virtual op, it will use ancestors as producers to channel
                    for pred_op in pred_ops:
                        pred_op.add_output_channel(channel)
                processed_op.add(op.name)
                # find same input op to combine channel
                for other_op in actual_next_view[o_idx + 1:]:
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
        output_channel = self._gen_channel(channel_name_gen)
        channels.append(output_channel)
        last_op.add_output_channel(output_channel)

        pack_func, unpack_func = None, None
        pack_func = response_op.pack_response_package

        actual_ops = virtual_ops
513
        for op in used_ops:
514 515 516 517 518 519
            if len(op.get_input_ops()) == 0:
                unpack_func = op.unpack_request_package
                continue
            actual_ops.append(op)

        for c in channels:
B
barriery 已提交
520
            _LOGGER.debug("Channel({}):\n\t- producers: {}\n\t- consumers: {}"
B
barriery 已提交
521
                          .format(c.name, c.get_producers(), c.get_consumers()))
522 523 524 525

        return (actual_ops, channels, input_channel, output_channel, pack_func,
                unpack_func)

B
barriery 已提交
526 527 528
    def get_channels(self):
        return self._channels

529 530
    def build(self):
        (actual_ops, channels, input_channel, output_channel, pack_func,
531
         unpack_func) = self._build_dag(self._response_op)
B
barriery 已提交
532
        _LOGGER.info("[DAG] Succ build DAG")
533 534 535 536 537 538 539 540 541 542 543 544 545

        self._actual_ops = actual_ops
        self._channels = channels
        self._input_channel = input_channel
        self._output_channel = output_channel
        self._pack_func = pack_func
        self._unpack_func = unpack_func

        return self._input_channel, self._output_channel, self._pack_func, self._unpack_func

    def start(self):
        self._threads_or_proces = []
        for op in self._actual_ops:
B
barrierye 已提交
546
            op.use_profiler(self._use_profile)
B
barriery 已提交
547
            op.set_tracer(self._tracer)
B
barrierye 已提交
548
            if self._is_thread_op:
549
                self._threads_or_proces.extend(
550 551
                    op.start_with_thread(self._client_type))
            else:
552
                self._threads_or_proces.extend(
553
                    op.start_with_process(self._client_type))
B
barriery 已提交
554 555
        _LOGGER.info("[DAG] start")

556 557 558 559 560 561 562 563 564 565
        # not join yet
        return self._threads_or_proces

    def join(self):
        for x in self._threads_or_proces:
            x.join()

    def stop(self):
        for chl in self._channels:
            chl.stop()
566 567 568
        for op in self._actual_ops:
            op.clean_input_channel()
            op.clean_output_channels()