dag.py 21.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
import sys
B
barrierye 已提交
18
import copy
19 20 21 22 23 24 25 26 27 28
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
import os
import logging

from .operator import Op, RequestOp, ResponseOp, VirtualOp
B
barrierye 已提交
29 30
from .channel import (ThreadChannel, ProcessChannel, ChannelData,
                      ChannelDataEcode, ChannelDataType, ChannelStopError)
B
barrierye 已提交
31
from .profiler import TimeProfiler
32 33 34 35 36 37
from .util import NameGenerator

_LOGGER = logging.getLogger()


class DAGExecutor(object):
B
barrierye 已提交
38
    def __init__(self, response_op, dag_config, show_info):
B
barriery 已提交
39 40 41 42 43 44 45
        default_conf = {
            "retry": 1,
            "client_type": "brpc",
            "use_profile": False,
            "channel_size": 0,
            "is_thread_op": True
        }
B
barrierye 已提交
46

B
barriery 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
        for key, val in default_conf.items():
            if dag_config.get(key) is None:
                _LOGGER.warning("[CONF] {} not set, use default: {}"
                        .format(key, val))
                dag_config[key] = val

        self._retry = dag_config["retry"]
        client_type = dag_config["client_type"]
        self._server_use_profile = dag_config["use_profile"]
        channel_size = dag_config["channel_size"]
        self._is_thread_op = dag_config["is_thread_op"]
        build_dag_each_worker = dag_config["build_dag_each_worker"]

        if show_info:
            _LOGGER.info("=============== DAGExecutor ===============")
            for key in default_conf.keys():
                _LOGGER.info("{}: {}".format(key, dag_config[key]))            
B
barrierye 已提交
64
            _LOGGER.info("-------------------------------------------")
B
barrierye 已提交
65

B
barrierye 已提交
66
        self.name = "@G"
B
barrierye 已提交
67
        self._profiler = TimeProfiler()
B
barrierye 已提交
68
        self._profiler.enable(True)
B
barrierye 已提交
69

B
barrierye 已提交
70 71
        self._dag = DAG(self.name, response_op, self._server_use_profile,
                        self._is_thread_op, client_type, channel_size,
B
barriery 已提交
72
                        show_info, build_dag_each_worker)
B
barrierye 已提交
73 74
        (in_channel, out_channel, pack_rpc_func,
         unpack_rpc_func) = self._dag.build()
75 76 77
        self._dag.start()

        self._set_in_channel(in_channel)
78
        self._set_out_channel(out_channel)
79 80 81 82 83 84
        self._pack_rpc_func = pack_rpc_func
        self._unpack_rpc_func = unpack_rpc_func

        self._id_lock = threading.Lock()
        self._id_counter = 0
        self._reset_max_id = 1000000000000000000
B
barrierye 已提交
85 86 87
        self._cv_pool = {}
        self._cv_for_cv_pool = threading.Condition()
        self._fetch_buffer = None
88 89
        self._recive_func = None

B
barrierye 已提交
90 91 92
        self._client_profile_key = "pipeline.profile"
        self._client_profile_value = "1"

93
    def start(self):
94 95 96
        self._recive_func = threading.Thread(
            target=DAGExecutor._recive_out_channel_func, args=(self, ))
        self._recive_func.start()
B
barriery 已提交
97
        _LOGGER.debug("[DAG Executor] start recive thread")
98 99 100 101

    def stop(self):
        self._dag.stop()
        self._dag.join()
B
barriery 已提交
102
        _LOGGER.info("[DAG Executor] succ stop")
103 104 105 106 107 108 109 110 111 112 113

    def _get_next_data_id(self):
        with self._id_lock:
            if self._id_counter >= self._reset_max_id:
                self._id_counter -= self._reset_max_id
            self._id_counter += 1
            return self._id_counter - 1

    def _set_in_channel(self, in_channel):
        if not isinstance(in_channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
B
barriery 已提交
114
                "in_channel must be Channel type, but get {}".format(
B
bug fix  
barriery 已提交
115
                    type(in_channel)))
116 117 118 119 120 121
        in_channel.add_producer(self.name)
        self._in_channel = in_channel

    def _set_out_channel(self, out_channel):
        if not isinstance(out_channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
B
barriery 已提交
122
                "iout_channel must be Channel type, but get {}".format(
B
bug fix  
barriery 已提交
123
                    type(out_channel)))
124 125 126 127
        out_channel.add_consumer(self.name)
        self._out_channel = out_channel

    def _recive_out_channel_func(self):
B
barrierye 已提交
128
        cv = None
B
barrierye 已提交
129 130 131 132
        while True:
            try:
                channeldata_dict = self._out_channel.front(self.name)
            except ChannelStopError:
B
barriery 已提交
133
                _LOGGER.debug("[DAG Executor] channel stop.")
B
barrierye 已提交
134 135 136 137 138 139 140 141 142 143 144
                with self._cv_for_cv_pool:
                    for data_id, cv in self._cv_pool.items():
                        closed_errror_data = ChannelData(
                            ecode=ChannelDataEcode.CLOSED_ERROR.value,
                            error_info="dag closed.",
                            data_id=data_id)
                        with cv:
                            self._fetch_buffer = closed_errror_data
                            cv.notify_all()
                break

145
            if len(channeldata_dict) != 1:
B
barriery 已提交
146
                _LOGGER.error("[DAG Executor] out_channel cannot have multiple input ops")
147 148 149
                os._exit(-1)
            (_, channeldata), = channeldata_dict.items()
            if not isinstance(channeldata, ChannelData):
B
barriery 已提交
150
                _LOGGER.error('[DAG Executor] data must be ChannelData type, but get {}'
B
bug fix  
barriery 已提交
151
                        .format(type(channeldata)))
B
barriery 已提交
152
                os._exit(-1)
B
barrierye 已提交
153 154

            data_id = channeldata.id
B
barriery 已提交
155
            _LOGGER.debug("recive thread fetch data[{}]".format(data_id))
B
barrierye 已提交
156 157 158 159 160
            with self._cv_for_cv_pool:
                cv = self._cv_pool[data_id]
            with cv:
                self._fetch_buffer = channeldata
                cv.notify_all()
161 162 163

    def _get_channeldata_from_fetch_buffer(self, data_id):
        resp = None
B
barrierye 已提交
164 165 166 167 168
        cv = threading.Condition()
        with self._cv_for_cv_pool:
            self._cv_pool[data_id] = cv
        with cv:
            cv.wait()
B
barrierye 已提交
169
        with self._cv_for_cv_pool:
B
barriery 已提交
170 171
            resp = copy.deepcopy(self._fetch_buffer)
            _LOGGER.debug("resp thread get resp data[{}]".format(data_id))
B
barrierye 已提交
172
            self._cv_pool.pop(data_id)
173 174
        return resp

B
barrierye 已提交
175
    def _pack_channeldata(self, rpc_request, data_id):
176 177 178 179
        dictdata = None
        try:
            dictdata = self._unpack_rpc_func(rpc_request)
        except Exception as e:
B
barriery 已提交
180 181
            _LOGGER.error("parse RPC package to data[{}] Error: {}"
                    .format(data_id, e))
182 183 184
            return ChannelData(
                ecode=ChannelDataEcode.RPC_PACKAGE_ERROR.value,
                error_info="rpc package error: {}".format(e),
B
barrierye 已提交
185
                data_id=data_id)
186
        else:
B
barrierye 已提交
187 188 189 190 191 192 193
            # because unpack_rpc_func is rewritten by user, we need
            # to look for client_profile_key field in rpc_request
            profile_value = None
            for idx, key in enumerate(rpc_request.key):
                if key == self._client_profile_key:
                    profile_value = rpc_request.value[idx]
                    break
B
barriery 已提交
194 195
            client_need_profile = (profile_value == self._client_profile_value)
            _LOGGER.debug("request[{}] need profile: {}".format(data_id, client_need_profile))
196 197 198
            return ChannelData(
                datatype=ChannelDataType.DICT.value,
                dictdata=dictdata,
B
barrierye 已提交
199
                data_id=data_id,
B
barriery 已提交
200
                client_need_profile=client_need_profile)
201 202

    def call(self, rpc_request):
B
barrierye 已提交
203
        data_id = self._get_next_data_id()
B
barriery 已提交
204 205
        _LOGGER.debug("generate id: {}".format(data_id))

B
barrierye 已提交
206
        if not self._is_thread_op:
B
barrierye 已提交
207 208 209
            self._profiler.record("call_{}#DAG-{}_0".format(data_id, data_id))
        else:
            self._profiler.record("call_{}#DAG_0".format(data_id))
B
barrierye 已提交
210

B
barriery 已提交
211
        _LOGGER.debug("try parse RPC package to channeldata[{}]".format(data_id))
B
barrierye 已提交
212 213 214
        self._profiler.record("prepack_{}#{}_0".format(data_id, self.name))
        req_channeldata = self._pack_channeldata(rpc_request, data_id)
        self._profiler.record("prepack_{}#{}_1".format(data_id, self.name))
215 216 217

        resp_channeldata = None
        for i in range(self._retry):
B
barriery 已提交
218
            _LOGGER.debug("push data[{}] into Graph engine".format(data_id))
B
barrierye 已提交
219 220 221
            try:
                self._in_channel.push(req_channeldata, self.name)
            except ChannelStopError:
B
barriery 已提交
222
                _LOGGER.debug("[DAG Executor] channel stop.")
B
barrierye 已提交
223 224 225 226 227
                return self._pack_for_rpc_resp(
                    ChannelData(
                        ecode=ChannelDataEcode.CLOSED_ERROR.value,
                        error_info="dag closed.",
                        data_id=data_id))
228

B
barriery 已提交
229
            _LOGGER.debug("wait for Graph engine for data[{}]...".format(data_id))
230 231 232
            resp_channeldata = self._get_channeldata_from_fetch_buffer(data_id)

            if resp_channeldata.ecode == ChannelDataEcode.OK.value:
B
barriery 已提交
233
                _LOGGER.debug("Graph engine predict data[{}] succ".format(data_id))
234
                break
B
barriery 已提交
235 236 237 238 239 240
            else:
                _LOGGER.warn("Graph engine predict data[{}] failed: {}"
                        .format(data_id, resp_channeldata.error_info))
                if resp_channeldata.ecode != ChannelDataEcode.TIMEOUT.value:
                    break

241
            if i + 1 < self._retry:
B
barriery 已提交
242 243
                _LOGGER.warn("retry({}/{}) data[{}]".format(
                    i + 1, self._retry, data_id))
244

B
barriery 已提交
245
        _LOGGER.debug("unpack channeldata[{}] into RPC resp package".format(data_id))
B
barrierye 已提交
246
        self._profiler.record("postpack_{}#{}_0".format(data_id, self.name))
247
        rpc_resp = self._pack_for_rpc_resp(resp_channeldata)
B
barrierye 已提交
248
        self._profiler.record("postpack_{}#{}_1".format(data_id, self.name))
B
barrierye 已提交
249
        if not self._is_thread_op:
B
barrierye 已提交
250 251 252
            self._profiler.record("call_{}#DAG-{}_1".format(data_id, data_id))
        else:
            self._profiler.record("call_{}#DAG_1".format(data_id))
B
barrierye 已提交
253 254 255 256 257 258 259 260

        profile_str = self._profiler.gen_profile_str()
        if self._server_use_profile:
            sys.stderr.write(profile_str)

        # add profile info into rpc_resp
        profile_value = ""
        if resp_channeldata.client_need_profile:
B
barrierye 已提交
261 262 263
            profile_set = resp_channeldata.profile_data_set
            profile_set.add(profile_str)
            profile_value = "".join(list(profile_set))
B
barrierye 已提交
264 265 266
        rpc_resp.key.append(self._client_profile_key)
        rpc_resp.value.append(profile_value)

267 268 269 270 271 272 273
        return rpc_resp

    def _pack_for_rpc_resp(self, channeldata):
        return self._pack_rpc_func(channeldata)


class DAG(object):
B
barrierye 已提交
274
    def __init__(self, request_name, response_op, use_profile, is_thread_op,
B
barriery 已提交
275
                 client_type, channel_size, show_info, build_dag_each_worker):
B
barrierye 已提交
276
        self._request_name = request_name
277
        self._response_op = response_op
B
barrierye 已提交
278
        self._use_profile = use_profile
B
barrierye 已提交
279
        self._is_thread_op = is_thread_op
280 281
        self._channel_size = channel_size
        self._client_type = client_type
B
barrierye 已提交
282
        self._show_info = show_info
B
barriery 已提交
283
        self._build_dag_each_worker = build_dag_each_worker
B
barrierye 已提交
284
        if not self._is_thread_op:
285
            self._manager = multiprocessing.Manager()
B
barriery 已提交
286
        _LOGGER.info("[DAG] succ init")
287 288 289

    def get_use_ops(self, response_op):
        unique_names = set()
290
        used_ops = set()
291 292 293 294 295 296 297 298 299 300
        succ_ops_of_use_op = {}  # {op_name: succ_ops}
        que = Queue.Queue()
        que.put(response_op)
        while que.qsize() != 0:
            op = que.get()
            for pred_op in op.get_input_ops():
                if pred_op.name not in succ_ops_of_use_op:
                    succ_ops_of_use_op[pred_op.name] = []
                if op != response_op:
                    succ_ops_of_use_op[pred_op.name].append(op)
301
                if pred_op not in used_ops:
302
                    que.put(pred_op)
303
                    used_ops.add(pred_op)
304 305 306 307 308
                    # check the name of op is globally unique
                    if pred_op.name in unique_names:
                        raise Exception("the name of Op must be unique: {}".
                                        format(pred_op.name))
                    unique_names.add(pred_op.name)
309
        return used_ops, succ_ops_of_use_op
310 311 312

    def _gen_channel(self, name_gen):
        channel = None
B
barrierye 已提交
313
        if self._is_thread_op:
314 315 316 317 318
            channel = ThreadChannel(
                name=name_gen.next(), maxsize=self._channel_size)
        else:
            channel = ProcessChannel(
                self._manager, name=name_gen.next(), maxsize=self._channel_size)
B
barriery 已提交
319
        _LOGGER.debug("[DAG] gen Channel: {}".format(channel.name))
320 321 322
        return channel

    def _gen_virtual_op(self, name_gen):
B
barriery 已提交
323 324 325
        vir_op = VirtualOp(name=name_gen.next())
        _LOGGER.debug("[DAG] gen VirtualOp: {}".format(vir_op.name))
        return vir_op
326 327 328 329 330 331 332 333 334

    def _topo_sort(self, used_ops, response_op, out_degree_ops):
        out_degree_num = {
            name: len(ops)
            for name, ops in out_degree_ops.items()
        }
        que_idx = 0  # scroll queue 
        ques = [Queue.Queue() for _ in range(2)]
        zero_indegree_num = 0
335
        for op in used_ops:
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
            if len(op.get_input_ops()) == 0:
                zero_indegree_num += 1
        if zero_indegree_num != 1:
            raise Exception("DAG contains multiple input Ops")
        last_op = response_op.get_input_ops()[0]
        ques[que_idx].put(last_op)

        # topo sort to get dag_views
        dag_views = []
        sorted_op_num = 0
        while True:
            que = ques[que_idx]
            next_que = ques[(que_idx + 1) % 2]
            dag_view = []
            while que.qsize() != 0:
                op = que.get()
                dag_view.append(op)
                sorted_op_num += 1
                for pred_op in op.get_input_ops():
                    out_degree_num[pred_op.name] -= 1
                    if out_degree_num[pred_op.name] == 0:
                        next_que.put(pred_op)
            dag_views.append(dag_view)
            if next_que.qsize() == 0:
                break
            que_idx = (que_idx + 1) % 2
362
        if sorted_op_num < len(used_ops):
363 364 365 366
            raise Exception("not legal DAG")

        return dag_views, last_op

367
    def _build_dag(self, response_op):
368 369
        if response_op is None:
            raise Exception("response_op has not been set.")
370
        used_ops, out_degree_ops = self.get_use_ops(response_op)
B
barrierye 已提交
371 372 373 374 375 376
        if self._show_info:
            _LOGGER.info("================= USED OP =================")
            for op in used_ops:
                if op.name != self._request_name:
                    _LOGGER.info(op.name)
            _LOGGER.info("-------------------------------------------")
377
        if len(used_ops) <= 1:
378 379 380
            raise Exception(
                "Besides RequestOp and ResponseOp, there should be at least one Op in DAG."
            )
B
barriery 已提交
381 382
        if self._build_dag_each_worker:
            _LOGGER.info("Because `build_dag_each_worker` mode is used, "
B
bug fix  
barriery 已提交
383 384
                    "Auto-batching is set to the default config: "
                    "batch_size=1, auto_batching_timeout=None")
B
barriery 已提交
385 386
            for op in used_ops:
                op.use_default_auto_batching_config()
387 388 389

        dag_views, last_op = self._topo_sort(used_ops, response_op,
                                             out_degree_ops)
B
barrierye 已提交
390 391 392 393 394 395 396 397 398 399
        dag_views = list(reversed(dag_views))
        if self._show_info:
            _LOGGER.info("================== DAG ====================")
            for idx, view in enumerate(dag_views):
                _LOGGER.info("(VIEW {})".format(idx))
                for op in view:
                    _LOGGER.info("  [{}]".format(op.name))
                    for out_op in out_degree_ops[op.name]:
                        _LOGGER.info("    - {}".format(out_op.name))
            _LOGGER.info("-------------------------------------------")
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440

        # create channels and virtual ops
        virtual_op_name_gen = NameGenerator("vir")
        channel_name_gen = NameGenerator("chl")
        virtual_ops = []
        channels = []
        input_channel = None
        actual_view = None
        for v_idx, view in enumerate(dag_views):
            if v_idx + 1 >= len(dag_views):
                break
            next_view = dag_views[v_idx + 1]
            if actual_view is None:
                actual_view = view
            actual_next_view = []
            pred_op_of_next_view_op = {}
            for op in actual_view:
                # find actual succ op in next view and create virtual op
                for succ_op in out_degree_ops[op.name]:
                    if succ_op in next_view:
                        if succ_op not in actual_next_view:
                            actual_next_view.append(succ_op)
                        if succ_op.name not in pred_op_of_next_view_op:
                            pred_op_of_next_view_op[succ_op.name] = []
                        pred_op_of_next_view_op[succ_op.name].append(op)
                    else:
                        # create virtual op
                        virtual_op = self._gen_virtual_op(virtual_op_name_gen)
                        virtual_ops.append(virtual_op)
                        out_degree_ops[virtual_op.name] = [succ_op]
                        actual_next_view.append(virtual_op)
                        pred_op_of_next_view_op[virtual_op.name] = [op]
                        virtual_op.add_virtual_pred_op(op)
            actual_view = actual_next_view
            # create channel
            processed_op = set()
            for o_idx, op in enumerate(actual_next_view):
                if op.name in processed_op:
                    continue
                channel = self._gen_channel(channel_name_gen)
                channels.append(channel)
B
barriery 已提交
441 442
                _LOGGER.debug("[DAG] Channel({}) => Op({})"
                        .format(channel.name, op.name))
443 444 445 446 447 448 449
                op.add_input_channel(channel)
                pred_ops = pred_op_of_next_view_op[op.name]
                if v_idx == 0:
                    input_channel = channel
                else:
                    # if pred_op is virtual op, it will use ancestors as producers to channel
                    for pred_op in pred_ops:
B
barriery 已提交
450 451
                        _LOGGER.debug("[DAG] Op({}) => Channel({})"
                                .format(pred_op.name, channel.name))
452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
                        pred_op.add_output_channel(channel)
                processed_op.add(op.name)
                # find same input op to combine channel
                for other_op in actual_next_view[o_idx + 1:]:
                    if other_op.name in processed_op:
                        continue
                    other_pred_ops = pred_op_of_next_view_op[other_op.name]
                    if len(other_pred_ops) != len(pred_ops):
                        continue
                    same_flag = True
                    for pred_op in pred_ops:
                        if pred_op not in other_pred_ops:
                            same_flag = False
                            break
                    if same_flag:
B
barriery 已提交
467 468
                        _LOGGER.debug("[DAG] Channel({}) => Op({})"
                                .format(channel.name, other_op.name))
469 470 471 472 473 474 475 476 477 478
                        other_op.add_input_channel(channel)
                        processed_op.add(other_op.name)
        output_channel = self._gen_channel(channel_name_gen)
        channels.append(output_channel)
        last_op.add_output_channel(output_channel)

        pack_func, unpack_func = None, None
        pack_func = response_op.pack_response_package

        actual_ops = virtual_ops
479
        for op in used_ops:
480 481 482 483 484 485
            if len(op.get_input_ops()) == 0:
                unpack_func = op.unpack_request_package
                continue
            actual_ops.append(op)

        for c in channels:
B
barriery 已提交
486 487
            _LOGGER.debug("Channel({}):\n -producers: {}\n -consumers: {}"
                    .format(c.name, c.get_producers(), c.get_consumers()))
488 489 490 491 492 493

        return (actual_ops, channels, input_channel, output_channel, pack_func,
                unpack_func)

    def build(self):
        (actual_ops, channels, input_channel, output_channel, pack_func,
494
         unpack_func) = self._build_dag(self._response_op)
B
barriery 已提交
495
        _LOGGER.info("[DAG] succ build dag")
496 497 498 499 500 501 502 503 504 505 506 507 508

        self._actual_ops = actual_ops
        self._channels = channels
        self._input_channel = input_channel
        self._output_channel = output_channel
        self._pack_func = pack_func
        self._unpack_func = unpack_func

        return self._input_channel, self._output_channel, self._pack_func, self._unpack_func

    def start(self):
        self._threads_or_proces = []
        for op in self._actual_ops:
B
barrierye 已提交
509
            op.use_profiler(self._use_profile)
B
barrierye 已提交
510
            if self._is_thread_op:
511
                self._threads_or_proces.extend(
512 513
                    op.start_with_thread(self._client_type))
            else:
514
                self._threads_or_proces.extend(
515
                    op.start_with_process(self._client_type))
B
barriery 已提交
516 517
        _LOGGER.info("[DAG] start")

518 519 520 521 522 523 524 525 526 527
        # not join yet
        return self._threads_or_proces

    def join(self):
        for x in self._threads_or_proces:
            x.join()

    def stop(self):
        for chl in self._channels:
            chl.stop()
528 529 530
        for op in self._actual_ops:
            op.clean_input_channel()
            op.clean_output_channels()