auto_parallel_sharding.py 33.5 KB
Newer Older
J
JZ-LIANG 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
J
JZ-LIANG 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
J
JZ-LIANG 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
J
JZ-LIANG 已提交
9 10 11 12 13 14 15 16 17 18 19
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import reduce

from paddle.framework import core
from paddle.fluid import unique_name
from .pass_base import PassBase, register_pass
20 21 22 23
from paddle.distributed.fleet.meta_optimizers.common import (
    is_backward_op,
    is_optimizer_op,
)
J
JZ-LIANG 已提交
24
from paddle.distributed.auto_parallel.process_group import new_process_group
25 26 27 28 29 30 31 32 33
from paddle.distributed.auto_parallel.operators.common import (
    is_parameter_related,
    is_data_parallel_reduce_op,
)
from paddle.distributed.auto_parallel.utils import (
    _get_comm_group,
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
    set_var_dist_attr,
)
J
JZ-LIANG 已提交
34 35 36

OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
37
_skip_ops = [
38 39 40 41 42 43 44
    'create_py_reader',
    'create_double_buffer_reader',
    'read',
    'slice',
    'split',
    'assign',
    "send_v2",
45
]
J
JZ-LIANG 已提交
46 47
# update here to support new optimizers
_supported_optimizer_type = [
48 49 50 51 52 53 54 55 56 57
    "adam",
    "adamax",
    "adamw",
    "decayed_adagrad",
    "momentum",
    "dgc_momentum",
    "lars_momentum",
    "merged_momentum",
    "lamb",
    "sgd",
J
JZ-LIANG 已提交
58 59 60
]


61
def _is_reshard_op(op):
62 63 64
    return op.desc.has_attr(
        "op_namescope"
    ) and "/auto_parallel/reshard" in op.desc.attr('op_namescope')
65 66


J
JZ-LIANG 已提交
67 68 69
# NOTE we add the "auto_parallel" prefix to the pass in order to
# indicate that this pass should obey some constrains by auto_parallel
# for example all ops and vars should has dist attr before and after pass
70
# should use dist op instead of custom comm op
J
JZ-LIANG 已提交
71 72 73 74 75 76
@register_pass("auto_parallel_sharding")
class ShardingPass(PassBase):
    def __init__(self):
        super(ShardingPass, self).__init__()
        self.set_attr("dist_context", None)
        self.set_attr("stage", None)
Z
zhaoyingli 已提交
77 78
        self.set_attr("sharding_degree", None)  # for parallelizer
        self.set_attr("degree", None)  # for parallelizer_v2
J
JZ-LIANG 已提交
79 80 81 82 83 84 85
        self.set_attr("params_grads", [])
        self.set_attr("global_rank", -1)
        self.dp_groups = set()
        self.sharding_infos = []
        self.varname_to_sharding_info = {}
        self.partial_sharding = False
        self.outer_dp_group = None
86
        self.shared_params_grads = []
J
JZ-LIANG 已提交
87 88 89 90 91 92 93

    def _check_self(self):
        if self.get_attr("dist_context") is None:
            return False

        if self.get_attr("stage") not in [1, 2, 3]:
            return False
Z
zhaoyingli 已提交
94
        if self.get_attr("sharding_degree") is not None:
95 96 97
            if (
                not isinstance(self.get_attr("sharding_degree"), int)
            ) or self.get_attr("sharding_degree") <= 1:
Z
zhaoyingli 已提交
98 99
                return False
        elif self.get_attr("degree") is not None:
100 101 102
            if (not isinstance(self.get_attr("degree"), int)) or self.get_attr(
                "degree"
            ) <= 1:
Z
zhaoyingli 已提交
103 104
                return False
        else:
J
JZ-LIANG 已提交
105 106 107
            return False
        if len(self.get_attr("params_grads")) <= 0:
            return False
108 109 110
        if (not isinstance(self.get_attr("global_rank"), int)) or self.get_attr(
            "global_rank"
        ) < 0:
J
JZ-LIANG 已提交
111 112 113 114 115 116 117 118 119
            return False

        return True

    def _check_conflict(self, other_pass):
        return True

    def _apply_single_impl(self, main_program, startup_program, context):
        self._dist_context = self.get_attr("dist_context")
Z
zhaoyingli 已提交
120
        self.sharding_world_size = int(
121 122
            self.get_attr("sharding_degree") or self.get_attr("degree")
        )
J
JZ-LIANG 已提交
123 124 125
        self.stage = int(self.get_attr("stage"))
        self.global_rank = int(self.get_attr("global_rank"))
        params_grads = self.get_attr("params_grads")
126 127 128 129
        main_block, startup_block = (
            main_program.global_block(),
            startup_program.global_block(),
        )
J
JZ-LIANG 已提交
130 131 132 133 134 135

        self._build_sharding_groups(main_block, params_grads)
        self._shard_optimizer(main_block, startup_block, params_grads, context)
        self._shard_gradient_synchronization(main_block)
        self._shard_parameter(main_block, startup_block)

136 137
        context.set_attr("params_grads", self.shared_params_grads)

J
JZ-LIANG 已提交
138 139 140 141 142 143
    def _build_sharding_groups(self, main_block, params_grads):
        self._collective_data_parallel_groups(main_block)
        self._build_sharding_infos(params_grads)

    def _collective_data_parallel_groups(self, main_block):
        for op in main_block.ops:
J
JZ-LIANG 已提交
144
            if not _is_forward_op(op) or op.type in _skip_ops:
J
JZ-LIANG 已提交
145
                continue
146 147 148 149
            # NOTE: there aren't dist_attr in the ops which reshard insert,
            # and should be skip in sharding.
            if _is_reshard_op(op):
                continue
J
JZ-LIANG 已提交
150
            group = _inference_data_parallel_group_for_operator(
151 152
                self.global_rank, op, self._dist_context
            )
J
JZ-LIANG 已提交
153 154 155
            if group is not None:
                self.dp_groups.add(group)

156
        # TODO(JZ-LIANG) allow more than one dp groups in network, support more general distribution
J
JZ-LIANG 已提交
157 158 159
        # genetated by auto search
        if len(self.dp_groups) != 1:
            raise NotImplementedError(
160 161 162 163
                "So far Only and Exactly one data parallel group in network are supported, but got [{}] different data parallel groups".format(
                    len(self.dp_groups)
                )
            )
J
JZ-LIANG 已提交
164 165 166 167 168

    def _build_sharding_infos(self, params_grads):

        for dp_group in self.dp_groups:

169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
            assert (
                dp_group.nranks >= self.sharding_world_size
            ), "sharding world size [{}] should not larger than dp world size [{}]".format(
                self.sharding_world_size, dp_group.nranks
            )
            assert (
                dp_group.nranks % self.sharding_world_size == 0
            ), "sharding world size [{}] should be divisible by dp world size [{}]".format(
                self.sharding_world_size, dp_group.nranks
            )
            assert (
                self.global_rank in dp_group.ranks
            ), "current ranks [{}] does NOT belong to the data parallel group [{}]".format(
                self.global_rank, dp_group.ranks
            )
            assert (
                len(params_grads) >= self.sharding_world_size
            ), "number of parameters [{}] is not enough to be shard among [{}] ranks".format(
                len(params_grads), self.sharding_world_size
            )
J
JZ-LIANG 已提交
189

190
            # sharding hybrid data parallel: partial sharding param within
J
JZ-LIANG 已提交
191 192
            if dp_group.nranks > self.sharding_world_size:
                self.partial_sharding = True
193 194 195
                assert (
                    len(self.dp_groups) == 1
                ), "hybrid sharding and data parallelism are supported only when there is excatly one data parallel group in the network"
J
JZ-LIANG 已提交
196
                outer_dp_group, sharding_group = _get_dp_and_sharding_groups(
197 198
                    dp_group.ranks, self.sharding_world_size, self.global_rank
                )
J
JZ-LIANG 已提交
199 200 201 202 203
                sharding_group = new_process_group(sharding_group)
                self.outer_dp_group = new_process_group(outer_dp_group)
            else:
                sharding_group = dp_group

204
            self._dist_context._sharding_group = sharding_group
J
JZ-LIANG 已提交
205
            # TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
206 207 208
            sharding_info = ShardingInfo(
                sharding_group, self.global_rank, params_grads
            )
J
JZ-LIANG 已提交
209
            self.sharding_infos.append(sharding_info)
210
            for param in sharding_info.params:
J
JZ-LIANG 已提交
211 212
                self.varname_to_sharding_info[param.name] = sharding_info

213 214 215
    def _shard_optimizer(
        self, main_block, startup_block, params_grads, pass_context
    ):
J
JZ-LIANG 已提交
216 217 218 219 220 221 222 223
        """
        sharding all optimizer related ops and vars, include:
        gradient clip ops & vars
        weight decay ops & vars
        optimizer ops and states
        """
        self._shard_amp_related_op_and_vars(main_block, pass_context)
        self._shard_weight_decay(main_block)
224
        # self._shard_gradient_clip(main_block)
J
JZ-LIANG 已提交
225 226 227 228 229 230 231 232 233 234 235 236
        self._shard_optimizer_ops_and_states(main_block, startup_block)
        self._insert_optimizer_broadcasts(main_block, startup_block)

    def _shard_amp_related_op_and_vars(self, main_block, pass_context):

        if self.stage < 2:
            return

        for idx, op in reversed(list(enumerate(main_block.ops))):
            # shard amp related param_grad cast
            if _is_param_grad_fp32_cast_op(main_block, op):
                output_name = op.output_arg_names[0]
237
                param_name = output_name[: output_name.find("@")]
J
JZ-LIANG 已提交
238 239 240 241 242 243 244 245
                if not self._is_parameter_in_local_shard(param_name):
                    main_block._remove_op(idx, sync=False)
                    main_block._remove_var(output_name, sync=False)

            # shard check nan inf
            elif op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
                reversed_x = []
                for input_name in op.desc.input('X'):
246
                    param_name = input_name[: input_name.find("@")]
J
JZ-LIANG 已提交
247 248 249

                    if self._is_parameter_in_local_shard(param_name):
                        reversed_x.append(input_name)
250 251 252 253 254 255 256 257

                # NOTE: When `reversed_x` is [], check_finite_and_unscale will be replaced by `fill_constant` op.
                # The output of check_finite_and_unscale is be set False
                if reversed_x:
                    op.desc.set_input('X', reversed_x)
                    op.desc.set_output('Out', reversed_x)
                else:
                    if op.type == "check_finite_and_unscale":
258
                        op_role = op.attr('op_role')
259 260 261 262 263 264 265 266 267 268 269
                        out_name = op.output_arg_names[0]
                        out_var = main_block.vars[out_name]
                        main_block._remove_op(idx, sync=False)
                        main_block._insert_op_without_sync(
                            idx,
                            type="fill_constant",
                            outputs={"Out": out_var},
                            attrs={
                                "shape": out_var.shape,
                                "dtype": out_var.dtype,
                                "value": 0,
270
                                OP_ROLE_KEY: op_role,
271 272
                            },
                        )
273 274
                    else:
                        main_block._remove_op(idx, sync=False)
J
JZ-LIANG 已提交
275 276 277 278 279 280 281 282 283

        main_block._sync_with_cpp()

    def _shard_gradient_clip(self, main_block):

        if self.stage < 2:
            return

        # TODO (JZ-LIANG) support calculate global norm with tensor parallelism
J
JZ-LIANG 已提交
284 285 286 287
        removed_op_type = ['elementwise_mul', 'squared_l2_norm', 'clip_by_norm']
        removed_op_idx = set()
        removed_tmp_var = set()

J
JZ-LIANG 已提交
288 289 290 291
        for idx, op in list(enumerate(main_block.ops)):
            if not _is_gradient_clip_op(op):
                continue

J
JZ-LIANG 已提交
292 293
            if op.type in removed_op_type:
                input_name = op.input("X")[0]
294
                param_name = input_name[: input_name.find("@GRAD")]
J
JZ-LIANG 已提交
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
                if not self._is_parameter_in_local_shard(param_name):
                    removed_op_idx.add(idx)
                    if op.type in ['squared_l2_norm', 'clip_by_norm']:
                        for output_name in op.output_arg_names:
                            removed_tmp_var.add(output_name)

        for idx, op in reversed(list(enumerate(main_block.ops))):
            if not _is_gradient_clip_op(op):
                continue
            if idx in removed_op_idx:
                main_block._remove_op(idx, sync=False)

        for varname in removed_tmp_var:
            main_block._remove_var(varname, sync=False)

J
JZ-LIANG 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322
        for idx, op in list(enumerate(main_block.ops)):
            if not _is_gradient_clip_op(op):
                continue
            if op.type == 'sum':
                reserved_vars = []
                for input_name in op.input_arg_names:
                    if input_name not in removed_tmp_var:
                        reserved_vars.append(input_name)
                op.desc.set_input("X", reserved_vars)

                sum_op_output = op.desc.output_arg_names()[0]
                for i, sharding_info in enumerate(self.sharding_infos):
                    new_op = main_block._insert_op(
J
JZ-LIANG 已提交
323
                        idx + i + 1,
J
JZ-LIANG 已提交
324 325 326 327 328 329 330 331
                        type='c_allreduce_sum',
                        inputs={'X': [sum_op_output]},
                        outputs={'Out': [sum_op_output]},
                        attrs={
                            'ring_id': sharding_info.group.id,
                            'op_namescope': "/gradient_clip_model_parallelism",
                            'use_calc_stream': True,
                            OP_ROLE_KEY: OpRole.Optimize,
332 333 334 335 336 337 338
                        },
                    )
                    dist_attr = (
                        self._dist_context.get_tensor_dist_attr_for_program(
                            main_block.var(sum_op_output)
                        )
                    )
339 340 341 342
                    # assert dist_attr is not None
                    # naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                    #     new_op, dist_attr.process_mesh, dist_attr.dims_mapping,
                    #     self._dist_context)
J
JZ-LIANG 已提交
343 344 345 346 347 348 349 350 351 352 353 354 355 356
                break

        main_block._sync_with_cpp()

    def _shard_weight_decay(self, main_block):

        if self.stage < 2:
            return

        for idx, op in reversed(list(enumerate(main_block.ops))):
            if not _is_weight_decay_op(op):
                continue
            else:
                raise NotImplementedError(
357 358
                    "weight decay is NOT supported by now"
                )
J
JZ-LIANG 已提交
359 360 361 362 363 364 365 366 367 368 369 370 371 372
        main_block._sync_with_cpp()

    def _shard_optimizer_ops_and_states(self, main_block, startup_block):

        should_removed_optimizer_states = []
        for idx, op in reversed(list(enumerate(main_block.ops))):
            if not is_optimizer_op(op):
                break

            if op.type in _supported_optimizer_type:
                assert "Param" in op.input_names
                assert len(op.input("Param")) == 1
                param_name = op.input("Param")[0]
                if not self._is_parameter_in_local_shard(param_name):
373 374 375 376 377 378 379
                    should_removed_optimizer_states.extend(
                        [
                            varname
                            for varname in op.output_arg_names
                            if varname != param_name
                        ]
                    )
J
JZ-LIANG 已提交
380
                    main_block._remove_op(idx, sync=False)
381 382
                else:
                    self.shared_params_grads.append(
383 384
                        self._get_param_grad(param_name)
                    )
J
JZ-LIANG 已提交
385 386

        for idx, op in reversed(list(enumerate(startup_block.ops))):
387 388 389 390
            if (
                len(op.output_arg_names) == 1
                and op.output_arg_names[0] in should_removed_optimizer_states
            ):
J
JZ-LIANG 已提交
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
                startup_block._remove_op(idx, sync=False)

        for varname in should_removed_optimizer_states:
            if main_block.has_var(varname):
                main_block._remove_var(varname, sync=False)
            if startup_block.has_var(varname):
                startup_block._remove_var(varname, sync=False)

        main_block._sync_with_cpp()
        startup_block._sync_with_cpp()

    def _insert_optimizer_broadcasts(self, main_block, startup_block):

        if self.stage > 2:
            return

        for sharding_info in self.sharding_infos:
            for param in sharding_info.params:
                assert main_block.has_var(param.name)
                assert startup_block.has_var(param.name)

412 413 414 415 416 417 418 419 420 421 422 423 424 425
                new_op = main_block.append_op(
                    type='c_broadcast',
                    inputs={'X': param},
                    outputs={'Out': param},
                    attrs={
                        'ring_id': sharding_info.group.id,
                        'root': sharding_info.get_var_rank(param.name),
                        'use_calc_stream': True,
                        OP_ROLE_KEY: OpRole.Optimize,
                    },
                )
                param_dist_attr = (
                    self._dist_context.get_tensor_dist_attr_for_program(param)
                )
J
JZ-LIANG 已提交
426 427
                assert param_dist_attr is not None
                naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
428 429 430 431 432
                    new_op,
                    param_dist_attr.process_mesh,
                    param_dist_attr.dims_mapping,
                    self._dist_context,
                )
J
JZ-LIANG 已提交
433 434 435 436 437 438 439
        main_block._sync_with_cpp()

    def _is_parameter_in_local_shard(self, param_name):
        assert param_name in self.varname_to_sharding_info
        sharding_info = self.varname_to_sharding_info[param_name]
        return sharding_info.is_in_local_shard(param_name)

440 441 442 443 444 445 446
    def _get_param_grad(self, param_name):
        assert param_name in self.varname_to_sharding_info
        sharding_info = self.varname_to_sharding_info[param_name]
        p_g = sharding_info.get_param_grad(param_name)
        assert p_g is not None
        return p_g

J
JZ-LIANG 已提交
447 448 449 450 451 452 453
    def _shard_gradient_synchronization(self, main_block):

        if self.stage < 2:
            return

        dp_ring_ids = [group.id for group in self.dp_groups]
        for idx, op in reversed(list(enumerate(main_block.ops))):
454
            if _is_param_grad_allreduce_op(op, main_block):
J
JZ-LIANG 已提交
455 456 457
                input_name = op.input_arg_names[0]
                base_name = _get_base_name_from_grad_name(input_name)
                sharding_info = self.varname_to_sharding_info[base_name]
458 459 460 461 462 463 464 465 466 467 468 469
                _insert_reduce_op(
                    main_block,
                    idx,
                    input_name,
                    sharding_info.group.id,
                    sharding_info.get_var_rank(base_name),
                    self._dist_context,
                )
                if (
                    not self.partial_sharding
                    or not sharding_info.is_in_local_shard(base_name)
                ):
J
JZ-LIANG 已提交
470 471 472 473
                    main_block._remove_op(idx + 1, sync=False)
                else:
                    op._set_attr("ring_id", self.outer_dp_group.id)

474 475 476 477 478 479 480 481 482 483 484
            # NOTE:
            # var@GRAD = sum(var@GRAD@RENAME@0, var@GRAD@RENAME@1)
            # If the var is not in local rank and it is output of many ops, or the var is renamed in another words,
            # the sum op should be removed.
            if _is_param_grad_sum_op(op, main_block):
                out_name = op.output_arg_names[0]
                base_name = _get_base_name_from_grad_name(out_name)
                sharding_info = self.varname_to_sharding_info[base_name]
                if not sharding_info.is_in_local_shard(base_name):
                    main_block._remove_op(idx, sync=False)

J
JZ-LIANG 已提交
485 486 487 488 489 490 491 492 493
        main_block._sync_with_cpp()

    def _shard_parameter(self, main_block, startup_block):

        if self.stage < 3:
            return

        dp_ring_ids = [group.id for group in self.dp_groups]
        for sharding_info in self.sharding_infos:
494 495 496 497
            (
                need_broadcast_vars,
                param_usage,
            ) = sharding_info.get_broadcast_vars_and_param_usage(main_block)
J
JZ-LIANG 已提交
498 499
            not_used_param_nane = []
            for param_name in param_usage:
500 501 502 503 504
                if (
                    param_usage[param_name] == 0
                    and sharding_info.get_var_rank(param_name)
                    != sharding_info.local_rank
                ):
J
JZ-LIANG 已提交
505 506 507 508 509 510 511
                    not_used_param_nane.append(param_name)

            for idx, op in reversed(list(enumerate(main_block.ops))):
                if is_optimizer_op(op):
                    continue

                for input_name in op.desc.input_arg_names():
512 513
                    # NOTE hack for embedding op when AMP 02-3
                    # paddle amp force embedding (lookup table) to be run on fp32
514 515 516
                    if _is_param_fp16_cast_op(
                        main_block, op, sharding_info.param_names
                    ):
J
JZ-LIANG 已提交
517 518 519 520 521 522 523
                        continue
                    if input_name not in need_broadcast_vars:
                        continue
                    root_rank = sharding_info.get_var_rank(input_name)
                    if root_rank == sharding_info.local_rank:
                        broadcast_varname = input_name
                    else:
524 525 526
                        broadcast_varname = unique_name.generate(
                            input_name + "@BroadCast"
                        )
J
JZ-LIANG 已提交
527
                        input_var = main_block.var(input_name)
528 529 530 531 532 533 534 535 536 537 538
                        new_var = main_block.create_var(
                            name=broadcast_varname,
                            shape=input_var.shape,
                            dtype=input_var.dtype,
                            persistable=False,
                        )
                        ref_dist_attr = (
                            self._dist_context.get_tensor_dist_attr_for_program(
                                input_var
                            )
                        )
J
JZ-LIANG 已提交
539
                        out_var_dist_attr = set_var_dist_attr(
540 541
                            self._dist_context,
                            new_var,
J
JZ-LIANG 已提交
542
                            ref_dist_attr.dims_mapping,
543 544
                            ref_dist_attr.process_mesh,
                        )
J
JZ-LIANG 已提交
545 546
                        op._rename_input(input_name, broadcast_varname)

547 548 549 550 551 552 553 554 555 556
                    _insert_init_and_broadcast_op(
                        main_block,
                        idx,
                        broadcast_varname,
                        sharding_info.local_rank,
                        root_rank,
                        sharding_info.group.id,
                        op.attr('op_role'),
                        self._dist_context,
                    )
J
JZ-LIANG 已提交
557 558 559 560 561 562 563 564 565 566 567 568 569 570

            for idx, op in reversed(list(enumerate(main_block.ops))):
                if op.type != "cast":
                    continue
                input_name = op.input_arg_names[0]
                output_name = op.output_arg_names[0]
                if input_name in not_used_param_nane:
                    main_block._remove_op(idx, sync=False)
                    main_block._remove_var(output_name, sync=False)

            for idx, op in reversed(list(enumerate(startup_block.ops))):
                assert len(op.output_arg_names) == 1
                output_name = op.output_arg_names[0]

571 572 573 574 575 576 577 578 579
                if (
                    op.type == "c_broadcast"
                    and op.attr("ring_id") in dp_ring_ids
                ):
                    if (
                        self.outer_dp_group
                        and sharding_info.get_var_rank(output_name)
                        == sharding_info.local_rank
                    ):
J
JZ-LIANG 已提交
580 581 582 583 584
                        op._set_attr("ring_id", self.outer_dp_group.id)
                    else:
                        startup_block._remove_op(idx, sync=False)
                    continue

585 586 587 588 589 590
                if (
                    op.type != "c_broadcast"
                    and output_name in param_usage
                    and sharding_info.get_var_rank(output_name)
                    != sharding_info.local_rank
                ):
J
JZ-LIANG 已提交
591 592
                    startup_block._remove_op(idx, sync=False)

J
JZ-LIANG 已提交
593
            for param_name in param_usage:
594 595 596 597
                if (
                    sharding_info.get_var_rank(param_name)
                    != sharding_info.local_rank
                ):
J
JZ-LIANG 已提交
598 599
                    main_block._remove_var(param_name, sync=False)
                    startup_block._remove_var(param_name, sync=False)
J
JZ-LIANG 已提交
600 601 602 603 604

        main_block._sync_with_cpp()
        startup_block._sync_with_cpp()


605 606 607 608 609 610 611 612 613 614
def _insert_init_and_broadcast_op(
    block,
    insert_idx,
    varname,
    local_rank,
    root_rank,
    ring_id,
    op_role,
    dist_context,
):
J
JZ-LIANG 已提交
615 616 617 618 619
    """
    empty op for initialization
    """
    broadcast_var = block.var(varname)
    broadcast_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
620 621 622 623 624 625 626 627 628 629 630 631 632 633 634
        broadcast_var
    )

    new_op = block._insert_op_without_sync(
        insert_idx,
        type='c_broadcast',
        inputs={'X': varname},
        outputs={'Out': varname},
        attrs={
            'ring_id': ring_id,
            'root': root_rank,
            'use_calc_stream': True,
            OP_ROLE_KEY: op_role,
        },
    )
J
JZ-LIANG 已提交
635
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
636 637 638 639 640
        new_op,
        broadcast_var_dist_attr.process_mesh,
        broadcast_var_dist_attr.dims_mapping,
        dist_context,
    )
J
JZ-LIANG 已提交
641 642 643 644 645 646 647 648 649
    if local_rank != root_rank:

        new_op = block._insert_op_without_sync(
            insert_idx,
            type="empty",
            outputs={"Out": broadcast_var.name},
            attrs={
                "shape": broadcast_var.shape,
                "dtype": broadcast_var.dtype,
650 651 652
                OP_ROLE_KEY: op_role,
            },
        )
J
JZ-LIANG 已提交
653
        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
654 655 656 657 658
            new_op,
            broadcast_var_dist_attr.process_mesh,
            broadcast_var_dist_attr.dims_mapping,
            dist_context,
        )
J
JZ-LIANG 已提交
659 660 661
    return


662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686
def _insert_reduce_op(
    block,
    insert_idx,
    reduce_var,
    ring_id,
    root_id,
    dist_context,
    op_role=OpRole.Backward,
    use_calc_stream=True,
):
    assert (
        root_id >= 0
    ), "root id should be a positive int, but now root id is {}".format(root_id)
    new_op = block._insert_op_without_sync(
        insert_idx,
        type='c_reduce_sum',
        inputs={'X': [reduce_var]},
        outputs={'Out': [reduce_var]},
        attrs={
            'ring_id': ring_id,
            'root_id': root_id,
            'use_calc_stream': use_calc_stream,
            OP_ROLE_KEY: op_role,
        },
    )
J
JZ-LIANG 已提交
687 688

    dist_attr = dist_context.get_tensor_dist_attr_for_program(
689 690
        block.var(reduce_var)
    )
J
JZ-LIANG 已提交
691
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
692 693
        new_op, dist_attr.process_mesh, dist_attr.dims_mapping, dist_context
    )
J
JZ-LIANG 已提交
694 695 696 697 698 699 700 701 702 703 704 705 706 707


def _get_dp_and_sharding_groups(origin_group, sharding_group_size, rank):
    dp_axis = 0
    sharding_axis = 1
    shape = [len(origin_group) // sharding_group_size, sharding_group_size]

    dp_group = _get_comm_group(origin_group, shape, dp_axis, rank)
    sharding_group = _get_comm_group(origin_group, shape, sharding_axis, rank)

    return dp_group, sharding_group


def _is_gradient_clip_op(op):
708 709 710
    return op.desc.has_attr("op_namescope") and op.desc.attr(
        "op_namescope"
    ).startswith("/gradient_clip")
J
JZ-LIANG 已提交
711 712 713


def _is_weight_decay_op(op):
714 715 716
    return op.desc.has_attr("op_namescope") and op.desc.attr(
        "op_namescope"
    ).startswith("/regularization")
J
JZ-LIANG 已提交
717 718 719 720 721


def _is_param_grad_fp32_cast_op(block, op):
    if not is_backward_op(op):
        return False
722 723 724
    if not _is_desired_cast_op(
        block, op, core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32
    ):
J
JZ-LIANG 已提交
725 726
        return False
    output_name = op.desc.output_arg_names()[0]
727
    base_name = output_name[: output_name.find("@")]
J
JZ-LIANG 已提交
728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744
    if not block.has_var(base_name):
        return False
    return block.var(base_name).is_parameter


def _is_param_fp16_cast_op(block, op, params):

    if is_optimizer_op(op):
        return False
    if not _is_desired_cast_op(block, op):
        return False
    input_name = op.desc.input_arg_names()[0]
    if input_name not in params:
        return False
    return True


745 746 747 748 749 750
def _is_desired_cast_op(
    block,
    op,
    src_var_type=core.VarDesc.VarType.FP32,
    dst_var_type=core.VarDesc.VarType.FP16,
):
J
JZ-LIANG 已提交
751 752
    if op.type != "cast":
        return False
753 754
    assert len(op.desc.input_arg_names()) == 1
    assert len(op.desc.output_arg_names()) == 1
J
JZ-LIANG 已提交
755 756 757
    input_var = block.var(op.desc.input_arg_names()[0])
    output_var = block.var(op.desc.output_arg_names()[0])

758
    if input_var.dtype != src_var_type or output_var.dtype != dst_var_type:
J
JZ-LIANG 已提交
759 760 761 762 763 764 765 766
        return False

    return True


def _get_base_name_from_grad_name(grad_name):
    base_name = None
    if ".cast_fp16@GRAD" in grad_name:
767
        base_name = grad_name[: grad_name.find(".cast_fp16@GRAD")]
J
JZ-LIANG 已提交
768
    elif "@GRAD" in grad_name:
769
        base_name = grad_name[: grad_name.find("@GRAD")]
J
JZ-LIANG 已提交
770 771 772
    return base_name


773 774 775 776 777 778 779 780 781 782 783 784 785 786
def _is_param_grad_allreduce_op(op, block):

    if not is_data_parallel_reduce_op(op):
        return False

    output_name = op.output_arg_names[0]
    base_name = _get_base_name_from_grad_name(output_name)

    if not block.has_var(base_name):
        return False

    return block.var(base_name).is_parameter


787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802
def _is_param_grad_sum_op(op, block):

    if not is_backward_op(op):
        return False
    if op.type != "sum":
        return False

    output_name = op.output_arg_names[0]
    base_name = _get_base_name_from_grad_name(output_name)

    if not block.has_var(base_name):
        return False

    return block.var(base_name).is_parameter


J
JZ-LIANG 已提交
803 804 805 806
def _is_forward_op(op):
    return op.attr("op_role") == 0


J
JZ-LIANG 已提交
807 808 809 810 811 812 813 814 815 816 817 818
def _inference_data_parallel_group_for_operator(rank_id, op, dist_context):

    dp_group = None
    for input_name in op.input_arg_names:
        if not is_parameter_related(input_name, op.block):
            dist_attr = dist_context.get_op_dist_attr_for_program(op)
            process_mesh = dist_attr.process_mesh
            input_dim_mapping = dist_attr.get_input_dims_mapping(input_name)
            mesh_shape = process_mesh.topology
            # TODO(JZ-LIANG) replace with specific batch size dimension
            batch_size_axis = input_dim_mapping[0]
            if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
819 820 821 822 823 824
                group_ranks = _get_comm_group(
                    process_mesh.processes,
                    process_mesh.topology,
                    batch_size_axis,
                    rank_id,
                )
J
JZ-LIANG 已提交
825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842
                dp_group = new_process_group(group_ranks)
                break

    return dp_group


def shard_parameters(params, group_size):
    # TODO(JZ-LIANG) support multiple partition methods
    # method1: greedy even but unorder
    # method2: roughly even with oreder
    mapping = {}
    for rank_ in range(group_size):
        mapping[rank_] = []
    sizes = [0] * group_size
    for param in params:
        rank = sizes.index(min(sizes))
        mapping[rank].append(param)
        numel = reduce(lambda x, y: x * y, param.shape)
843 844 845 846 847
        assert (
            numel > 0
        ), "param [{}] should larger than 0, but it is [{}]".format(
            param.name, numel
        )
J
JZ-LIANG 已提交
848 849 850 851 852 853
        sizes[rank] += numel

    return mapping


class ShardingInfo(object):
854
    def __init__(self, group, rank, params_grads):
J
JZ-LIANG 已提交
855
        self.group = group
856
        self.params_grads = dict([(p.name, (p, g)) for p, g in params_grads])
857 858 859
        assert len(self.params_grads) == len(
            set(self.params_grads)
        ), "found duplicated param in params_grads"
860 861

        self.params = [p for p, _ in params_grads]
J
JZ-LIANG 已提交
862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884
        self.param_names = [p.name for p in self.params]
        self.group_size = group.nranks
        self.global_rank = rank
        self.local_rank = group.ranks.index(self.global_rank)
        # rank in below mapping are local rank in this sharding group
        self.rank_to_params = shard_parameters(self.params, self.group_size)
        # include fp32 and fp16 param
        self.param_to_rank = dict()
        self._map_param_to_rank()

    def _map_param_to_rank(self):
        """
        mapping parameters to the rank which holds it.
        """
        for rank, params in self.rank_to_params.items():
            for param in params:
                self.param_to_rank[param.name] = rank

    def get_var_rank(self, varname):
        if varname in self.param_to_rank:
            return self.param_to_rank[varname]
        return -1

885
    # determine fp32 and fp16 (cast) param
J
JZ-LIANG 已提交
886 887 888
    def is_in_local_shard(self, param_name):
        return self.get_var_rank(param_name) == self.local_rank

889 890 891 892
    # NOTE the follwo logic is designed for supporting AMP O1 when
    # the param would be cast to fp16 before used for caculation.
    # and sharding should only broadcast the casted fp16 param
    # instead of the origin fp32 version param.
J
JZ-LIANG 已提交
893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920
    def get_broadcast_vars_and_param_usage(self, block):
        broadcast_vars = set([])
        fp16_params = set([])
        fp16_to_fp32 = {}

        param_usage = {x: 0 for x in self.param_names}
        for op in block.ops:
            if is_optimizer_op(op):
                continue
            for input_name in op.desc.input_arg_names():
                if input_name in self.param_names:
                    param_usage[input_name] += 1

        for op in block.ops:
            if not _is_param_fp16_cast_op(block, op, self.param_names):
                continue
            input_name = op.input_arg_names[0]
            output_name = op.output_arg_names[0]
            broadcast_vars.add(output_name)
            fp16_params.add(output_name)
            fp16_to_fp32[output_name] = input_name
            param_usage[input_name] -= 1
            self.param_to_rank[output_name] = self.param_to_rank[input_name]

        for param, usage in param_usage.items():
            if usage > 0:
                broadcast_vars.add(param)
        return broadcast_vars, param_usage
921 922 923 924

    def get_param_grad(self, param_name):
        if not self.is_in_local_shard(param_name):
            raise ValueError(
925 926
                "param[{}] not in current rank.".format(param_name)
            )
927 928 929
        if param_name not in self.params_grads:
            raise ValueError('param[{}] not in params_grads'.format(param_name))
        return self.params_grads.get(param_name, None)