partitioner.py 37.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import copy
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid import framework as framework
from paddle.fluid import core, unique_name
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.backward import append_backward, _some_in_set_, _append_grad_suffix_
25
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
26 27
from paddle.fluid.clip import GradientClipBase, GradientClipByNorm, error_clip_callback, append_gradient_clip_ops, ClipGradByGlobalNorm
from paddle.distributed.fleet.base.distributed_strategy import DistributedStrategy
28
from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext
29 30
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
31 32 33
from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group
from .utils import print_program_with_dist_attr
34
from paddle.distributed.auto_parallel.completion import complete_backward_annotation, complete_update_annotation
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70

__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]


class Partitioner(object):
    """
    warning:: Partitioner is experimental and subject to change.

    Partitioner convert a program into another program.
    Given a serial program which has been auto completed with shard annotation, the Partitioner 
    convert the serial program into a "distributed" program. The Partitioner will  modify the serial
    program in following two ways, which is also the major difference between serial and distributed program:
        1. partition op: replace a serial op into its corresponding dist op infered from the shard annotation
        2. partition var: if a var is sharded, modify the shape of var according to its shard annotation

    Partitioner is supposed to be call by the auto parallel framework, and not supposed to be directly called by user.

    Example:
        ....
            import paddle.distributed.auto_parallel as auto
            from paddle.fluid.distributed_attribute import get_default_distributed_context
            from paddle.distributed import fleet
            from paddle.distributed.auto_parallel.partitioner import Partitioner

            # create serial program with forward only 
            with static.program_guard(serial_main_program, serial_start_program):
                model = create_model(config)
                tokens = static.data(name="tokens", shape=[batch_size, sequence_len], dtype='int64')
                labels = static.data(name="labels", shape=[batch_size, sequence_len], dtype='int64')
                loss_mask = static.data(name="loss_mask", shape=[batch_size, sequence_len], dtype='int64')
                preds = model(tokens)
                loss = criterion(preds, labels, loss_mask)

            # auto completion
            auto.ProcessMesh(shape=[2, 4], process_group=[0, 1, 2, 3, 4, 5, 6, 7])
            annotated_main_program = auto.complete_annotation(serial_main_program)
71
            dist_context = get_default_distributed_context()
72 73 74 75 76 77
                
            # distributed strategy & rank info
            rank_id = paddle.distributed.get_rank()
            dist_strategy = fleet.DistributedStrategy()
    
            # create partitioner
78
            Partitioner = Partitioner(dist_strategy, dist_context, rank_id)
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95

            # create dist program with forward only
            # for distributed inference, using partitioned_main_prog from here
            partitioned_main_prog, partitioned_startup_prog = Partitioner.transpile_forward(complete_train_program, start_program)

            # create dist program with forward/backward/update
            # for distributed training, using partitioned_main_prog from here
            dist_params_grads = Partitioner.apply_backward(loss, complete_train_program, start_program, partitioned_main_prog, partitioned_startup_prog)
            optimizer = paddle.fluid.optimizer.AdamOptimizer(
                learning_rate=0.00001,
                beta1=0.9,
                beta2=0.999,
                epsilon=1e-08,
                grad_clip=None)
            opt_ops = Partitioner.apply_optimize(optimizer, dist_params_grads, partitioned_main_prog, partitioned_startup_prog)
    """

96
    def __init__(self, dist_strategy, dist_context, rank_id=0):
97 98 99
        """
        Args:
            dist_strategy (paddle.fleet.distributed_strategy): used to determine the user defined distributed strategy.
100
            dist_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario.
101 102 103 104 105 106 107 108
            rank_id (int): global rank id to which the partitioned distributed program belong.
        """

        if not isinstance(dist_strategy, DistributedStrategy):
            raise TypeError(
                "dist_strategy be paddle.fleet.base.DistributedStrategy, got %s here"
                % type(dist_strategy))

109
        if not isinstance(dist_context, DistributedContext):
110
            raise TypeError(
111 112
                "dist_context be paddle.fluid.DistributedContext, got %s here" %
                type(dist_context))
113 114

        self._dist_strategy = dist_strategy
115
        self._dist_context = dist_context
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
        self._rank_id = rank_id
        self._serial2dist_varname_mapping = {}
        self._dist_varname_suffix = ""

        # TODO if there is some dist op that is not compatible 
        # with auto_backward in forward, the following flag 
        # should be set to False
        self._compatible_with_auto_backward = True

    def transpile_forward(self, serial_main_program, serial_startup_program):
        """
        take serial forward programs with shard annotation, create a new distributed forward programs based on the serial ones.
        instead of modify the input programs inplace, this function will preserve the inputs and create new program for output.

        beside replace the serial op with its dist op, if user has defined other strategy in fleet.distributed_strategy, and if 
        those strategy need to transpile (modify) the forward network program, those forward program modification should also be done within this
        function in auto parallel scenario, in order to facilitate distributed inference/evaluation which need to DECOUPLE strategy specific forward transpilation with fleet.distributed_optimizer.minimize().

        by now the fleet.distributed_strategy that need transpile forward program are following: 
            1. (optimizer) sharding

        Args:
            main_program (paddle.fluid.framework.program): serial main program with forward network only
            startup_program (paddle.fluid.framework.program): serial startup program with forward network only
        
        return:
            main_program (paddle.fluid.framework.program): distributed main program with forward network only
            startup_program (paddle.fluid.framework.program): distributed startup program with forward network only
        """

        dist_main_program, dist_startup_program = self.transpile_forward_impl(
            serial_main_program, serial_startup_program)
        return dist_main_program, dist_startup_program

    def apply_backward(self,
                       serial_loss,
                       serial_main_program,
                       serial_startup_program,
                       dist_main_program,
                       dist_startup_program,
                       parameter_list=None,
                       no_grad_set=None,
                       callbacks=None):
        """
        A complete training neural network is made up of forward and backward propagation. 
        This function is to generate the dist backward program for the distributed forward program.

        By now, the current automatical backward mechanism in paddle framework might NOT handle the backward generation for 
        some dist ops correctly, some so we now have two ways to genenate the backward program:
            1. dist_forward_program --> auto_backward --> dist_backward_program (if auto_backward could handle all dist op)
            2. serial_forward_program --> auto_backward --> serial_backward_program --> dist_op_backward_transpile --> dist_backward_program (if auto_backward could not handle all dist op)
        
        the backprogram is append the input dist program inplaced.

        Args:
            serial_loss (Variable) the loss in serial program that to be minimized 
            serial_main_program (paddle.fluid.framework.program): serial main program with forward network only
            serial_startup_program (paddle.fluid.framework.program): serial startup program with forward network only
            dist_main_program (paddle.fluid.framework.program): dist main program with forward network only
            dist_startup_program (paddle.fluid.framework.program): dist startup program with forward network only
            parameter_list (Iterable, optional): Iterable of ``Variable`` or ``Variable.name`` to update
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
            no_grad_set (set, optional): Set of ``Variable``  or ``Variable.name`` that don't need
                to be updated. The default value is None.
            callbacks (list, optional): list of callable objects to run when appending backward
                operator for one parameter. The default value is None.
        
        return:
            params_grads (list) list of tuple that contain param and its grad variable
        """
        params_grads = self.apply_backward_impl(
            serial_loss, serial_main_program, serial_startup_program,
            dist_main_program, dist_startup_program)
        return params_grads

    def apply_optimize(self, user_define_optimizer, params_grads,
                       dist_main_program, dist_startup_program):
        """
        append update related ops to the program: clip, weight decay, ops
        filter optimize op if sharding is enable
        naive gradient synchronization before update

        Args:
            user_define_optimizer (paddle.fluid.optimizer): 
            params_grads (list) list of tuple that contain param and its grad variable
            dist_main_program (paddle.fluid.framework.program): dist main program with forward & backward network 
            dist_startup_program (paddle.fluid.framework.program): dist startup program with forward & backward  network 
        """

        optimize_ops = self.apply_optimize_impl(user_define_optimizer,
                                                params_grads, dist_main_program,
                                                dist_startup_program)

        return optimize_ops

    def transpile_forward_impl(self, main_program, startup_program):

        if not isinstance(main_program, (Program)):
            raise TypeError(
                "dist_strategy be paddle.fluid.framework.program, got %s here" %
                type(main_program))

        if not isinstance(startup_program, (Program)):
            raise TypeError(
221 222
                "dist_context be paddle.fluid.framework.program, got %s here" %
                type(startup_program))
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295

        # check if shard annotated serial program valid
        if not self._is_valid_annotated_program(main_program):
            raise RuntimeError(
                "Not all vars or ops are annotated in main program !")

        # dist op & partition vars
        new_main_prog, new_startup_program = self._dist_var_op_forward_transpile(
            main_program, startup_program)

        # Sharding
        if self._dist_strategy.sharding:
            new_main_prog, new_startup_program = self._sharding_forward_transpile(
                new_main_prog, new_startup_program)

        return new_main_prog, new_startup_program

    def apply_backward_impl(self,
                            serial_loss,
                            serial_main_program,
                            serial_startup_program,
                            dist_main_program,
                            dist_startup_program,
                            parameter_list=None,
                            no_grad_set=None,
                            callbacks=None):
        """
        """

        params_grads = self._dist_var_op_backward_transpile(
            serial_loss, serial_main_program, serial_startup_program,
            dist_main_program, dist_startup_program)
        # Sharding
        if self._dist_strategy.sharding:
            self._sharding_backward_transpile(new_main_prog,
                                              new_startup_program)

        return params_grads

    def apply_optimize_impl(self, user_define_optimizer, params_grads,
                            dist_main_program, dist_startup_program):
        """
        append update related ops to the program: clip, weight decay, ops
        filter optimize op if sharding is enable
        naive gradient synchronization before update

        Args:
            user_define_optimizer (paddle.fluid.optimizer): 
            params_grads (list) list of tuple that contain param and its grad variable
            dist_main_program (paddle.fluid.framework.program): dist main program with forward & backward network 
            dist_startup_program (paddle.fluid.framework.program): dist startup program with forward & backward  network 
        """

        if self._dist_strategy.sharding:
            params_grads = sharding_optimize_transpile(
                params_grads, dist_main_program, dist_startup_program)

        optimize_ops = self._optimize_transpile(user_define_optimizer,
                                                params_grads, dist_main_program,
                                                dist_startup_program)

        return optimize_ops

    def _dist_var_op_forward_transpile(self,
                                       serial_main_program,
                                       serial_startup_program=None):
        """
        1. partition variables
        2. replace local op with corresponding dist op
        """

        partitioned_main_prog = fluid.Program()
        partitioned_global_block = partitioned_main_prog.global_block()
296
        serial_main_block = serial_main_program.global_block()
297 298
        serial_ops = serial_main_program.global_block().ops

299 300 301 302 303 304 305 306 307 308 309 310 311 312
        # transpile startup program
        if serial_startup_program == None:
            partitioned_startup_prog = None
        else:
            partitioned_startup_prog = fluid.Program()
            # create parameter
            partitioned_startup_global_block = partitioned_startup_prog.global_block(
            )
            param2shape = {}
            temp_varname_map = {}
            for var in serial_startup_program.list_vars():
                if isinstance(var, Parameter):
                    # TODO if var not belong to this rank, should be filtered
                    serial_main_var = serial_main_block.var(var.name)
313
                    dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
314 315 316 317
                        serial_main_var)
                    target_shape = _get_dist_shape(serial_main_var, dist_attr)
                    new_name = var.name + self._dist_varname_suffix
                    temp_varname_map[var.name] = new_name
318
                    _partition_parameter(self._dist_context, serial_main_var,
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
                                         partitioned_startup_global_block,
                                         new_name, target_shape)
                    param2shape[new_name] = target_shape

            # copy initializer
            for op in serial_startup_program.global_block().ops:
                # TODO if var not belong to this rank, should be filtered
                output_vars = op.desc.output_arg_names()
                assert len(
                    output_vars
                ) == 1, "initializer should output only ONE variable, but got [{}]".format(
                    str(op.desc))
                assert temp_varname_map[output_vars[
                    0]] in param2shape, "try to initialize [{}] which is not a Parameter".format(
                        output_vars[0])
                new_op_desc = partitioned_startup_global_block.desc.append_op()
                new_op_desc.copy_from(op.desc)
                new_op_desc._rename_output(output_vars[0],
                                           temp_varname_map[output_vars[0]])
                new_op_desc._set_attr(
                    "shape", param2shape[temp_varname_map[output_vars[0]]])
                partitioned_startup_global_block._sync_with_cpp()

                # set distribute atrribute
                new_op = partitioned_startup_global_block.ops[-1]
                assert new_op.type == new_op_desc.type()
                assert new_op.desc == new_op_desc
                output_var = partitioned_startup_global_block.var(output_vars[
                    0])
348
                output_var_attr = self._dist_context.get_tensor_dist_attr_for_program(
349
                    output_var)
350 351 352 353 354 355 356
                op_attr = OperatorDistributedAttribute()
                op_attr.process_mesh = output_var_attr.process_mesh
                op_attr.set_output_dims_mapping(output_var.name,
                                                output_var_attr.dims_mapping)
                op_attr.set_input_dims_mapping(output_var.name,
                                               output_var_attr.dims_mapping)
                self._dist_context.set_op_dist_attr_for_program(new_op, op_attr)
357 358

        # TODO move helper init to a comm place
359 360 361 362 363
        dist_op_context = self._dist_context.dist_op_context
        dist_op_context.set_dst_main_program(partitioned_main_prog)
        dist_op_context.set_dst_startup_program(partitioned_startup_prog)
        dist_op_context.set_varname_mapping(self._serial2dist_varname_mapping)
        dist_op_context.set_rank_id(self._rank_id)
364

365 366 367 368 369 370 371
        # transpile main program
        for op in serial_ops:

            # partititon input variables
            for serial_input_varname in op.desc.input_arg_names():
                if serial_input_varname not in self._serial2dist_varname_mapping:
                    new_varname = serial_input_varname + self._dist_varname_suffix
372
                    if serial_main_block.has_var(serial_input_varname):
373
                        _partition_var(self._dist_context, serial_main_block,
374 375 376 377 378 379 380 381 382 383 384 385
                                       partitioned_global_block,
                                       serial_input_varname, new_varname)
                    else:
                        assert serial_input_varname in __varname_not_in_block__

                    self._serial2dist_varname_mapping[
                        serial_input_varname] = new_varname

            # partition output vars
            for serial_output_varname in op.desc.output_arg_names():
                if serial_output_varname not in self._serial2dist_varname_mapping:
                    new_varname = serial_output_varname + self._dist_varname_suffix
386 387
                    _partition_var(self._dist_context, serial_main_block,
                                   partitioned_global_block,
388 389 390 391 392
                                   serial_output_varname, new_varname)
                    self._serial2dist_varname_mapping[
                        serial_output_varname] = new_varname

            # partition op
393 394 395 396 397 398
            kinputs, koutputs = dist_op_context.prepare_forward_context(op)
            dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
            if _is_dist_op_forward_implement(self._dist_context, op):
                dist_ops = get_distributed_operator_impl_container(op.type)
                dist_op_impl = dist_ops.get_impl(dist_attr.impl_idx)
                dist_op_impl.forward(self._dist_context, **kinputs, **koutputs)
399

400 401
            else:
                # replicate op
402
                dist_ops = get_distributed_operator_impl_container("default")
403
                dist_op_impl = dist_ops.get_impl(0)
404
                dist_op_impl.forward(self._dist_context, **kinputs, **koutputs)
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448

        return partitioned_main_prog, partitioned_startup_prog

    def _dist_var_op_backward_transpile(self,
                                        serial_loss,
                                        serial_main_program,
                                        serial_startup_program,
                                        dist_main_program,
                                        dist_startup_program,
                                        parameter_list=None,
                                        no_grad_set=None,
                                        callbacks=None):
        """
        so far, the auto_backward case only guarantee the correcotness of backward ops for curtain Dist ops:
            1. NV-Megatron-like parallel embedding
            2. NV-Megatron-like row parallel linear
            3. NV-Megatron-like col parallel linear
        """

        if self._compatible_with_auto_backward:
            assert isinstance(
                serial_loss, Variable), "The target loss should be an Variable."
            dist_loss = self._serial_varname2dist_var(serial_loss.name,
                                                      dist_main_program)

            assert len(dist_loss.shape) == 1 and dist_loss.shape[0] == 1, \
                "The dist loss.shape should be (1L,), but the current dist loss.shape is {}. " \
                "Maybe that you should call fluid.layers.mean to process the current loss.".format(
                    dist_loss.shape)

            # update parameter list
            if parameter_list:
                parameter_list = [
                    self._serial_varname2dist_var(param.name, dist_main_program)
                    for param in parameter_list
                ]

            # update parameter no_grad_set
            if no_grad_set:
                no_grad_set = [
                    self._serial_varname2dist_var(param.name, dist_main_program)
                    for param in no_grad_set
                ]

449
            dist_op_context = self._dist_context.dist_op_context
450
            params_and_grads = _auto_backward(
451 452 453 454
                dist_loss,
                dist_startup_program,
                parameter_list=parameter_list,
                no_grad_set=no_grad_set,
455
                callbacks=callbacks,
456
                distop_context=dist_op_context)
457 458 459

            # backward completion 
            complete_backward_annotation(
460
                dist_main_program, dist_context=self._dist_context)
461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480

            # transpiler backward for dist op
            # get backward ops
            ops = dist_main_program.global_block().ops
            first_backward_op_idx = -1
            forward_op_id2forward_op = {}
            for idx in range(len(ops)):
                if is_forward_op(ops[idx]):
                    forward_op_id2forward_op[ops[idx].desc.id()] = ops[idx]

                if int(ops[idx].attr('op_role')) == int(OpRole.Backward):
                    first_backward_op_idx = idx
                    break
            assert first_backward_op_idx >= 0, "not found backward ops in program"
            assert len(forward_op_id2forward_op
                       ) > 0, "not found forward ops in program"

            backward_ops = ops[first_backward_op_idx:]
            for backward_op in backward_ops:
                # if the backward op has a corresponding forward op
481 482
                if backward_op.desc.id() in dist_op_context.gradopidx2opidx:
                    forward_op_id = dist_op_context.gradopidx2opidx[
483 484 485
                        backward_op.desc.id()]
                    forward_op = forward_op_id2forward_op[forward_op_id]
                    # TODO backward attr should has _impl_idx
486
                    forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
487 488
                        forward_op)
                    # TODO use the backward op itself to find the dist op
489 490 491
                    dist_ops = get_distributed_operator_impl_container(
                        forward_op.type)
                    kinputs, koutputs = dist_op_context.prepare_backward_context(
492 493 494
                        backward_op)

                    # TODO use backward op itself to determine impl idx
495 496
                    if _is_dist_op_backward_implement(self._dist_context,
                                                      forward_op):
497
                        dist_op_impl = dist_ops.get_impl(
498 499 500
                            forward_op_dist_attr.impl_idx)
                        dist_op_impl.backward(self._dist_context, **kinputs,
                                              **koutputs)
501 502
                    else:
                        # replicate op
503 504
                        dist_ops = get_distributed_operator_impl_container(
                            "default")
505
                        dist_op_impl = dist_ops.get_impl(0)
506 507
                        dist_op_impl.backward(self._dist_context, **kinputs,
                                              **koutputs)
508 509

            return params_and_grads
510 511 512 513 514 515 516 517 518 519
        # replace dist grad ops
        else:
            raise RuntimeError("transpile NOT implemented !")

    def _optimize_transpile(self, user_define_optimizer, params_grads,
                            main_program, startup_program):

        with program_guard(main_program, startup_program):
            optimize_ops = user_define_optimizer.apply_gradients(params_grads)

520 521
        # update completion 
        complete_update_annotation(
522
            main_program, dist_context=self._dist_context)
523

524 525 526 527 528 529 530 531
        return optimize_ops

    def _is_valid_annotated_program(self, program):

        # TODO (ZJ-LIANG) should check all block
        ops = program.global_block().ops
        vars_ = program.list_vars()
        op_dist_attrs = [
532
            self._dist_context.get_op_dist_attr_for_program(op) for op in ops
533 534
        ]
        var_dist_attrs = [
535 536
            self._dist_context.get_tensor_dist_attr_for_program(var)
            for var in vars_
537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559
        ]

        all_ops_annotated = all(dist_attr is not None
                                for dist_attr in op_dist_attrs)
        all_vars_annotated = all(dist_attr is not None
                                 for dist_attr in var_dist_attrs)

        return all_ops_annotated and all_vars_annotated

    def _serial_varname2dist_var(self, serial_varname, dist_program):
        assert serial_varname in self._serial2dist_varname_mapping, "The serial var [{}] is not found in var name mapping".format(
            serial_varname)
        dist_varname = self._serial2dist_varname_mapping[serial_varname]

        assert dist_program.global_block().has_var(
            dist_varname
        ), "The dist var [{}] is not found in dist program".format(dist_varname)
        dist_var = dist_program.global_block().var(dist_varname)

        return dist_var

    def _is_var_distributed(self, var):

560
        dist_attr = self._dist_context.get_tensor_dist_attr_for_program(var)
561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
        assert dist_attr is not None, "dist_attr of var [{}] is None".format(
            var.name)
        return _is_distributed(dist_attr)

    def _sharding_forward_transpile(self, main_prog, startup_program):
        """
        this transpile conduct the modification in forward program need by sharding strategy
        which majorly include:
            1. partition the parameter
            2. insert broadcast op
            3. insert sync op 

        NOTE the transpile modification is inplace on the input program
        """

        raise NotImplementedError(
            "Sharding is NOT support in AutoParallel yet!")

    def _sharding_backward_transpile(self, main_prog, startup_program):
        """
        this transpile conduct the modification in backward program need by sharding strategy
        which majorly include:
            1. partition the gradient
            2. insert broadcast op
            3. insert sync op 

        NOTE the transpile modification is inplace on the input program
        """

        raise NotImplementedError(
            "Sharding is NOT support in AutoParallel yet!")

    def _sharding_optimize_transpile(self, params_grads, dist_main_program,
                                     dist_startup_program):
        """
        shard params_grads
        append the broadcast to sync parameters 
        """
        raise RuntimeError("sharding transpile is NOT implemented !")


def _get_no_grad_set_name(no_grad_set):
    no_grad_set_name = set()
    if no_grad_set is not None:
        if isinstance(no_grad_set, (set, list, tuple)):
            for i, no_grad_var in enumerate(no_grad_set):
                if isinstance(no_grad_var, framework.Variable):
                    no_grad_set_name.add(no_grad_var.name)
                elif isinstance(no_grad_var, six.string_types):
                    no_grad_set_name.add(no_grad_var)
                else:
                    raise TypeError(
                        "The type of no_grad_set's member must be paddle.fluid.Variable or str, but received %s."
                        % (type(no_grad_var)))
        else:
            raise TypeError(
                "The type of no_grad_set should be set or list or tuple, but received {}".
                format(type(no_grad_set)))
    return no_grad_set_name


def _get_no_grad_set(loss, no_grad_set=None):
    no_grad_set = _get_no_grad_set_name(no_grad_set)
    parameters = loss.block.program.global_block().all_parameters()
    param_no_trainable = set(
        [param.name for param in parameters if param.trainable is False])
    # If the parameter is no trainable, it should not have a gradient.
    no_grad_set.update(param_no_trainable)

    return no_grad_set


633 634 635
def _is_dist_op_forward_implement(dist_context, op):
    dist_attr = dist_context.get_op_dist_attr_for_program(op)
    dist_ops = get_distributed_operator_impl_container(op.type)
636

637 638
    return dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( \
        dist_attr.impl_idx)._forward_implemented
639 640


641 642 643
def _is_dist_op_backward_implement(dist_context, op):
    dist_attr = dist_context.get_op_dist_attr_for_program(op)
    dist_ops = get_distributed_operator_impl_container(op.type)
644

645 646
    return dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( \
        dist_attr.impl_idx)._backward_implemented
647 648


649 650 651 652
def _auto_backward(loss,
                   startup_program=None,
                   parameter_list=None,
                   no_grad_set=None,
653 654
                   callbacks=None,
                   distop_context=None):
655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671
    """
    modification is inplaced
    """
    act_no_grad_set = _get_no_grad_set(loss, no_grad_set)
    assert isinstance(loss, Variable), "The target loss should be an Variable."

    if callbacks is None:
        callbacks = [error_clip_callback]
    else:
        assert (isinstance(callbacks, list))

    assert len(loss.shape) == 1 and loss.shape[0] == 1, \
        "The loss.shape should be (1L,), but the current loss.shape is {}. " \
        "Maybe that you should call fluid.layers.mean to process the current loss.".format(
            loss.shape)

    program = loss.block.program
672

673
    with program_guard(program, startup_program):
674 675 676 677 678 679
        params_grads = append_backward(
            loss,
            parameter_list,
            act_no_grad_set,
            callbacks,
            distop_context=distop_context)
680 681 682 683 684 685

    return params_grads


def _is_distributed(dist_attr):

686 687
    mapping = dist_attr.dims_mapping
    mesh = dist_attr.process_mesh.topology
688 689 690 691 692 693 694 695 696 697
    for idx in range(len(mapping)):
        if mapping[idx] >= 0 and mesh[mapping[idx]] > 1:
            return True

    return False


def _get_dist_shape(var, dist_attr):

    var_shape = var.shape
698 699
    mapping = dist_attr.dims_mapping
    mesh = dist_attr.process_mesh.topology
700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716
    assert len(var_shape) == len(
        mapping
    ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
        var_shape, mapping)
    new_shape = []
    for idx in range(len(var_shape)):
        if var_shape[idx] == -1 or mapping[idx] == -1:
            new_shape.append(var_shape[idx])
        else:
            assert var_shape[idx] % mesh[mapping[
                idx]] == 0, "un-event partition: var_shape[idx]=[{}], mesh[{}]".format(
                    var_shape[idx], mesh[mapping[idx]])
            new_shape.append(var_shape[idx] // mesh[mapping[idx]])

    return new_shape


717
def _partition_parameter(dist_context, src_var, dst_block, dst_varname,
718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744
                         dst_shape):
    # NOTE hack to copied Parameter
    # not initialized parameter, need to initialize it 
    copied_kwargs = {}
    copied_kwargs['trainable'] = src_var.trainable
    copied_kwargs['optimize_attr'] = src_var.optimize_attr
    copied_kwargs['regularizer'] = src_var.regularizer
    copied_kwargs['do_model_average'] = src_var.do_model_average
    copied_kwargs['need_clip'] = src_var.need_clip

    param = Parameter(
        block=dst_block,
        type=src_var.type,
        name=dst_varname,
        shape=dst_shape,
        dtype=src_var.dtype,
        lod_level=src_var.lod_level,
        error_clip=src_var.error_clip,
        stop_gradient=src_var.stop_gradient,
        is_data=src_var.is_data,
        belong_to_optimizer=src_var.belong_to_optimizer,
        **copied_kwargs)

    # set dist attr uid
    # distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
    # param.desc.set_distributed_attr_uid(distributed_attr_uid)
    dist_attr = copy.deepcopy(
745
        dist_context.get_tensor_dist_attr_for_program(src_var))
746
    assert dist_attr is not None
747
    dist_context.set_tensor_dist_attr_for_program(param, dist_attr)
748 749


750 751
def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname,
                                dst_shape):
752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767
    var = dst_block.create_var(
        type=src_var.type,
        name=dst_varname,
        shape=dst_shape,
        dtype=src_var.dtype,
        lod_level=src_var.lod_level,
        persistable=src_var.persistable,
        error_clip=src_var.error_clip,
        stop_gradient=src_var.stop_gradient,
        is_data=src_var.is_data,
        belong_to_optimizer=src_var.belong_to_optimizer)

    # set dist attr uid
    # distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
    # var.desc.set_distributed_attr_uid(distributed_attr_uid)
    dist_attr = copy.deepcopy(
768
        dist_context.get_tensor_dist_attr_for_program(src_var))
769
    assert dist_attr is not None
770
    dist_context.set_tensor_dist_attr_for_program(var, dist_attr)
771 772


773
def _partition_var(dist_context, src_block, dst_block, src_varname,
774 775 776 777 778 779 780 781 782 783 784 785 786
                   dst_varname):
    """
    partition include: split + replicate
    """
    src_var = src_block.var(src_varname)

    if src_var.type == core.VarDesc.VarType.READER:
        dst_block.create_var(
            type=src_var.type,
            name=dst_varname,
            persistable=True,
            stop_gradient=True)
    else:
787
        dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
788 789 790
        target_shape = _get_dist_shape(src_var, dist_attr)

        if isinstance(src_var, Parameter):
791 792
            _partition_parameter(dist_context, src_var, dst_block, dst_varname,
                                 target_shape)
793
        else:
794 795
            _partition_intermediate_var(dist_context, src_var, dst_block,
                                        dst_varname, target_shape)
796 797 798 799 800 801 802 803 804 805 806 807 808 809


def _insert_src_op(src_op, dst_block, varname_mapping):

    new_op_desc = dst_block.desc.append_op()
    new_op_desc.copy_from(src_op.desc)
    for local_varname in src_op.desc.input_arg_names():
        new_op_desc._rename_input(local_varname, varname_mapping[local_varname])
    for local_varname in src_op.desc.output_arg_names():
        new_op_desc._rename_output(local_varname,
                                   varname_mapping[local_varname])
    dst_block._sync_with_cpp()


810
def _insert_dist_op(src_op, dst_block, varname_mapping, dist_context, rank_id):
811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828

    # build input varname mapping
    input_mapping = {}
    for input_name in src_op.desc.input_names():
        varnames = []
        for varname in src_op.desc.input(input_name):
            varnames.append(varname_mapping[varname])
        input_mapping[input_name] = varnames

    # build output varname mapping
    output_mapping = {}
    for output_name in src_op.desc.output_names():
        varnames = []
        for varname in src_op.desc.output(output_name):
            varnames.append(varname_mapping[varname])
        output_mapping[output_name] = varnames

    # append dist op 
829 830 831
    dist_attr = dist_context.get_op_dist_attr_for_program(src_op)
    dist_ops = get_distributed_operator_impl_container(src_op.type)
    append_op_handle = dist_ops.get_impl(dist_attr.impl_idx).forward(src_op)
832 833 834 835 836 837 838
    append_op_handle(
        dst_block,
        src_op,
        dist_attr,
        input_mapping,
        output_mapping,
        rank_id=rank_id)
839 840 841 842 843 844 845 846


def is_forward_op(op):
    role1 = int(core.op_proto_and_checker_maker.OpRole.Forward) | int(
        core.op_proto_and_checker_maker.OpRole.Loss)
    role2 = int(core.op_proto_and_checker_maker.OpRole.Forward)
    op_role = int(op.attr('op_role'))
    return op_role == role2 or op_role == role1