ps_trainer_pass.py 40.9 KB
Newer Older
Z
ziyoujiyi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import paddle
import paddle.compat as cpt
from ..ps.utils.public import *
from paddle.framework import core
from .pass_base import PassBase, register_pass
from paddle.fluid.transpiler.details.program_utils import delete_ops
from paddle.fluid.transpiler.collective import SingleProcessMultiThread

OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "gradient_clip"
STEP_COUNTER = "@PS_STEP_COUNTER@"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward

SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
SPARSE_GRAD_OP_TYPE_DICT = {
    "lookup_table_grad": "W",
    "lookup_table_v2_grad": "W"
}
DEVICE_LIST = ["cpu", "gpu", "xpu"]
COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"]
DEFAULT_DEVICE = 'cpu'


@register_pass("append_send_ops_pass")
class AppendSendOpsPass(PassBase):  # 该 pass 被多种模式复用
    def __init__(self):
        super(AppendSendOpsPass, self).__init__()

    def _check_self(self):
        return True

    def _check_conflict(self, other_pass):
        return True

    def _append_send_op(self, program, union_vars, queue, is_sparse, table_id,
                        ps_mode):
        if queue == STEP_COUNTER:
            send_input_vars = []
        else:
            send_input_vars = [
                program.global_block().vars[union_var]
                for union_var in union_vars
            ]

        dummy_output = []
        if ps_mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
            dummy_output = program.global_block().create_var(
                name=framework.generate_control_dev_var_name())

        program.global_block().append_op(
            type="send",
            inputs={"X": send_input_vars},
            outputs={"Out": dummy_output},
            attrs={
                "send_varnames": [queue],
                "is_sparse": is_sparse,
                "table_id": table_id,
                RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
            })

        return dummy_output

    def _append_barrier_op(self, program, dummys):
        program.global_block().append_op(
            type="send_barrier",
            inputs={"X": dummys},
            outputs={"Out": []},
            attrs={
                "trainer_id": trainer_id,
                "half_async": True,
                RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
            })

    def _apply_single_impl(self, main_program, startup_program, pass_ctx):
        attrs = pass_ctx._attrs
        ps_mode = attrs['ps_mode']
        if ps_mode == DistributedMode.GEO:
            send_ctx = get_geo_trainer_send_context(attrs)  # geo 模式
        else:
            send_ctx = get_the_one_send_context(attrs)  # async、sync 等各种模式
        dummys = []
        for merged_name, send in send_ctx.items():
            if send.is_sparse() and ps_mode != DistributedMode.GEO:
                continue
            is_sparse = 1 if send.is_sparse() else 0
            is_sparse = 2 if send.is_distributed() else is_sparse
            dummys.append(
                self._append_send_op(main_program,
                                     send.origin_varnames(), merged_name,
                                     is_sparse, send.table_id(), ps_mode))

        if ps_mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
            self._append_barrier_op(main_program, dummys)


@register_pass("distributed_ops_pass")
class DistributedOpsPass(PassBase):
    def __init__(self):
        super(DistributedOpsPass, self).__init__()
        self.w_2_table_id = {}
        self.emb_size = {}

    def _check_self(self):
        return True

    def _check_conflict(self, other_pass):
        return True

    def _push_sparse_fuse(self, _program, push_sparse_ops, attrs):
        if attrs['use_ps_gpu']:
            return
        if len(push_sparse_ops) == 0:
            return
        show = None
        clk = None
        use_entry = False
        for param, ops in push_sparse_ops.items():
            op_first = ops[0]
            break
        if op_first.has_attr("entry"):
            entry = op_first.attr("entry")
            entry = entry.split(':')
            if len(entry) == 3 and entry[0] == 'show_click_entry':
                show_var_name = entry[1]
                click_var_name = entry[2]
                if show_var_name in _program.global_block(
                ).vars and click_var_name in _program.global_block().vars:
                    show = _program.global_block().vars[show_var_name]
                    clk = _program.global_block().vars[click_var_name]
                    use_entry = True
                else:
                    warnings.warn(
                        'ShowClickEntry configured, but cannot find show/click var, will not use'
                    )

        if not use_entry:
            print('ShowClickEntry not configured, will not use')
            show = _program.global_block().create_var(
                name="show",
                dtype=core.VarDesc.VarType.INT64,
                persistable=False,
                stop_gradient=True)
            _program.global_block()._insert_op(
                index=0,
                type='fill_constant',
                inputs={},
                outputs={'Out': show},
                attrs={
                    'shape': [1],
                    'dtype': show.dtype,
                    'value': 1,
                })

            clk = _program.global_block().create_var(
                name="clk",
                dtype=core.VarDesc.VarType.INT64,
                persistable=False,
                stop_gradient=True)
            _program.global_block()._insert_op(
                index=0,
                type='fill_constant',
                inputs={},
                outputs={'Out': clk},
                attrs={
                    'shape': [1],
                    'dtype': clk.dtype,
                    'value': 0,
                })

        for param, ops in push_sparse_ops.items():
            all_ops = _program.global_block().ops
            op_idxs = [all_ops.index(op) for op in ops]
            inputs = [
                _program.global_block().vars[op.input("Ids")[0]] for op in ops
            ]
            w = _program.global_block().vars[ops[0].output("W@GRAD")[0]]
            table_id = self.w_2_table_id[param]

            padding_idx = ops[0].attr("padding_idx")
            is_distributed = ops[0].attr("is_distributed")
            op_type = ops[0].type
            outputs = [
                _program.global_block().vars[op.input("Out@GRAD")[0]]
                for op in ops
            ]

            for idx in op_idxs[::-1]:
                _program.global_block()._remove_op(idx)

            _program.global_block().append_op(
                type="distributed_push_sparse",
                inputs={
                    "Ids": inputs,
                    'W': w,
                    "Outputs": outputs,
                    "Shows": show,
                    "Clicks": clk
                },
                outputs={"Outputs": outputs},
                attrs={
                    "is_distributed": is_distributed,
                    "padding_idx": padding_idx,
                    "table_id": table_id,
                    "size": self.emb_size[param]
                })

    def _pull_sparse_fuse(self, _program, pull_sparse_ops, attrs, send_ctx):
        def dag_check_up_and_reorder(program, inputs, outputs):
            global_block = program.global_block()
            min_output_index = len(global_block.ops)
            max_input_index = -1
            input_indexes = [0] * len(global_block.ops)
            output_indexes = [0] * len(global_block.ops)
            for idx, op in enumerate(global_block.ops):
                for i in range(0, len(op.output_names)):
                    if input_indexes[idx] == 1:
                        break
                    outs = op.output(op.output_names[i])
                    for in_id, in_var in enumerate(inputs):
                        if in_var.name in outs:
                            input_indexes[idx] = 1
                            max_input_index = max(max_input_index, idx)
                            break

                for i in range(0, len(op.input_names)):
                    if output_indexes[idx] == 1:
                        break
                    ins = op.input(op.input_names[i])
                    for out_id, out_var in enumerate(outputs):
                        if out_var.name in ins:
                            output_indexes[idx] = 1
                            min_output_index = min(min_output_index, idx)

            for i in range(len(global_block.ops)):
                if input_indexes[i] == 1 and output_indexes[i] == 1:
                    warnings.warn(
                        "unable to re-arrange dags order to combine distributed embedding ops because a op both needs embedding table's output as input and produces ids as the same embedding table's input"
                    )
                    return

            if min_output_index < max_input_index:
                move_ops = []
                for i in range(min_output_index + 1, len(input_indexes)):
                    if input_indexes[i] == 1:
                        move_ops.append((global_block.ops[i], i))
                for i, op in enumerate(move_ops):
                    queue = list()
                    visited = set()
                    queue.append(op[1])
                    visited.add(op[0])
                    start = 0
                    while start < len(queue):
                        pos = queue[start]
                        op = global_block.ops[pos]
                        op_inputs = []
                        for k in range(0, len(op.input_names)):
                            ins = op.input(op.input_names[k])
                            op_inputs.append(ins)
                        for j in range(pos - 1, min_output_index - 1, -1):
                            op1 = global_block.ops[j]
                            if op1 in visited:
                                continue
                            found = False
                            for k in range(0, len(op1.output_names)):
                                outs = op1.output(op1.output_names[k])
                                for t in range(len(op_inputs)):
                                    for y in op_inputs[t]:
                                        if y in outs:
                                            found = True
                                            break
                                    if found:
                                        break
                                if found:
                                    break
                            if found:
                                if output_indexes[j] == True:
                                    warnings.warn(
                                        "unable to re-arrange dags order to combine distributed embedding ops"
                                    )
                                    return
                                queue.append(j)
                                visited.add(global_block.ops[j])
                        start = start + 1

                    queue.sort()
                    for index in queue:
                        desc = global_block.desc._insert_op(min_output_index)
                        desc.copy_from(global_block.ops[index].desc)
                        global_block.desc._remove_op(index + 1, index + 2)
                        global_block.ops[index].desc = desc
                        insert_op = global_block.ops.pop(index)
                        input_state = input_indexes.pop(index)
                        output_state = output_indexes.pop(index)
                        global_block.ops.insert(min_output_index, insert_op)
                        input_indexes.insert(min_output_index, input_state)
                        output_indexes.insert(min_output_index, output_state)
                        min_output_index = min_output_index + 1

                assert global_block.desc.op_size() == len(global_block.ops)
                for i in range(len(global_block.ops)):
                    assert global_block.desc.op(i) == global_block.ops[i].desc

        for param, ops in pull_sparse_ops.items():
            all_ops = _program.global_block().ops
            op_device = ""
            if attrs['is_heter_ps_mode']:
                op_device = ops[0].attr("op_device")
            inputs = [
                _program.global_block().vars[op.input("Ids")[0]] for op in ops
            ]
            w = _program.global_block().vars[ops[0].input("W")[0]]
            self.emb_size[param] = w.shape[1]

            grad_name = attrs['param_name_to_grad_name'][w.name]

            table_id = -1

            for name, ctx in send_ctx.items():
                if grad_name in ctx.origin_varnames():
                    table_id = ctx.table_id()

            if table_id == -1:
                raise ValueError(
                    "can not find suitable sparse table, please check")

            self.w_2_table_id[param] = table_id
            padding_idx = ops[0].attr("padding_idx")
            is_distributed = ops[0].attr("is_distributed")
            op_type = ops[0].type

            outputs = [
                _program.global_block().vars[op.output("Out")[0]] for op in ops
            ]

            dag_check_up_and_reorder(_program, inputs, outputs)

            op_idxs = [all_ops.index(op) for op in ops]

            for idx in op_idxs[::-1]:
                _program.global_block()._remove_op(idx)

            inputs_idxs = [-1] * len(inputs)
            outputs_idxs = [len(_program.global_block().ops) + 1] * len(outputs)

            for idx, op in enumerate(_program.global_block().ops):
                for i in range(0, len(op.output_names)):
                    outs = op.output(op.output_names[i])
                    for in_id, in_var in enumerate(inputs):
                        if in_var.name in outs:
                            inputs_idxs[in_id] = max(idx, inputs_idxs[in_id])
                for i in range(0, len(op.input_names)):
                    ins = op.input(op.input_names[i])
                    for out_id, out_var in enumerate(outputs):
                        if out_var.name in ins:
                            outputs_idxs[out_id] = min(idx,
                                                       outputs_idxs[out_id])

            if min(outputs_idxs) - max(inputs_idxs) >= 1:
                if max(inputs_idxs) == -1:
                    distributed_idx = min(op_idxs)
                else:
                    distributed_idx = max(inputs_idxs) + 1

                if attrs['use_ps_gpu']:
                    _program.global_block()._insert_op(
                        index=distributed_idx,
                        type="pull_box_sparse",
                        inputs={"Ids": inputs,
                                'W': w},
                        outputs={"Out": outputs},
                        attrs={
                            "size": w.shape[1],
                            "is_distributed": True,
                            "is_sparse": True
                        })
                else:
                    _program.global_block()._insert_op(
                        index=distributed_idx,
                        type="distributed_lookup_table",
                        inputs={"Ids": inputs,
                                'W': w},
                        outputs={"Outputs": outputs},
                        attrs={
                            "is_distributed": is_distributed,
                            "padding_idx": padding_idx,
                            "table_id": table_id,
                            "lookup_table_version": op_type,
                            "op_device": op_device
                        })
            else:
                for i in range(len(inputs_idxs)):
                    distributed_idx = op_idxs[i]

                    _program.global_block()._insert_op(
                        index=distributed_idx,
                        type="distributed_lookup_table",
                        inputs={"Ids": [inputs[i]],
                                'W': w},
                        outputs={"Outputs": [outputs[i]]},
                        attrs={
                            "is_distributed": is_distributed,
                            "padding_idx": padding_idx,
                            "table_id": table_id,
                            "lookup_table_version": op_type,
                            "op_device": op_device
                        })

    def _get_pull_sparse_ops(self, _program, attrs):
        pull_sparse_ops = {}
        pull_sparse_ids = {}
        push_sparse_ops = {}
        ops = {}
        for op in _program.global_block().ops:
            if op.type in SPARSE_OP_TYPE_DICT.keys() \
                    and op.attr('remote_prefetch') is True:
                param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0]
                if attrs['is_heter_ps_mode']:
                    # trick for matchnet, need to modify
                    param_name += op.input("Ids")[0][0]
                ops = pull_sparse_ops.get(param_name, [])
                ops.append(op)
                pull_sparse_ops[param_name] = ops
                ids = pull_sparse_ids.get(param_name, [])
                ids.append(op.input("Ids")[0])
                pull_sparse_ids[param_name] = ids
        for op in _program.global_block().ops:
            if op.type in SPARSE_GRAD_OP_TYPE_DICT.keys():
                param_name = op.input(SPARSE_GRAD_OP_TYPE_DICT[op.type])[0]
                if param_name in pull_sparse_ids and op.input("Ids")[
                        0] in pull_sparse_ids[param_name]:
                    ops = push_sparse_ops.get(param_name, [])
                    ops.append(op)
                    push_sparse_ops[param_name] = ops

        return pull_sparse_ops, push_sparse_ops

    def _apply_single_impl(self, main_program, startup_program, pass_ctx):
        attrs = pass_ctx._attrs
        pull_sparse_ops, push_sparse_ops = self._get_pull_sparse_ops(
            main_program, attrs)
        send_ctx = get_the_one_send_context(
            attrs, split_dense_table=attrs['is_heter_ps_mode'])
        self._pull_sparse_fuse(main_program, pull_sparse_ops, attrs, send_ctx)
        self._push_sparse_fuse(main_program, push_sparse_ops, attrs)


@register_pass("delete_optimizer_pass")
class DeleteOptimizesPass(PassBase):
    def __init__(self):
        super(DeleteOptimizesPass, self).__init__()

    def _check_self(self):
        return True

    def _check_conflict(self, other_pass):
        return True

    def _delete_optimizer_op_and_vars(self, _program, optimize_ops):
        optimize_vars = []
        optimize_op_role_vars = []
        optimize_need_delete_vars = []

        for op in optimize_ops:
            optimize_vars.extend(op.input_arg_names)
            optimize_op_role_vars.extend(op.attr("op_role_var"))

        optimize_vars = list(set(optimize_vars))
        optimize_op_role_vars = list(set(optimize_op_role_vars))

        for var in optimize_vars:
            if var not in optimize_op_role_vars:
                optimize_need_delete_vars.append(var)
        need_delete_optimize_vars = list(set(optimize_need_delete_vars))

        delete_ops(_program.global_block(), optimize_ops)
        for var in need_delete_optimize_vars:
            if _program.global_block().has_var(var):
                _program.global_block()._remove_var(var)

    def _add_lr_var(self, main_program, attrs):
        # Todo: hard code for pe
        lr_var = attrs['origin_main_program'].global_block().vars[
            "learning_rate_0"]
        main_program.global_block().create_var(
            name=lr_var.name,
            shape=lr_var.shape,
            dtype=lr_var.dtype,
            type=lr_var.type,
            lod_level=lr_var.lod_level,
            persistable=True)

    def _apply_single_impl(self, main_program, startup_program, pass_ctx):
        attrs = pass_ctx._attrs
        optimizer_ops = get_optimize_ops(main_program)
        lr_ops = get_lr_ops(main_program)
        optimizer_ops.extend(lr_ops)
        self._delete_optimizer_op_and_vars(main_program, optimizer_ops)

        if hasattr(attrs['origin_main_program'], 'lr_sheduler'):
            self._add_lr_var(main_program, attrs)


@register_pass("delete_extra_optimizer_pass")
class DeleteExtraOptimizerPass(PassBase):
    def __init__(self):
        super(DeleteExtraOptimizerPass, self).__init__()

    def _check_self(self):
        return True

    def _check_conflict(self, other_pass):
        return True

    def _apply_single_impl(self, main_program, startup_program, pass_ctx):
        attrs = pass_ctx._attrs
        optimize_vars = []
        optimize_op_role_vars = []
        optimize_need_delete_vars = []

        for op in get_optimize_ops(main_program):
            optimize_vars.extend(op.input_arg_names)
            optimize_op_role_vars.extend(op.attr("op_role_var"))

        optimize_vars = list(set(optimize_vars))
        optimize_op_role_vars = list(set(optimize_op_role_vars))
        for var in optimize_vars:
            if var not in optimize_op_role_vars:
                optimize_need_delete_vars.append(var)
        need_delete_optimize_vars = list(set(optimize_need_delete_vars))

        init_ops = []
        for var in need_delete_optimize_vars:
            param_init_op = []
            for op in startup_program.global_block().ops:
                if var in op.output_arg_names:
                    param_init_op.append(op)
            init_ops.extend(param_init_op)
        delete_ops(startup_program.global_block(), init_ops)

        for var in need_delete_optimize_vars:
            if startup_program.global_block().has_var(var):
                startup_program.global_block()._remove_var(var)


@register_pass("fake_init_ops_pass")
class FakeInitOpsPass(PassBase):
    def __init__(self):
        super(FakeInitOpsPass, self).__init__()

    def _check_self(self):
        return True

    def _check_conflict(self, other_pass):
        return True

    def _get_sparse_table_names(self, attrs):
        dist_varnames = get_sparse_tablenames(attrs['origin_main_program'],
                                              True)
        sparse_varnames = get_sparse_tablenames(attrs['origin_main_program'],
                                                False)
        return list(set(dist_varnames + sparse_varnames))

    def _fake_init_sparsetable(self, program, sparse_table_names):
        # delete table init op
        for table_name in sparse_table_names:
            table_var = program.global_block().vars[table_name]
            table_param_init_op = []
            for op in program.global_block().ops:
                if table_name in op.output_arg_names:
                    table_param_init_op.append(op)
            init_op_num = len(table_param_init_op)
            if init_op_num != 1:
                raise ValueError("table init op num should be 1, now is " + str(
                    init_op_num))
            table_init_op = table_param_init_op[0]
            program.global_block().append_op(
                type="fake_init",
                inputs={},
                outputs={"Out": table_var},
                attrs={"shape": table_init_op.attr('shape')})
            delete_ops(program.global_block(), table_param_init_op)

    def _apply_single_impl(self, main_program, startup_program, pass_ctx):
        attrs = pass_ctx._attrs
        sparse_tables = self._get_sparse_table_names(attrs)
        self._fake_init_sparsetable(startup_program, sparse_tables)


@register_pass("ps_gpu_pass")
class PsGpuPass(PassBase):
    def __init__(self):
        super(PsGpuPass, self).__init__()

    def _check_self(self):
        return True

    def _check_conflict(self, other_pass):
        return True

    def _add_push_box_sparse_op(self, program):
        for op in program.global_block().ops:
            if op.type != "pull_box_sparse":
                continue
            grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
                op.desc, cpt.to_text(set()), [])
            for op_desc in grad_op_desc:
                new_op_desc = program.global_block().desc.append_op()
                new_op_desc.copy_from(op_desc)
                new_op_desc._set_attr(op_role_attr_name, backward)

    def _remove_optimizer_var(self, program):
        embedding_w = {}
        for idx, op in list(enumerate(program.global_block().ops)):
            if op.type == "lookup_table_grad":
                for name in op.input("W"):
                    embedding_w[name] = 1

        optimize_vars = []
        optimize_op_role_vars = []
        optimize_need_delete_vars = []
        for op in get_optimize_ops(program):
            for name in op.input("Param"):
                if name in embedding_w:
                    optimize_op_role_vars.extend(op.attr("op_role_var"))
                    for key_name in op.input_names:
                        if key_name == "LearningRate":
                            continue
                        for var in op.input(key_name):
                            optimize_vars.append(var)

        optimize_vars = list(set(optimize_vars))
        optimize_op_role_vars = list(set(optimize_op_role_vars))

        for var in optimize_vars:
            if var not in optimize_op_role_vars:
                optimize_need_delete_vars.append(var)
        need_delete_optimize_vars = list(set(optimize_need_delete_vars))

        for name in need_delete_optimize_vars:
            if program.global_block().has_var(name):
                program.global_block()._remove_var(name)

    def _remove_lookup_table_grad_op_and_var(self, program):
        lookup_table_grad_var = {}
        remove_op_index = []
        remove_var = []
        for idx, op in list(enumerate(program.global_block().ops)):
            if op.type == "lookup_table_grad":
                for name in op.output("W@GRAD"):
                    lookup_table_grad_var[name] = 1
                    remove_op_index.append(idx)
                    remove_var.append(name)
                for name in op.input("W"):
                    lookup_table_grad_var[name] = 1

        for idx, op in list(enumerate(program.global_block().ops)):
            if op.type == "pull_box_sparse":
                continue
            for key_name in op.input_names:
                for var in op.input(key_name):
                    if var in lookup_table_grad_var:
                        remove_op_index.append(idx)
                        break

        remove_op_index = list(set(remove_op_index))
        remove_op_index.sort(reverse=True)
        for idx in remove_op_index:
            program.global_block()._remove_op(idx)
        for name in remove_var:
            program.global_block()._remove_var(name)

    def _apply_single_impl(self, main_program, startup_program, pass_ctx):
        attrs = pass_ctx._attrs
        self._add_push_box_sparse_op(main_program)
        self._remove_optimizer_var(main_program)
        self._remove_lookup_table_grad_op_and_var(main_program)


@register_pass("ps_transpile_pass")
class PsTranspilePass(PassBase):
    def __init__(self):
        super(PsTranspilePass, self).__init__()

    def _check_self(self):
        return True

    def _check_conflict(self, other_pass):
        return True

    def _apply_single_impl(self, main_program, startup_program, pass_ctx):
        attrs = pass_ctx._attrs
        t = SingleProcessMultiThread()
        env = get_dist_env()
        t.transpile(
            startup_program=startup_program,
            main_program=main_program,
            rank=env["trainer_id"],
            endpoints=env["trainer_endpoints"],
            current_endpoint=env['current_endpoint'],
            wait_port=False)


@register_pass("split_heter_worker_ops_pass")
class SplitHeterWorkerOpsPass(PassBase):
    def __init__(self):
        super(SplitHeterWorkerOpsPass, self).__init__()

    def _check_self(self):
        return True

    def _check_conflict(self, other_pass):
        return True

    def _create_heter_program(self, program, attrs, heter_program,
                              program_block_ops_list, heter_ops,
                              block_var_detail):
        # This function mainly includes the following contents:
        # 1. For every heter block:
        #     a) copy heter device op from origin program
        #     b) create variables which belong to heter op:
        #         -> if variable is persistable, clone it in global_scope
        #         -> if variable is temp, create it in heter block
        #     c) create communicate related op as follow:
        #         joint_var.0_1 -> slice -> reshape -> origin_var
        #         origin_var -> origin_program
        #         reshape -> concat -> joint_var.1_2
        #     d) copy send op from origin program for var@grad which loacted in current heter block
        #     e) re-check every op in current blcok if its device is not current heter devie
        # 2. Create send op for step counter in last heter-block
        # 3. Create Listen&Serv OP and Send&Recv OP for distributed training
        # 4. update CompileTimeStrategy for heter_program

        optimizer_block = []
        grad_to_block_id = []
        send_grad_var_list = []

        pre_block_idx = heter_program.num_blocks - 1
        role_maker = attrs['role_maker']
        current_device = role_maker._heter_device_type().lower()
        stage_id = int(role_maker._get_stage_id())

        heter_block_ops_forward = program_block_ops_list[stage_id - 1][
            "forward"]
        heter_block_ops_backward = program_block_ops_list[stage_id - 1][
            "backward"]

        heter_block = heter_program._create_block(pre_block_idx)
        optimizer_block.append(heter_block)
        for _, op in enumerate(heter_block_ops_forward):
            block_append_op(heter_program, program, heter_block, op)

        entrance_vars = block_var_detail[stage_id - 1]["forward"]["entrance"]
        add_vars_by_var_list(entrance_vars, program, heter_program, heter_block)
        exit_vars = block_var_detail[stage_id - 1]["forward"]["exit"]
        add_vars_by_var_list(exit_vars, program, heter_program, heter_block)

        first_op_index_fp = len(heter_block.ops)

        if stage_id < len(program_block_ops_list):

            heter_block_bp = heter_program._create_block(pre_block_idx)
            optimizer_block.append(heter_block_bp)

            for _, op in enumerate(heter_block_ops_backward):
                block_append_op(heter_program, program, heter_block_bp, op)

            bp_entrance_vars = block_var_detail[stage_id - 1]["backward"][
                "entrance"]
            add_vars_by_var_list(bp_entrance_vars, program, heter_program,
                                 heter_block_bp)
            bp_exit_vars = block_var_detail[stage_id - 1]["backward"]["exit"]
            add_vars_by_var_list(bp_exit_vars, program, heter_program,
                                 heter_block_bp)
            backward_comm_info = get_communicate_var_info(
                program, stage_id, bp_entrance_vars, type="backward")

            grad_to_block_id.append(backward_comm_info["block_input_var_name"] +
                                    ":" + str(heter_block_bp.idx))

        else:
            for _, op in enumerate(heter_block_ops_backward):
                block_append_op(heter_program, program, heter_block, op)

            bp_entrance_vars = block_var_detail[stage_id - 1]["backward"][
                "entrance"]
            add_vars_by_var_list(bp_entrance_vars, program, heter_program,
                                 heter_block)
            bp_exit_vars = block_var_detail[stage_id - 1]["backward"]["exit"]
            add_vars_by_var_list(bp_exit_vars, program, heter_program,
                                 heter_block)

            heter_block_bp = heter_block

        forward_comm_info = get_communicate_var_info(
            program, stage_id, entrance_vars, type="forward")

        grad_to_block_id.append(forward_comm_info["block_input_var_name"] + ":"
                                + str(heter_block.idx))

        first_op_index_bp = len(heter_block_bp.ops)

        if stage_id <= len(block_var_detail) - 1:
            static_var = insert_communicate_op(program, role_maker, heter_block,
                                               stage_id, first_op_index_fp,
                                               block_var_detail, current_device)
        static_var_bp = insert_communicate_op(
            program, role_maker, heter_block_bp, stage_id, first_op_index_bp,
            block_var_detail, current_device, False)

        # add send op
        send_grad_var_list = add_heter_send_op(program, heter_program,
                                               heter_block_bp,
                                               block_var_detail[stage_id - 1])

        # add step conter
        send_input_vars = []
        dummy_output = []
        pserver_endpoints = get_ps_endpoints(role_maker)
        attrs = {
            "message_to_block_id": grad_to_block_id,
            "optimize_blocks": optimizer_block,
            # runtime attribute
            "endpoint": get_heter_worker_endpoint(role_maker),
            "fanin": len(get_previous_stage_trainers(role_maker)),
            "pserver_id": get_role_id(role_maker),
            "distributed_mode": attrs['ps_mode'],
            "rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)),
            RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
        }
        # append the listen_and_serv op
        heter_program.global_block().append_op(
            type="heter_listen_and_serv",
            inputs={'X': []},
            outputs={},
            attrs=attrs)
        # TODO check heter program

    def _apply_single_impl(self, main_program, startup_program, pass_ctx):
        """
        split heter worker program from origin-program
        1. find heter op (located on different device)
        2. find input&output of every heter-block
        3. create heter worker program, add listen&serv op
        """
        attrs = pass_ctx._attrs
        default_deveice = "cpu"
        program, heter_ops, _, program_block_ops = find_heter_ops(
            main_program, default_deveice)
        if len(heter_ops) == 0:
            warnings.warn(
                "Currently running in Heter Parameter Server mode, but no OP running on heterogeneous devices, Please check your code."
            )
            main_program = program
            return

        program_block_ops = union_forward_gradient_op(program_block_ops)
        block_vars_detail = find_block_joints(program, program_block_ops,
                                              heter_ops)
        heter_program = framework.Program()
        self._create_heter_program(program, attrs, heter_program,
                                   program_block_ops, heter_ops,
                                   block_vars_detail)
        main_program = heter_program


@register_pass("split_trainer_ops_pass")
class SplitTrainerOpsPass(PassBase):
    def __init__(self):
        super(SplitTrainerOpsPass, self).__init__()

    def _check_self(self):
        return True

    def _check_conflict(self, other_pass):
        return True

    def _create_trainer_program(self, program, origin_program, attrs,
                                program_block_ops_list, block_var_detail):
        # This function mainly includes the following contents:
        # 1. For every heter block in origin program
        #     a) delete heter op and related variables
        #     b) add send&recv op
        #     c) add communicate ops as follows:
        #         origin_var -> reshape -> concat -> joint_var.0_1
        #         send&recv op(send joint_var.0_1; recv joint_var.1_2)
        #         joint_var.1_2 -> slice -> reshape -> origin_var
        #     d) remove send op which related var@grad is not in trainer program
        # 2. check every op's device
        static_var = []
        for heter_block_index in range(1, len(program_block_ops_list)):
            ops_list = program_block_ops_list[heter_block_index][
                "forward"] + program_block_ops_list[heter_block_index][
                    "backward"]
            static_var += replace_ops_by_communicate_op(
                program, attrs, heter_block_index, ops_list, block_var_detail)
            remove_trainer_send_op(program, attrs, heter_block_index,
                                   block_var_detail)

        optimizer_block = []
        grad_to_block_id = []

        bp_ops_list = program_block_ops_list[0]["backward"]
        delete_same_ops(program.global_block(), bp_ops_list)
        delete_trainer_useless_var(attrs, program, static_var)
        backward_block = create_backward_block(program, origin_program, attrs,
                                               bp_ops_list, block_var_detail)

        bp_entrance_vars = block_var_detail[0]["backward"]["entrance"]
        backward_comm_info = get_communicate_var_info(
            origin_program, 1, bp_entrance_vars, type="backward")

        grad_to_block_id.append(backward_comm_info["block_input_var_name"] + ":"
                                + str(backward_block.idx))
        optimizer_block.append(backward_block)
        role_maker = attrs['role_maker']
        attrs = {
            "message_to_block_id": grad_to_block_id,
            "optimize_blocks": optimizer_block,
            # runtime attribute
            "endpoint":
            get_trainer_endpoint(role_maker),  ## get trainer endpoint
            "fanin": 0,  ## get heter worker
            "pserver_id": get_role_id(role_maker),
            "distributed_mode": attrs['ps_mode'],
            "rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)),
            RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
        }
        # append the listen_and_serv op
        program.global_block()._insert_op(
            index=0,
            type="heter_listen_and_serv",
            inputs={'X': []},
            outputs={},
            attrs=attrs)

        ## TODO add check for bp block
        #check_op_device(program.global_block(), DEFAULT_DEVICE)

    def _apply_single_impl(self, main_program, startup_program, pass_ctx):
        """
        split cpu-trainer program from origin-program
        1. find heter op (located on different device)
        2. find input&output of every heter-block
        3. create cpu-trainer program, add send&recv op 
        """
        attrs = pass_ctx._attrs
        default_device_ = 'cpu'
        program, heter_ops, default_ops, program_block_ops = find_heter_ops(
            main_program, default_device_)
        program_block_ops = union_forward_gradient_op(program_block_ops)

        block_vars_detail = find_block_joints(program, program_block_ops,
                                              heter_ops)
        trainer_program = program.clone()
        self._create_trainer_program(trainer_program, program, attrs,
                                     program_block_ops, block_vars_detail)
        main_program = trainer_program


@register_pass("set_heter_pipeline_opt_pass")
class SetHeterPipelineOptPass(PassBase):
    def __init__(self):
        super(SetHeterPipelineOptPass, self).__init__()

    def _check_self(self):
        return True

    def _check_conflict(self, other_pass):
        return True

    def _apply_single_impl(self, main_program, startup_program, pass_ctx):
        attrs = pass_ctx._attrs
        role_maker = attrs['role_maker']
        num_microbatches = attrs['user_defined_strategy'].pipeline_configs[
            'accumulate_steps']

        attrs['origin_startup_program']._heter_pipeline_opt = {
            "startup_program": startup_program,
            "pipeline_stage": int(role_maker._get_stage_id()) - 1,
            "heter_place": role_maker._heter_device(),
        }
        attrs['origin_main_program']._heter_pipeline_opt = {
            "trainer": "HeterPipelineTrainer",
            "device_worker": "HeterSection",
            "trainers":
            role_maker._get_stage_trainers(),  ## trainer num in each stage
            "trainer_id": int(role_maker._role_id()),
            "pipeline_stage": int(role_maker._get_stage_id()) - 1,
            "num_pipeline_stages": int(role_maker._get_num_stage()),
            "section_program": main_program,
            "num_microbatches": num_microbatches,
            "heter_place": role_maker._heter_device(),
        }