dag.py 21.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
import sys
B
barrierye 已提交
18
import copy
19 20 21 22 23 24 25 26 27 28
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
import os
import logging

from .operator import Op, RequestOp, ResponseOp, VirtualOp
B
barrierye 已提交
29 30
from .channel import (ThreadChannel, ProcessChannel, ChannelData,
                      ChannelDataEcode, ChannelDataType, ChannelStopError)
B
barrierye 已提交
31
from .profiler import TimeProfiler
32 33 34 35 36 37
from .util import NameGenerator

_LOGGER = logging.getLogger()


class DAGExecutor(object):
38 39 40 41 42 43 44
    def __init__(self, response_op, dag_conf):
        self._retry = dag_conf["retry"]
        client_type = dag_conf["client_type"]
        self._server_use_profile = dag_conf["use_profile"]
        channel_size = dag_conf["channel_size"]
        self._is_thread_op = dag_conf["is_thread_op"]
        build_dag_each_worker = dag_conf["build_dag_each_worker"]
B
barrierye 已提交
45

B
barrierye 已提交
46
        self.name = "@G"
B
barrierye 已提交
47
        self._profiler = TimeProfiler()
B
barrierye 已提交
48
        self._profiler.enable(True)
B
barrierye 已提交
49

B
barrierye 已提交
50 51
        self._dag = DAG(self.name, response_op, self._server_use_profile,
                        self._is_thread_op, client_type, channel_size,
52
                        build_dag_each_worker)
B
barrierye 已提交
53 54
        (in_channel, out_channel, pack_rpc_func,
         unpack_rpc_func) = self._dag.build()
55 56 57
        self._dag.start()

        self._set_in_channel(in_channel)
58
        self._set_out_channel(out_channel)
59 60 61 62 63 64
        self._pack_rpc_func = pack_rpc_func
        self._unpack_rpc_func = unpack_rpc_func

        self._id_lock = threading.Lock()
        self._id_counter = 0
        self._reset_max_id = 1000000000000000000
B
barrierye 已提交
65 66
        self._cv_pool = {}
        self._cv_for_cv_pool = threading.Condition()
67
        self._fetch_buffer = {}
68 69
        self._recive_func = None

B
barrierye 已提交
70 71 72
        self._client_profile_key = "pipeline.profile"
        self._client_profile_value = "1"

73
    def start(self):
74 75 76
        self._recive_func = threading.Thread(
            target=DAGExecutor._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
B
barriery 已提交
77
        _LOGGER.debug("[DAG Executor] start recive thread")
78 79 80 81

    def stop(self):
        self._dag.stop()
        self._dag.join()
B
barriery 已提交
82
        _LOGGER.info("[DAG Executor] succ stop")
83 84

    def _get_next_data_id(self):
B
bug fix  
barriery 已提交
85
        data_id = None
86 87 88
        with self._id_lock:
            if self._id_counter >= self._reset_max_id:
                self._id_counter -= self._reset_max_id
B
bug fix  
barriery 已提交
89
            data_id = self._id_counter
90
            self._id_counter += 1
B
bug fix  
barriery 已提交
91 92 93
        cond_v = threading.Condition()
        with self._cv_for_cv_pool:
            self._cv_pool[data_id] = cond_v
94
            self._fetch_buffer[data_id] = None
B
bug fix  
barriery 已提交
95
        return data_id, cond_v
96 97 98

    def _set_in_channel(self, in_channel):
        if not isinstance(in_channel, (ThreadChannel, ProcessChannel)):
99 100 101
            _LOGGER.critical("[DAG Executor] in_channel must be Channel"
                             " type, but get {}".format(type(in_channel)))
            os._exit(-1)
102 103 104 105 106
        in_channel.add_producer(self.name)
        self._in_channel = in_channel

    def _set_out_channel(self, out_channel):
        if not isinstance(out_channel, (ThreadChannel, ProcessChannel)):
107 108 109
            _LOGGER.critical("[DAG Executor]iout_channel must be Channel"
                             " type, but get {}".format(type(out_channel)))
            os._exit(-1)
110 111 112 113
        out_channel.add_consumer(self.name)
        self._out_channel = out_channel

    def _recive_out_channel_func(self):
B
barrierye 已提交
114
        cv = None
B
barrierye 已提交
115 116 117 118
        while True:
            try:
                channeldata_dict = self._out_channel.front(self.name)
            except ChannelStopError:
119
                _LOGGER.info("[DAG Executor] channel stop.")
B
barrierye 已提交
120 121 122 123 124 125 126
                with self._cv_for_cv_pool:
                    for data_id, cv in self._cv_pool.items():
                        closed_errror_data = ChannelData(
                            ecode=ChannelDataEcode.CLOSED_ERROR.value,
                            error_info="dag closed.",
                            data_id=data_id)
                        with cv:
127
                            self._fetch_buffer[data_id] = closed_errror_data
B
barrierye 已提交
128 129 130
                            cv.notify_all()
                break

131
            if len(channeldata_dict) != 1:
132
                _LOGGER.critical(
B
barriery 已提交
133
                    "[DAG Executor] out_channel cannot have multiple input ops")
134 135 136
                os._exit(-1)
            (_, channeldata), = channeldata_dict.items()
            if not isinstance(channeldata, ChannelData):
137
                _LOGGER.critical(
B
barriery 已提交
138 139
                    '[DAG Executor] data must be ChannelData type, but get {}'
                    .format(type(channeldata)))
B
barriery 已提交
140
                os._exit(-1)
B
barrierye 已提交
141 142

            data_id = channeldata.id
B
barriery 已提交
143
            _LOGGER.debug("recive thread fetch data[{}]".format(data_id))
B
barrierye 已提交
144
            with self._cv_for_cv_pool:
145 146 147 148
                cond_v = self._cv_pool[data_id]
            with cond_v:
                self._fetch_buffer[data_id] = channeldata
                cond_v.notify_all()
149

B
bug fix  
barriery 已提交
150
    def _get_channeldata_from_fetch_buffer(self, data_id, cond_v):
151 152
        ready_data = None

B
bug fix  
barriery 已提交
153
        with cond_v:
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
            with self._cv_for_cv_pool:
                if self._fetch_buffer[data_id] is not None:
                    # The requested data is already ready
                    ready_data = self._fetch_buffer[data_id]
                    self._cv_pool.pop(data_id)
                    self._fetch_buffer.pop(data_id)
            if ready_data is None:
                # Wait for data ready
                cond_v.wait()
                with self._cv_for_cv_pool:
                    ready_data = self._fetch_buffer[data_id]
                    self._cv_pool.pop(data_id)
                    self._fetch_buffer.pop(data_id)
        _LOGGER.debug("resp thread get resp data[{}]".format(data_id))
        return ready_data
169

B
barrierye 已提交
170
    def _pack_channeldata(self, rpc_request, data_id):
171 172 173 174
        dictdata = None
        try:
            dictdata = self._unpack_rpc_func(rpc_request)
        except Exception as e:
B
barriery 已提交
175
            _LOGGER.error("parse RPC package to data[{}] Error: {}"
B
barriery 已提交
176
                          .format(data_id, e))
177 178 179
            return ChannelData(
                ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
                error_info="rpc package error: {}".format(e),
B
barrierye 已提交
180
                data_id=data_id)
181
        else:
B
barrierye 已提交
182 183 184 185 186 187 188
            # because unpack_rpc_func is rewritten by user, we need
            # to look for client_profile_key field in rpc_request
            profile_value = None
            for idx, key in enumerate(rpc_request.key):
                if key == self._client_profile_key:
                    profile_value = rpc_request.value[idx]
                    break
B
barriery 已提交
189
            client_need_profile = (profile_value == self._client_profile_value)
B
barriery 已提交
190 191
            _LOGGER.debug("request[{}] need profile: {}".format(
                data_id, client_need_profile))
192 193 194
            return ChannelData(
                datatype=ChannelDataType.DICT.value,
                dictdata=dictdata,
B
barrierye 已提交
195
                data_id=data_id,
B
barriery 已提交
196
                client_need_profile=client_need_profile)
197 198

    def call(self, rpc_request):
B
bug fix  
barriery 已提交
199
        data_id, cond_v = self._get_next_data_id()
200
        _LOGGER.debug("generate Request id: {}".format(data_id))
B
barriery 已提交
201

B
barrierye 已提交
202
        if not self._is_thread_op:
B
barrierye 已提交
203 204 205
            self._profiler.record("call_{}#DAG-{}_0".format(data_id, data_id))
        else:
            self._profiler.record("call_{}#DAG_0".format(data_id))
B
barrierye 已提交
206

207
        _LOGGER.debug("try parse RPC request to channeldata[{}]".format(
B
barriery 已提交
208
            data_id))
B
barrierye 已提交
209 210 211
        self._profiler.record("prepack_{}#{}_0".format(data_id, self.name))
        req_channeldata = self._pack_channeldata(rpc_request, data_id)
        self._profiler.record("prepack_{}#{}_1".format(data_id, self.name))
212 213 214

        resp_channeldata = None
        for i in range(self._retry):
B
barriery 已提交
215
            _LOGGER.debug("push data[{}] into Graph engine".format(data_id))
B
barrierye 已提交
216 217 218
            try:
                self._in_channel.push(req_channeldata, self.name)
            except ChannelStopError:
B
barriery 已提交
219
                _LOGGER.debug("[DAG Executor] channel stop.")
B
bug fix  
barriery 已提交
220 221
                with self._cv_for_cv_pool:
                    self._cv_pool.pop(data_id)
B
barrierye 已提交
222 223 224 225 226
                return self._pack_for_rpc_resp(
                    ChannelData(
                        ecode=ChannelDataEcode.CLOSED_ERROR.value,
                        error_info="dag closed.",
                        data_id=data_id))
227

228
            _LOGGER.debug("wait Graph engine for data[{}]...".format(data_id))
B
bug fix  
barriery 已提交
229 230
            resp_channeldata = self._get_channeldata_from_fetch_buffer(data_id,
                                                                       cond_v)
231 232

            if resp_channeldata.ecode == ChannelDataEcode.OK.value:
233
                _LOGGER.debug("request[{}] succ predict".format(data_id))
234
                break
B
barriery 已提交
235
            else:
236 237
                _LOGGER.warning("request[{}] predict failed: {}"
                                .format(data_id, resp_channeldata.error_info))
B
barriery 已提交
238 239 240
                if resp_channeldata.ecode != ChannelDataEcode.TIMEOUT.value:
                    break

241
            if i + 1 < self._retry:
242 243
                _LOGGER.warning("retry({}/{}) data[{}]".format(
                    i + 1, self._retry, data_id))
244

245
        _LOGGER.debug("unpack channeldata[{}] into RPC response".format(
B
barriery 已提交
246
            data_id))
B
barrierye 已提交
247
        self._profiler.record("postpack_{}#{}_0".format(data_id, self.name))
248
        rpc_resp = self._pack_for_rpc_resp(resp_channeldata)
B
barrierye 已提交
249
        self._profiler.record("postpack_{}#{}_1".format(data_id, self.name))
B
barrierye 已提交
250
        if not self._is_thread_op:
B
barrierye 已提交
251 252 253
            self._profiler.record("call_{}#DAG-{}_1".format(data_id, data_id))
        else:
            self._profiler.record("call_{}#DAG_1".format(data_id))
B
barrierye 已提交
254 255 256 257 258 259 260 261

        profile_str = self._profiler.gen_profile_str()
        if self._server_use_profile:
            sys.stderr.write(profile_str)

        # add profile info into rpc_resp
        profile_value = ""
        if resp_channeldata.client_need_profile:
B
barrierye 已提交
262 263 264
            profile_set = resp_channeldata.profile_data_set
            profile_set.add(profile_str)
            profile_value = "".join(list(profile_set))
B
barrierye 已提交
265 266 267
        rpc_resp.key.append(self._client_profile_key)
        rpc_resp.value.append(profile_value)

268 269 270 271 272 273 274
        return rpc_resp

    def _pack_for_rpc_resp(self, channeldata):
        return self._pack_rpc_func(channeldata)


class DAG(object):
B
barrierye 已提交
275
    def __init__(self, request_name, response_op, use_profile, is_thread_op,
276
                 client_type, channel_size, build_dag_each_worker):
B
barrierye 已提交
277
        self._request_name = request_name
278
        self._response_op = response_op
B
barrierye 已提交
279
        self._use_profile = use_profile
B
barrierye 已提交
280
        self._is_thread_op = is_thread_op
281 282
        self._channel_size = channel_size
        self._client_type = client_type
B
barriery 已提交
283
        self._build_dag_each_worker = build_dag_each_worker
B
barrierye 已提交
284
        if not self._is_thread_op:
285
            self._manager = multiprocessing.Manager()
B
barriery 已提交
286
        _LOGGER.info("[DAG] succ init")
287 288 289

    def get_use_ops(self, response_op):
        unique_names = set()
290
        used_ops = set()
291 292 293 294 295 296 297 298 299 300
        succ_ops_of_use_op = {}  # {op_name: succ_ops}
        que = Queue.Queue()
        que.put(response_op)
        while que.qsize() != 0:
            op = que.get()
            for pred_op in op.get_input_ops():
                if pred_op.name not in succ_ops_of_use_op:
                    succ_ops_of_use_op[pred_op.name] = []
                if op != response_op:
                    succ_ops_of_use_op[pred_op.name].append(op)
301
                if pred_op not in used_ops:
302
                    que.put(pred_op)
303
                    used_ops.add(pred_op)
304 305
                    # check the name of op is globally unique
                    if pred_op.name in unique_names:
306 307 308
                        _LOGGER.critical("the name of Op must be unique: {}".
                                         format(pred_op.name))
                        os._exit(-1)
309
                    unique_names.add(pred_op.name)
310
        return used_ops, succ_ops_of_use_op
311 312 313

    def _gen_channel(self, name_gen):
        channel = None
B
barrierye 已提交
314
        if self._is_thread_op:
315 316 317 318 319
            channel = ThreadChannel(
                name=name_gen.next(), maxsize=self._channel_size)
        else:
            channel = ProcessChannel(
                self._manager, name=name_gen.next(), maxsize=self._channel_size)
B
barriery 已提交
320
        _LOGGER.debug("[DAG] gen Channel: {}".format(channel.name))
321 322 323
        return channel

    def _gen_virtual_op(self, name_gen):
B
barriery 已提交
324 325 326
        vir_op = VirtualOp(name=name_gen.next())
        _LOGGER.debug("[DAG] gen VirtualOp: {}".format(vir_op.name))
        return vir_op
327 328 329 330 331 332 333 334 335

    def _topo_sort(self, used_ops, response_op, out_degree_ops):
        out_degree_num = {
            name: len(ops)
            for name, ops in out_degree_ops.items()
        }
        que_idx = 0  # scroll queue 
        ques = [Queue.Queue() for _ in range(2)]
        zero_indegree_num = 0
336
        for op in used_ops:
337 338 339
            if len(op.get_input_ops()) == 0:
                zero_indegree_num += 1
        if zero_indegree_num != 1:
340 341
            _LOGGER.critical("DAG contains multiple RequestOps")
            os._exit(-1)
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
        last_op = response_op.get_input_ops()[0]
        ques[que_idx].put(last_op)

        # topo sort to get dag_views
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                sorted_op_num += 1
                for pred_op in op.get_input_ops():
                    out_degree_num[pred_op.name] -= 1
                    if out_degree_num[pred_op.name] == 0:
                        next_que.put(pred_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
364
        if sorted_op_num < len(used_ops):
365 366
            _LOGGER.critical("not legal DAG")
            os._exit(-1)
367 368 369

        return dag_views, last_op

370
    def _build_dag(self, response_op):
371
        if response_op is None:
372 373
            _LOGGER.critical("ResponseOp has not been set.")
            os._exit(-1)
374
        used_ops, out_degree_ops = self.get_use_ops(response_op)
375
        if not self._build_dag_each_worker:
B
barrierye 已提交
376 377 378 379 380
            _LOGGER.info("================= USED OP =================")
            for op in used_ops:
                if op.name != self._request_name:
                    _LOGGER.info(op.name)
            _LOGGER.info("-------------------------------------------")
381
        if len(used_ops) <= 1:
382
            _LOGGER.critical(
383 384
                "Besides RequestOp and ResponseOp, there should be at least one Op in DAG."
            )
385
            os._exit(-1)
B
barriery 已提交
386 387
        if self._build_dag_each_worker:
            _LOGGER.info("Because `build_dag_each_worker` mode is used, "
B
barriery 已提交
388 389
                         "Auto-batching is set to the default config: "
                         "batch_size=1, auto_batching_timeout=None")
B
barriery 已提交
390 391
            for op in used_ops:
                op.use_default_auto_batching_config()
392 393 394

        dag_views, last_op = self._topo_sort(used_ops, response_op,
                                             out_degree_ops)
B
barrierye 已提交
395
        dag_views = list(reversed(dag_views))
396 397
        if not self._build_dag_each_worker:
            _LOGGER.debug("================== DAG ====================")
B
barrierye 已提交
398
            for idx, view in enumerate(dag_views):
399
                _LOGGER.debug("(VIEW {})".format(idx))
B
barrierye 已提交
400
                for op in view:
401
                    _LOGGER.debug("  [{}]".format(op.name))
B
barrierye 已提交
402
                    for out_op in out_degree_ops[op.name]:
403 404
                        _LOGGER.debug("    - {}".format(out_op.name))
            _LOGGER.debug("-------------------------------------------")
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445

        # create channels and virtual ops
        virtual_op_name_gen = NameGenerator("vir")
        channel_name_gen = NameGenerator("chl")
        virtual_ops = []
        channels = []
        input_channel = None
        actual_view = None
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
            if actual_view is None:
                actual_view = view
            actual_next_view = []
            pred_op_of_next_view_op = {}
            for op in actual_view:
                # find actual succ op in next view and create virtual op
                for succ_op in out_degree_ops[op.name]:
                    if succ_op in next_view:
                        if succ_op not in actual_next_view:
                            actual_next_view.append(succ_op)
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
                        # create virtual op
                        virtual_op = self._gen_virtual_op(virtual_op_name_gen)
                        virtual_ops.append(virtual_op)
                        out_degree_ops[virtual_op.name] = [succ_op]
                        actual_next_view.append(virtual_op)
                        pred_op_of_next_view_op[virtual_op.name] = [op]
                        virtual_op.add_virtual_pred_op(op)
            actual_view = actual_next_view
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
                if op.name in processed_op:
                    continue
                channel = self._gen_channel(channel_name_gen)
                channels.append(channel)
B
barriery 已提交
446
                _LOGGER.debug("[DAG] Channel({}) => Op({})"
B
barriery 已提交
447
                              .format(channel.name, op.name))
448 449 450 451 452 453 454
                op.add_input_channel(channel)
                pred_ops = pred_op_of_next_view_op[op.name]
                if v_idx == 0:
                    input_channel = channel
                else:
                    # if pred_op is virtual op, it will use ancestors as producers to channel
                    for pred_op in pred_ops:
B
barriery 已提交
455
                        _LOGGER.debug("[DAG] Op({}) => Channel({})"
B
barriery 已提交
456
                                      .format(pred_op.name, channel.name))
457 458 459 460 461 462 463 464 465 466 467 468 469 470 471
                        pred_op.add_output_channel(channel)
                processed_op.add(op.name)
                # find same input op to combine channel
                for other_op in actual_next_view[o_idx + 1:]:
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
B
barriery 已提交
472
                        _LOGGER.debug("[DAG] Channel({}) => Op({})"
B
barriery 已提交
473
                                      .format(channel.name, other_op.name))
474 475 476 477 478 479 480 481 482 483
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
        output_channel = self._gen_channel(channel_name_gen)
        channels.append(output_channel)
        last_op.add_output_channel(output_channel)

        pack_func, unpack_func = None, None
        pack_func = response_op.pack_response_package

        actual_ops = virtual_ops
484
        for op in used_ops:
485 486 487 488 489 490
            if len(op.get_input_ops()) == 0:
                unpack_func = op.unpack_request_package
                continue
            actual_ops.append(op)

        for c in channels:
491
            _LOGGER.debug("Channel({}):\n\t-producers: {}\n\t-consumers: {}"
B
barriery 已提交
492
                          .format(c.name, c.get_producers(), c.get_consumers()))
493 494 495 496 497 498

        return (actual_ops, channels, input_channel, output_channel, pack_func,
                unpack_func)

    def build(self):
        (actual_ops, channels, input_channel, output_channel, pack_func,
499
         unpack_func) = self._build_dag(self._response_op)
B
barriery 已提交
500
        _LOGGER.info("[DAG] succ build dag")
501 502 503 504 505 506 507 508 509 510 511 512 513

        self._actual_ops = actual_ops
        self._channels = channels
        self._input_channel = input_channel
        self._output_channel = output_channel
        self._pack_func = pack_func
        self._unpack_func = unpack_func

        return self._input_channel, self._output_channel, self._pack_func, self._unpack_func

    def start(self):
        self._threads_or_proces = []
        for op in self._actual_ops:
B
barrierye 已提交
514
            op.use_profiler(self._use_profile)
B
barrierye 已提交
515
            if self._is_thread_op:
516
                self._threads_or_proces.extend(
517 518
                    op.start_with_thread(self._client_type))
            else:
519
                self._threads_or_proces.extend(
520
                    op.start_with_process(self._client_type))
B
barriery 已提交
521 522
        _LOGGER.info("[DAG] start")

523 524 525 526 527 528 529 530 531 532
        # not join yet
        return self._threads_or_proces

    def join(self):
        for x in self._threads_or_proces:
            x.join()

    def stop(self):
        for chl in self._channels:
            chl.stop()
533 534 535
        for op in self._actual_ops:
            op.clean_input_channel()
            op.clean_output_channels()