partitioner.py 40.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import copy
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid import framework as framework
from paddle.fluid import core, unique_name
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.backward import append_backward, _some_in_set_, _append_grad_suffix_
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator
from paddle.distributed.auto_parallel.operators.common import find_best_compatible_distributed_operator_impl
from paddle.fluid.clip import GradientClipBase, GradientClipByNorm, error_clip_callback, append_gradient_clip_ops, ClipGradByGlobalNorm
from paddle.distributed.fleet.base.distributed_strategy import DistributedStrategy
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from .process import new_process_group
from .interface import _g_process_mesh_map
from .utils import _get_comm_group

__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]


class Partitioner(object):
    """
    warning:: Partitioner is experimental and subject to change.

    Partitioner convert a program into another program.
    Given a serial program which has been auto completed with shard annotation, the Partitioner 
    convert the serial program into a "distributed" program. The Partitioner will  modify the serial
    program in following two ways, which is also the major difference between serial and distributed program:
        1. partition op: replace a serial op into its corresponding dist op infered from the shard annotation
        2. partition var: if a var is sharded, modify the shape of var according to its shard annotation

    Partitioner is supposed to be call by the auto parallel framework, and not supposed to be directly called by user.

    Example:
        ....
            import paddle.distributed.auto_parallel as auto
            from paddle.fluid.distributed_attribute import get_default_distributed_context
            from paddle.distributed import fleet
            from paddle.distributed.auto_parallel.partitioner import Partitioner

            # create serial program with forward only 
            with static.program_guard(serial_main_program, serial_start_program):
                model = create_model(config)
                tokens = static.data(name="tokens", shape=[batch_size, sequence_len], dtype='int64')
                labels = static.data(name="labels", shape=[batch_size, sequence_len], dtype='int64')
                loss_mask = static.data(name="loss_mask", shape=[batch_size, sequence_len], dtype='int64')
                preds = model(tokens)
                loss = criterion(preds, labels, loss_mask)

            # auto completion
            auto.ProcessMesh(shape=[2, 4], process_group=[0, 1, 2, 3, 4, 5, 6, 7])
            annotated_main_program = auto.complete_annotation(serial_main_program)
            auto_paralle_context = get_default_distributed_context()
                
            # distributed strategy & rank info
            rank_id = paddle.distributed.get_rank()
            dist_strategy = fleet.DistributedStrategy()
    
            # create partitioner
            Partitioner = Partitioner(dist_strategy, auto_paralle_context, rank_id)

            # create dist program with forward only
            # for distributed inference, using partitioned_main_prog from here
            partitioned_main_prog, partitioned_startup_prog = Partitioner.transpile_forward(complete_train_program, start_program)

            # create dist program with forward/backward/update
            # for distributed training, using partitioned_main_prog from here
            dist_params_grads = Partitioner.apply_backward(loss, complete_train_program, start_program, partitioned_main_prog, partitioned_startup_prog)
            optimizer = paddle.fluid.optimizer.AdamOptimizer(
                learning_rate=0.00001,
                beta1=0.9,
                beta2=0.999,
                epsilon=1e-08,
                grad_clip=None)
            opt_ops = Partitioner.apply_optimize(optimizer, dist_params_grads, partitioned_main_prog, partitioned_startup_prog)
    """

    def __init__(self, dist_strategy, auto_parallel_context, rank_id=0):
        """
        Args:
            dist_strategy (paddle.fleet.distributed_strategy): used to determine the user defined distributed strategy.
            auto_parallel_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario.
            rank_id (int): global rank id to which the partitioned distributed program belong.
        """

        if not isinstance(dist_strategy, DistributedStrategy):
            raise TypeError(
                "dist_strategy be paddle.fleet.base.DistributedStrategy, got %s here"
                % type(dist_strategy))

        if not isinstance(auto_parallel_context, DistributedContext):
            raise TypeError(
                "auto_parallel_context be paddle.fluid.DistributedContext, got %s here"
                % type(auto_parallel_context))

        self._dist_strategy = dist_strategy
        self._auto_parallel_context = auto_parallel_context
        self._rank_id = rank_id
        self._serial2dist_varname_mapping = {}
        self._dist_varname_suffix = ""

        # TODO if there is some dist op that is not compatible 
        # with auto_backward in forward, the following flag 
        # should be set to False
        self._compatible_with_auto_backward = True

        # data parallelism        
        self._enable_data_parallel = False
        self._dp_degree = 0
        self._dp_group = None

        # tensor parallelism        
        self._enable_tensor_parallel = False
        self._tp_degree = 0
        self._tp_group = None

    def transpile_forward(self, serial_main_program, serial_startup_program):
        """
        take serial forward programs with shard annotation, create a new distributed forward programs based on the serial ones.
        instead of modify the input programs inplace, this function will preserve the inputs and create new program for output.

        beside replace the serial op with its dist op, if user has defined other strategy in fleet.distributed_strategy, and if 
        those strategy need to transpile (modify) the forward network program, those forward program modification should also be done within this
        function in auto parallel scenario, in order to facilitate distributed inference/evaluation which need to DECOUPLE strategy specific forward transpilation with fleet.distributed_optimizer.minimize().

        by now the fleet.distributed_strategy that need transpile forward program are following: 
            1. (optimizer) sharding

        Args:
            main_program (paddle.fluid.framework.program): serial main program with forward network only
            startup_program (paddle.fluid.framework.program): serial startup program with forward network only
        
        return:
            main_program (paddle.fluid.framework.program): distributed main program with forward network only
            startup_program (paddle.fluid.framework.program): distributed startup program with forward network only
        """

        dist_main_program, dist_startup_program = self.transpile_forward_impl(
            serial_main_program, serial_startup_program)
        return dist_main_program, dist_startup_program

    def apply_backward(self,
                       serial_loss,
                       serial_main_program,
                       serial_startup_program,
                       dist_main_program,
                       dist_startup_program,
                       parameter_list=None,
                       no_grad_set=None,
                       callbacks=None):
        """
        A complete training neural network is made up of forward and backward propagation. 
        This function is to generate the dist backward program for the distributed forward program.

        By now, the current automatical backward mechanism in paddle framework might NOT handle the backward generation for 
        some dist ops correctly, some so we now have two ways to genenate the backward program:
            1. dist_forward_program --> auto_backward --> dist_backward_program (if auto_backward could handle all dist op)
            2. serial_forward_program --> auto_backward --> serial_backward_program --> dist_op_backward_transpile --> dist_backward_program (if auto_backward could not handle all dist op)
        
        the backprogram is append the input dist program inplaced.

        Args:
            serial_loss (Variable) the loss in serial program that to be minimized 
            serial_main_program (paddle.fluid.framework.program): serial main program with forward network only
            serial_startup_program (paddle.fluid.framework.program): serial startup program with forward network only
            dist_main_program (paddle.fluid.framework.program): dist main program with forward network only
            dist_startup_program (paddle.fluid.framework.program): dist startup program with forward network only
            parameter_list (Iterable, optional): Iterable of ``Variable`` or ``Variable.name`` to update
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
            no_grad_set (set, optional): Set of ``Variable``  or ``Variable.name`` that don't need
                to be updated. The default value is None.
            callbacks (list, optional): list of callable objects to run when appending backward
                operator for one parameter. The default value is None.
        
        return:
            params_grads (list) list of tuple that contain param and its grad variable
        """
        params_grads = self.apply_backward_impl(
            serial_loss, serial_main_program, serial_startup_program,
            dist_main_program, dist_startup_program)
        return params_grads

    def apply_optimize(self, user_define_optimizer, params_grads,
                       dist_main_program, dist_startup_program):
        """
        append update related ops to the program: clip, weight decay, ops
        filter optimize op if sharding is enable
        naive gradient synchronization before update

        Args:
            user_define_optimizer (paddle.fluid.optimizer): 
            params_grads (list) list of tuple that contain param and its grad variable
            dist_main_program (paddle.fluid.framework.program): dist main program with forward & backward network 
            dist_startup_program (paddle.fluid.framework.program): dist startup program with forward & backward  network 
        """

        optimize_ops = self.apply_optimize_impl(user_define_optimizer,
                                                params_grads, dist_main_program,
                                                dist_startup_program)

        return optimize_ops

    def transpile_forward_impl(self, main_program, startup_program):

        if not isinstance(main_program, (Program)):
            raise TypeError(
                "dist_strategy be paddle.fluid.framework.program, got %s here" %
                type(main_program))

        if not isinstance(startup_program, (Program)):
            raise TypeError(
                "auto_parallel_context be paddle.fluid.framework.program, got %s here"
                % type(startup_program))

        # check if shard annotated serial program valid
        if not self._is_valid_annotated_program(main_program):
            raise RuntimeError(
                "Not all vars or ops are annotated in main program !")

        # determine parallelism mode
        self._determine_parallel_mode(main_program)

        # dist op & partition vars
        new_main_prog, new_startup_program = self._dist_var_op_forward_transpile(
            main_program, startup_program)

        # Sharding
        if self._dist_strategy.sharding:
            new_main_prog, new_startup_program = self._sharding_forward_transpile(
                new_main_prog, new_startup_program)

        return new_main_prog, new_startup_program

    def apply_backward_impl(self,
                            serial_loss,
                            serial_main_program,
                            serial_startup_program,
                            dist_main_program,
                            dist_startup_program,
                            parameter_list=None,
                            no_grad_set=None,
                            callbacks=None):
        """
        """

        params_grads = self._dist_var_op_backward_transpile(
            serial_loss, serial_main_program, serial_startup_program,
            dist_main_program, dist_startup_program)
        # Sharding
        if self._dist_strategy.sharding:
            self._sharding_backward_transpile(new_main_prog,
                                              new_startup_program)

        # Data Parallel pass
        if self._enable_data_parallel:
            self._gradient_sync_transpile(dist_main_program,
                                          dist_startup_program)

        return params_grads

    def apply_optimize_impl(self, user_define_optimizer, params_grads,
                            dist_main_program, dist_startup_program):
        """
        append update related ops to the program: clip, weight decay, ops
        filter optimize op if sharding is enable
        naive gradient synchronization before update

        Args:
            user_define_optimizer (paddle.fluid.optimizer): 
            params_grads (list) list of tuple that contain param and its grad variable
            dist_main_program (paddle.fluid.framework.program): dist main program with forward & backward network 
            dist_startup_program (paddle.fluid.framework.program): dist startup program with forward & backward  network 
        """

        if self._dist_strategy.sharding:
            params_grads = sharding_optimize_transpile(
                params_grads, dist_main_program, dist_startup_program)

        optimize_ops = self._optimize_transpile(user_define_optimizer,
                                                params_grads, dist_main_program,
                                                dist_startup_program)

        return optimize_ops

    def _dist_var_op_forward_transpile(self,
                                       serial_main_program,
                                       serial_startup_program=None):
        """
        1. partition variables
        2. replace local op with corresponding dist op
        """

        partitioned_main_prog = fluid.Program()
        partitioned_global_block = partitioned_main_prog.global_block()
        serial_global_block = serial_main_program.global_block()
        serial_ops = serial_main_program.global_block().ops

        # transpile main program
        for op in serial_ops:

            # partititon input variables
            for serial_input_varname in op.desc.input_arg_names():
                if serial_input_varname not in self._serial2dist_varname_mapping:
                    new_varname = serial_input_varname + self._dist_varname_suffix
                    if serial_global_block.has_var(serial_input_varname):
                        _partition_var(self._auto_parallel_context,
                                       serial_global_block,
                                       partitioned_global_block,
                                       serial_input_varname, new_varname)
                    else:
                        assert serial_input_varname in __varname_not_in_block__

                    self._serial2dist_varname_mapping[
                        serial_input_varname] = new_varname

            # partition output vars
            for serial_output_varname in op.desc.output_arg_names():
                if serial_output_varname not in self._serial2dist_varname_mapping:
                    new_varname = serial_output_varname + self._dist_varname_suffix
                    _partition_var(self._auto_parallel_context,
                                   serial_global_block,
                                   partitioned_global_block,
                                   serial_output_varname, new_varname)
                    self._serial2dist_varname_mapping[
                        serial_output_varname] = new_varname

            # partition op
            if _found_match_dist_op(self._auto_parallel_context, op):
                # replace with corresponding dist op
                _insert_dist_op(op, partitioned_global_block,
                                self._serial2dist_varname_mapping,
                                self._auto_parallel_context, self._rank_id)
            else:
                # replicate op
                _insert_src_op(op, partitioned_global_block,
                               self._serial2dist_varname_mapping)

        # transpile startup program
        if serial_startup_program == None:
            partitioned_startup_prog = None
        else:
            partitioned_startup_prog = fluid.Program()
            # create parameter
            partitioned_startup_global_block = partitioned_startup_prog.global_block(
            )
            param2shape = {}
            for var in partitioned_main_prog.list_vars():
                if isinstance(var, Parameter):
                    _partition_parameter(self._auto_parallel_context, var,
                                         partitioned_startup_global_block,
                                         var.name, var.shape)
                    param2shape[var.name] = var.shape

            # copy initializer
            for op in serial_startup_program.global_block().ops:
                output_vars = op.desc.output_arg_names()
                assert len(
                    output_vars
                ) == 1, "initializer should output only ONE variable, but got [{}]".format(
                    str(op.desc))
                assert self._serial2dist_varname_mapping[output_vars[
                    0]] in param2shape, "try to initialize [{}] which is not a Parameter".format(
                        output_vars[0])
                new_op_desc = partitioned_startup_global_block.desc.append_op()
                new_op_desc.copy_from(op.desc)
                new_op_desc._rename_output(
                    output_vars[0],
                    self._serial2dist_varname_mapping[output_vars[0]])
                new_op_desc._set_attr("shape", param2shape[
                    self._serial2dist_varname_mapping[output_vars[0]]])
                partitioned_startup_global_block._sync_with_cpp()

            # MP broadcast not split parameter
            # NOTE Theoretically, the MP param init broadcast should be handled by
            # each dist op itself. but if we insert the broadcast op at that moment, the broadcast
            # will before the initializer, which lead to a undertermined case.
            if self._enable_tensor_parallel:
                param_to_sync = []
                for param in partitioned_startup_prog.all_parameters():
                    if not self._is_var_distributed(param):
                        param_to_sync.append(param)
                        # FIXME the ring id should be set by autoparallel.mapping module
                        # it should be determined by dp groups butfixed it here for hacking
                        partitioned_startup_global_block.append_op(
                            type='c_broadcast',
                            inputs={'X': param},
                            outputs={'Out': param},
                            attrs={
                                'ring_id': self._tp_group.id,
                                'root': 0,
                                'use_calc_stream': True,
                                OP_ROLE_KEY: OpRole.Forward
                            })
                partitioned_startup_global_block.append_op(
                    type='c_sync_comm_stream',
                    inputs={'X': param_to_sync},
                    outputs={'Out': param_to_sync},
                    attrs={
                        'ring_id': self._tp_group.id,
                        OP_ROLE_KEY: OpRole.Forward
                    })
                partitioned_startup_global_block._sync_with_cpp()

            # DP init param broadcast
            if self._enable_data_parallel:
                # parameters initialization synchronization 
                param_to_sync = []

                for param in partitioned_startup_global_block.all_parameters():
                    param_to_sync.append(param)

                    # FIXME the ring id should be set by autoparallel.mapping module
                    # it should be determined by dp groups butfixed it here for hacking
                    partitioned_startup_global_block.append_op(
                        type='c_broadcast',
                        inputs={'X': param},
                        outputs={'Out': param},
                        attrs={
                            'ring_id': self._dp_group.id,
                            'root': 0,
                            'use_calc_stream': True,
                            OP_ROLE_KEY: OpRole.Forward
                        })
                partitioned_startup_global_block.append_op(
                    type='c_sync_comm_stream',
                    inputs={'X': param_to_sync},
                    outputs={'Out': param_to_sync},
                    attrs={
                        'ring_id': self._dp_group.id,
                        OP_ROLE_KEY: OpRole.Forward
                    })
                partitioned_startup_global_block._sync_with_cpp()

        return partitioned_main_prog, partitioned_startup_prog

    def _dist_var_op_backward_transpile(self,
                                        serial_loss,
                                        serial_main_program,
                                        serial_startup_program,
                                        dist_main_program,
                                        dist_startup_program,
                                        parameter_list=None,
                                        no_grad_set=None,
                                        callbacks=None):
        """
        so far, the auto_backward case only guarantee the correcotness of backward ops for curtain Dist ops:
            1. NV-Megatron-like parallel embedding
            2. NV-Megatron-like row parallel linear
            3. NV-Megatron-like col parallel linear
        """

        if self._compatible_with_auto_backward:
            assert isinstance(
                serial_loss, Variable), "The target loss should be an Variable."
            dist_loss = self._serial_varname2dist_var(serial_loss.name,
                                                      dist_main_program)

            assert len(dist_loss.shape) == 1 and dist_loss.shape[0] == 1, \
                "The dist loss.shape should be (1L,), but the current dist loss.shape is {}. " \
                "Maybe that you should call fluid.layers.mean to process the current loss.".format(
                    dist_loss.shape)

            # update parameter list
            if parameter_list:
                parameter_list = [
                    self._serial_varname2dist_var(param.name, dist_main_program)
                    for param in parameter_list
                ]

            # update parameter no_grad_set
            if no_grad_set:
                no_grad_set = [
                    self._serial_varname2dist_var(param.name, dist_main_program)
                    for param in no_grad_set
                ]

            return _auto_backward(
                dist_loss,
                dist_startup_program,
                parameter_list=parameter_list,
                no_grad_set=no_grad_set,
                callbacks=callbacks)
        # replace dist grad ops
        else:
            raise RuntimeError("transpile NOT implemented !")

    def _optimize_transpile(self, user_define_optimizer, params_grads,
                            main_program, startup_program):

        with program_guard(main_program, startup_program):
            optimize_ops = user_define_optimizer.apply_gradients(params_grads)

        return optimize_ops

    def _is_valid_annotated_program(self, program):

        # TODO (ZJ-LIANG) should check all block
        ops = program.global_block().ops
        vars_ = program.list_vars()
        op_dist_attrs = [
            self._auto_parallel_context.get_op_distributed_attr_for_program(op)
            for op in ops
        ]
        var_dist_attrs = [
            self._auto_parallel_context.get_tensor_distributed_attr_for_program(
                var) for var in vars_
        ]

        all_ops_annotated = all(dist_attr is not None
                                for dist_attr in op_dist_attrs)
        all_vars_annotated = all(dist_attr is not None
                                 for dist_attr in var_dist_attrs)

        return all_ops_annotated and all_vars_annotated

    def _serial_varname2dist_var(self, serial_varname, dist_program):
        assert serial_varname in self._serial2dist_varname_mapping, "The serial var [{}] is not found in var name mapping".format(
            serial_varname)
        dist_varname = self._serial2dist_varname_mapping[serial_varname]

        assert dist_program.global_block().has_var(
            dist_varname
        ), "The dist var [{}] is not found in dist program".format(dist_varname)
        dist_var = dist_program.global_block().var(dist_varname)

        return dist_var

    def _determine_parallel_mode(self, program):
        """
        determine the parallelism that is enabled
        NOTE a hard rule and should be updated in future
        """

        for param in program.all_parameters():
            if self._is_var_distributed(param):
                self._enable_tensor_parallel = True
                break

        for var in program.list_vars():
            var_dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program(
                var)
            if not var_dist_attr.is_parameter():
                mapping = var_dist_attr.get_dims_mapping()
                mesh = var_dist_attr.get_process_mesh().topology
                if mapping[0] >= 0 and mesh[mapping[0]] > 1:
                    self._enable_data_parallel = True
                    break

        # tensor parallelism
        if self._enable_tensor_parallel:
            model_parallel_axis, process_mesh = self._auto_parallel_context._get_model_parallel_info(
            )
            group_ranks = _get_comm_group(process_mesh.process_group,
                                          process_mesh.topology,
                                          model_parallel_axis, self._rank_id)
            self._tp_degree = len(group_ranks)
            self._tp_group = new_process_group(group_ranks)

        # data parallelism
        data_parallel_axis, process_mesh = self._auto_parallel_context._get_data_parallel_info(
        )
        if self._enable_data_parallel:
            group_ranks = _get_comm_group(process_mesh.process_group,
                                          process_mesh.topology,
                                          data_parallel_axis, self._rank_id)
            self._dp_degree = len(group_ranks)
            self._dp_group = new_process_group(group_ranks)

    def _is_var_distributed(self, var):

        dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program(
            var)
        assert dist_attr is not None, "dist_attr of var [{}] is None".format(
            var.name)
        return _is_distributed(dist_attr)

    def _sharding_forward_transpile(self, main_prog, startup_program):
        """
        this transpile conduct the modification in forward program need by sharding strategy
        which majorly include:
            1. partition the parameter
            2. insert broadcast op
            3. insert sync op 

        NOTE the transpile modification is inplace on the input program
        """

        raise NotImplementedError(
            "Sharding is NOT support in AutoParallel yet!")

    def _sharding_backward_transpile(self, main_prog, startup_program):
        """
        this transpile conduct the modification in backward program need by sharding strategy
        which majorly include:
            1. partition the gradient
            2. insert broadcast op
            3. insert sync op 

        NOTE the transpile modification is inplace on the input program
        """

        raise NotImplementedError(
            "Sharding is NOT support in AutoParallel yet!")

    def _sharding_optimize_transpile(self, params_grads, dist_main_program,
                                     dist_startup_program):
        """
        shard params_grads
        append the broadcast to sync parameters 
        """
        raise RuntimeError("sharding transpile is NOT implemented !")

    def _gradient_sync_transpile(self, main_program, startup_program):
        """
        append the gradient allreduce ops for all parameters' grad in case of Data Parallel
        """

        # scale loss by dp degree
        main_global_block = main_program.global_block()
        for idx, op in reversed(list(enumerate(main_global_block.ops))):
            if is_loss_grad_op(op):
                loss_grad_var = main_global_block.vars[op.output_arg_names[0]]
                main_global_block._insert_op_without_sync(
                    idx + 1,
                    type='scale',
                    inputs={'X': loss_grad_var},
                    outputs={'Out': loss_grad_var},
                    attrs={
                        'scale': 1.0 / self._dp_degree,
                        OP_ROLE_KEY: OpRole.Backward
                    })
                break
        main_global_block._sync_with_cpp()

        # gradient synchronization
        # NOTE naive gradient sync without overlapping
        # so there is not need to sync between calc and comm
        # collecting grad var
        grad_to_sync = []
        for idx, op in reversed(list(enumerate(main_global_block.ops))):
            if is_backward_op(op) and \
                    OP_ROLE_VAR_KEY in op.attr_names:
                op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
                if len(op_role_var) != 0:
                    assert len(op_role_var) % 2 == 0
                    for i in range(0, len(op_role_var), 2):
                        param, reduced_grad = op_role_var[i], op_role_var[i + 1]
                        assert (reduced_grad not in grad_to_sync)
                        grad_to_sync.append(reduced_grad)
            if is_optimizer_op(op):
                first_optimize_op_idx = idx

        # insert allreduce
        for grad in grad_to_sync:
            # FIXME the ring id should be set by autoparallel.mapping module
            # it should be determined by dp groups butfixed it here for hacking
            main_global_block.append_op(
                type='c_allreduce_sum',
                inputs={'X': grad},
                outputs={'Out': grad},
                attrs={
                    'ring_id': self._dp_group.id,
                    'root': 0,
                    'use_calc_stream': True,
                    OP_ROLE_KEY: OpRole.Backward
                })
        main_global_block.append_op(
            type='c_sync_comm_stream',
            inputs={'X': grad_to_sync},
            outputs={'Out': grad_to_sync},
            attrs={'ring_id': self._dp_group.id,
                   OP_ROLE_KEY: OpRole.Backward})
        main_global_block._sync_with_cpp()


def _get_no_grad_set_name(no_grad_set):
    no_grad_set_name = set()
    if no_grad_set is not None:
        if isinstance(no_grad_set, (set, list, tuple)):
            for i, no_grad_var in enumerate(no_grad_set):
                if isinstance(no_grad_var, framework.Variable):
                    no_grad_set_name.add(no_grad_var.name)
                elif isinstance(no_grad_var, six.string_types):
                    no_grad_set_name.add(no_grad_var)
                else:
                    raise TypeError(
                        "The type of no_grad_set's member must be paddle.fluid.Variable or str, but received %s."
                        % (type(no_grad_var)))
        else:
            raise TypeError(
                "The type of no_grad_set should be set or list or tuple, but received {}".
                format(type(no_grad_set)))
    return no_grad_set_name


def _get_no_grad_set(loss, no_grad_set=None):
    no_grad_set = _get_no_grad_set_name(no_grad_set)
    parameters = loss.block.program.global_block().all_parameters()
    param_no_trainable = set(
        [param.name for param in parameters if param.trainable is False])
    # If the parameter is no trainable, it should not have a gradient.
    no_grad_set.update(param_no_trainable)

    return no_grad_set


def _found_match_dist_op(auto_paralle_context, op):
    dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(op)
    dist_ops = get_distributed_operator(op.type)

    return dist_ops and dist_attr.get_impl_idx() >= 0 and dist_ops.get_impl( \
        dist_attr.get_impl_idx())._forward_implemented


def _auto_backward(loss,
                   startup_program=None,
                   parameter_list=None,
                   no_grad_set=None,
                   callbacks=None):
    """
    modification is inplaced
    """
    act_no_grad_set = _get_no_grad_set(loss, no_grad_set)
    assert isinstance(loss, Variable), "The target loss should be an Variable."

    if callbacks is None:
        callbacks = [error_clip_callback]
    else:
        assert (isinstance(callbacks, list))

    assert len(loss.shape) == 1 and loss.shape[0] == 1, \
        "The loss.shape should be (1L,), but the current loss.shape is {}. " \
        "Maybe that you should call fluid.layers.mean to process the current loss.".format(
            loss.shape)

    program = loss.block.program
    with program_guard(program, startup_program):
        params_grads = append_backward(loss, parameter_list, act_no_grad_set,
                                       callbacks)

    return params_grads


def _is_distributed(dist_attr):

    mapping = dist_attr.get_dims_mapping()
    mesh = dist_attr.get_process_mesh().topology
    for idx in range(len(mapping)):
        if mapping[idx] >= 0 and mesh[mapping[idx]] > 1:
            return True

    return False


def _get_dist_shape(var, dist_attr):

    var_shape = var.shape
    mapping = dist_attr.get_dims_mapping()
    mesh = dist_attr.get_process_mesh().topology
    assert len(var_shape) == len(
        mapping
    ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
        var_shape, mapping)
    new_shape = []
    for idx in range(len(var_shape)):
        if var_shape[idx] == -1 or mapping[idx] == -1:
            new_shape.append(var_shape[idx])
        else:
            assert var_shape[idx] % mesh[mapping[
                idx]] == 0, "un-event partition: var_shape[idx]=[{}], mesh[{}]".format(
                    var_shape[idx], mesh[mapping[idx]])
            new_shape.append(var_shape[idx] // mesh[mapping[idx]])

    return new_shape


def _partition_parameter(auto_paralle_context, src_var, dst_block, dst_varname,
                         dst_shape):
    # NOTE hack to copied Parameter
    # not initialized parameter, need to initialize it 
    copied_kwargs = {}
    copied_kwargs['trainable'] = src_var.trainable
    copied_kwargs['optimize_attr'] = src_var.optimize_attr
    copied_kwargs['regularizer'] = src_var.regularizer
    copied_kwargs['do_model_average'] = src_var.do_model_average
    copied_kwargs['need_clip'] = src_var.need_clip

    param = Parameter(
        block=dst_block,
        type=src_var.type,
        name=dst_varname,
        shape=dst_shape,
        dtype=src_var.dtype,
        lod_level=src_var.lod_level,
        error_clip=src_var.error_clip,
        stop_gradient=src_var.stop_gradient,
        is_data=src_var.is_data,
        belong_to_optimizer=src_var.belong_to_optimizer,
        **copied_kwargs)

    # set dist attr uid
    # distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
    # param.desc.set_distributed_attr_uid(distributed_attr_uid)
    dist_attr = copy.deepcopy(
        auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
    dist_attr._owner_tensor = param
    dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
        src_var)._owner_context
    auto_paralle_context.set_tensor_distributed_attr_for_program(param,
                                                                 dist_attr)


def _partition_intermediate_var(auto_paralle_context, src_var, dst_block,
                                dst_varname, dst_shape):
    var = dst_block.create_var(
        type=src_var.type,
        name=dst_varname,
        shape=dst_shape,
        dtype=src_var.dtype,
        lod_level=src_var.lod_level,
        persistable=src_var.persistable,
        error_clip=src_var.error_clip,
        stop_gradient=src_var.stop_gradient,
        is_data=src_var.is_data,
        belong_to_optimizer=src_var.belong_to_optimizer)

    # set dist attr uid
    # distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
    # var.desc.set_distributed_attr_uid(distributed_attr_uid)
    dist_attr = copy.deepcopy(
        auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
    dist_attr._owner_tensor = var
    dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
        src_var)._owner_context
    auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr)


def _partition_var(auto_paralle_context, src_block, dst_block, src_varname,
                   dst_varname):
    """
    partition include: split + replicate
    """
    src_var = src_block.var(src_varname)

    if src_var.type == core.VarDesc.VarType.READER:
        dst_block.create_var(
            type=src_var.type,
            name=dst_varname,
            persistable=True,
            stop_gradient=True)
    else:
        dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
            src_var)
        target_shape = _get_dist_shape(src_var, dist_attr)

        if isinstance(src_var, Parameter):
            _partition_parameter(auto_paralle_context, src_var, dst_block,
                                 dst_varname, target_shape)
        else:
            _partition_intermediate_var(auto_paralle_context, src_var,
                                        dst_block, dst_varname, target_shape)


def _insert_src_op(src_op, dst_block, varname_mapping):

    new_op_desc = dst_block.desc.append_op()
    new_op_desc.copy_from(src_op.desc)
    for local_varname in src_op.desc.input_arg_names():
        new_op_desc._rename_input(local_varname, varname_mapping[local_varname])
    for local_varname in src_op.desc.output_arg_names():
        new_op_desc._rename_output(local_varname,
                                   varname_mapping[local_varname])
    dst_block._sync_with_cpp()


def _insert_dist_op(src_op, dst_block, varname_mapping, auto_paralle_context,
                    rank_id):

    # build input varname mapping
    input_mapping = {}
    for input_name in src_op.desc.input_names():
        varnames = []
        for varname in src_op.desc.input(input_name):
            varnames.append(varname_mapping[varname])
        input_mapping[input_name] = varnames

    # build output varname mapping
    output_mapping = {}
    for output_name in src_op.desc.output_names():
        varnames = []
        for varname in src_op.desc.output(output_name):
            varnames.append(varname_mapping[varname])
        output_mapping[output_name] = varnames

    # append dist op 
    dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(src_op)
    dist_ops = get_distributed_operator(src_op.type)
    append_op_handle = dist_ops.get_impl(dist_attr.get_impl_idx()).forward(
        src_op)
    append_op_handle(
        dst_block,
        src_op,
        dist_attr,
        input_mapping,
        output_mapping,
        rank_id=rank_id)