auto_parallel_sharding.py 28.0 KB
Newer Older
J
JZ-LIANG 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import reduce
from collections import OrderedDict
import numpy as np

import paddle
from paddle.framework import core
from paddle.fluid import unique_name
from .pass_base import PassBase, register_pass
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op
J
JZ-LIANG 已提交
24
from paddle.distributed.auto_parallel.process_group import new_process_group
J
JZ-LIANG 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
from paddle.distributed.auto_parallel.operators.common import is_parameter_related
from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr

OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
_skip_ops = ['create_py_reader', 'create_double_buffer_reader', 'read', 'slice']
# update here to support new optimizers
_supported_optimizer_type = [
    "adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum",
    "lars_momentum", "merged_momentum", "lamb", "sgd"
]


# NOTE we add the "auto_parallel" prefix to the pass in order to
# indicate that this pass should obey some constrains by auto_parallel
# for example all ops and vars should has dist attr before and after pass
# should use dist op instead of custom comm op 
@register_pass("auto_parallel_sharding")
class ShardingPass(PassBase):
    def __init__(self):
        super(ShardingPass, self).__init__()
        self.set_attr("dist_context", None)
        self.set_attr("stage", None)
        self.set_attr("sharding_degree", None)
        self.set_attr("params_grads", [])
        self.set_attr("global_rank", -1)
        self.dp_groups = set()
        self.sharding_infos = []
        self.varname_to_sharding_info = {}
        self.partial_sharding = False
        self.outer_dp_group = None

    def _check_self(self):
        if self.get_attr("dist_context") is None:
            return False

        if self.get_attr("stage") not in [1, 2, 3]:
            return False
        if (not isinstance(self.get_attr("sharding_degree"),
                           int)) or self.get_attr("sharding_degree") <= 1:
            return False
        if len(self.get_attr("params_grads")) <= 0:
            return False
        if (not isinstance(self.get_attr("global_rank"),
                           int)) or self.get_attr("global_rank") < 0:
            return False

        return True

    def _check_conflict(self, other_pass):
        return True

    def _apply_single_impl(self, main_program, startup_program, context):
        self._dist_context = self.get_attr("dist_context")
        self.sharding_world_size = int(self.get_attr("sharding_degree"))
        self.stage = int(self.get_attr("stage"))
        self.global_rank = int(self.get_attr("global_rank"))
        params_grads = self.get_attr("params_grads")
        main_block, startup_block = main_program.global_block(
        ), startup_program.global_block()

        self._build_sharding_groups(main_block, params_grads)
        self._shard_optimizer(main_block, startup_block, params_grads, context)
        self._shard_gradient_synchronization(main_block)
        self._shard_parameter(main_block, startup_block)

    def _build_sharding_groups(self, main_block, params_grads):
        self._collective_data_parallel_groups(main_block)
        self._build_sharding_infos(params_grads)

    def _collective_data_parallel_groups(self, main_block):
        for op in main_block.ops:
J
JZ-LIANG 已提交
97
            if not _is_forward_op(op) or op.type in _skip_ops:
J
JZ-LIANG 已提交
98 99 100 101 102 103 104 105 106 107 108
                continue
            group = _inference_data_parallel_group_for_operator(
                self.global_rank, op, self._dist_context)
            if group is not None:
                self.dp_groups.add(group)

        # TODO(JZ-LIANG) allow more than one dp groups in network, support more general distribution 
        # genetated by auto search
        if len(self.dp_groups) != 1:
            raise NotImplementedError(
                "So far Only and Exactly one data parallel group in network are supported, but got [{}] different data parallel groups".
J
JZ-LIANG 已提交
109
                format(len(self.dp_groups)))
J
JZ-LIANG 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195

    def _build_sharding_infos(self, params_grads):

        for dp_group in self.dp_groups:

            assert dp_group.nranks >= self.sharding_world_size, "sharding world size [{}] should not larger than dp world size [{}]".format(
                self.sharding_world_size, dp_group.nranks)
            assert dp_group.nranks % self.sharding_world_size == 0, "sharding world size [{}] should be divisible by dp world size [{}]".format(
                self.sharding_world_size, dp_group.nranks)
            assert self.global_rank in dp_group.ranks, "current ranks [{}] does NOT belong to the data parallel group [{}]".format(
                self.global_rank, dp_group.ranks)
            assert len(
                params_grads
            ) >= self.sharding_world_size, "number of parameters [{}] is not enough to be shard among [{}] ranks".format(
                len(params_grads), self.sharding_world_size)

            # sharding hybrid data parallel: partial sharding param within 
            if dp_group.nranks > self.sharding_world_size:
                self.partial_sharding = True
                assert len(
                    self.dp_groups
                ) == 1, "hybrid sharding and data parallelism are supported only when there is excatly one data parallel group in the network"
                outer_dp_group, sharding_group = _get_dp_and_sharding_groups(
                    dp_group.ranks, self.sharding_world_size, self.global_rank)
                sharding_group = new_process_group(sharding_group)
                self.outer_dp_group = new_process_group(outer_dp_group)
            else:
                sharding_group = dp_group

            # TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
            params_in_group = [p for p, g in params_grads]
            assert len(params_in_group) == len(set(
                params_in_group)), "found duplicated param in params_grads"
            sharding_info = ShardingInfo(sharding_group, self.global_rank,
                                         params_in_group)
            self.sharding_infos.append(sharding_info)
            for param in params_in_group:
                self.varname_to_sharding_info[param.name] = sharding_info

    def _shard_optimizer(self, main_block, startup_block, params_grads,
                         pass_context):
        """
        sharding all optimizer related ops and vars, include:
        gradient clip ops & vars
        weight decay ops & vars
        optimizer ops and states
        """
        self._shard_amp_related_op_and_vars(main_block, pass_context)
        self._shard_weight_decay(main_block)
        self._shard_gradient_clip(main_block)
        self._shard_optimizer_ops_and_states(main_block, startup_block)
        self._insert_optimizer_broadcasts(main_block, startup_block)

    def _shard_amp_related_op_and_vars(self, main_block, pass_context):

        if self.stage < 2:
            return

        for idx, op in reversed(list(enumerate(main_block.ops))):
            # shard amp related param_grad cast
            if _is_param_grad_fp32_cast_op(main_block, op):
                output_name = op.output_arg_names[0]
                param_name = output_name[:output_name.find("@")]
                if not self._is_parameter_in_local_shard(param_name):
                    main_block._remove_op(idx, sync=False)
                    main_block._remove_var(output_name, sync=False)

            # shard check nan inf
            elif op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
                reversed_x = []
                for input_name in op.desc.input('X'):
                    param_name = input_name[:input_name.find("@")]

                    if self._is_parameter_in_local_shard(param_name):
                        reversed_x.append(input_name)
                op.desc.set_input('X', reversed_x)
                op.desc.set_output('Out', reversed_x)

        main_block._sync_with_cpp()

    def _shard_gradient_clip(self, main_block):

        if self.stage < 2:
            return

        # TODO (JZ-LIANG) support calculate global norm with tensor parallelism
J
JZ-LIANG 已提交
196 197 198 199
        removed_op_type = ['elementwise_mul', 'squared_l2_norm', 'clip_by_norm']
        removed_op_idx = set()
        removed_tmp_var = set()

J
JZ-LIANG 已提交
200 201 202 203
        for idx, op in list(enumerate(main_block.ops)):
            if not _is_gradient_clip_op(op):
                continue

J
JZ-LIANG 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
            if op.type in removed_op_type:
                input_name = op.input("X")[0]
                param_name = input_name[:input_name.find("@GRAD")]
                if not self._is_parameter_in_local_shard(param_name):
                    removed_op_idx.add(idx)
                    if op.type in ['squared_l2_norm', 'clip_by_norm']:
                        for output_name in op.output_arg_names:
                            removed_tmp_var.add(output_name)

        for idx, op in reversed(list(enumerate(main_block.ops))):
            if not _is_gradient_clip_op(op):
                continue
            if idx in removed_op_idx:
                main_block._remove_op(idx, sync=False)

        for varname in removed_tmp_var:
            main_block._remove_var(varname, sync=False)

J
JZ-LIANG 已提交
222 223 224 225 226 227 228 229 230 231 232 233 234
        for idx, op in list(enumerate(main_block.ops)):
            if not _is_gradient_clip_op(op):
                continue
            if op.type == 'sum':
                reserved_vars = []
                for input_name in op.input_arg_names:
                    if input_name not in removed_tmp_var:
                        reserved_vars.append(input_name)
                op.desc.set_input("X", reserved_vars)

                sum_op_output = op.desc.output_arg_names()[0]
                for i, sharding_info in enumerate(self.sharding_infos):
                    new_op = main_block._insert_op(
J
JZ-LIANG 已提交
235
                        idx + i + 1,
J
JZ-LIANG 已提交
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
                        type='c_allreduce_sum',
                        inputs={'X': [sum_op_output]},
                        outputs={'Out': [sum_op_output]},
                        attrs={
                            'ring_id': sharding_info.group.id,
                            'op_namescope': "/gradient_clip_model_parallelism",
                            'use_calc_stream': True,
                            OP_ROLE_KEY: OpRole.Optimize,
                        })
                    dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
                        main_block.var(sum_op_output))
                    assert dist_attr is not None
                    naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                        new_op, dist_attr.process_mesh, dist_attr.dims_mapping,
                        self._dist_context)
                break

        main_block._sync_with_cpp()

    def _shard_weight_decay(self, main_block):

        if self.stage < 2:
            return

        for idx, op in reversed(list(enumerate(main_block.ops))):
            if not _is_weight_decay_op(op):
                continue
            else:
                raise NotImplementedError(
                    "weight decay is NOT supported by now")
        main_block._sync_with_cpp()

    def _shard_optimizer_ops_and_states(self, main_block, startup_block):

        should_removed_optimizer_states = []
        for idx, op in reversed(list(enumerate(main_block.ops))):
            if not is_optimizer_op(op):
                break

            if op.type in _supported_optimizer_type:
                assert "Param" in op.input_names
                assert len(op.input("Param")) == 1
                param_name = op.input("Param")[0]
                if not self._is_parameter_in_local_shard(param_name):
                    should_removed_optimizer_states.extend([
                        varname for varname in op.output_arg_names
                        if varname != param_name
                    ])
                    main_block._remove_op(idx, sync=False)

        for idx, op in reversed(list(enumerate(startup_block.ops))):
            if len(op.output_arg_names) == 1 and op.output_arg_names[
                    0] in should_removed_optimizer_states:
                startup_block._remove_op(idx, sync=False)

        for varname in should_removed_optimizer_states:
            if main_block.has_var(varname):
                main_block._remove_var(varname, sync=False)
            if startup_block.has_var(varname):
                startup_block._remove_var(varname, sync=False)

        main_block._sync_with_cpp()
        startup_block._sync_with_cpp()

    def _insert_optimizer_broadcasts(self, main_block, startup_block):

        if self.stage > 2:
            return

        for sharding_info in self.sharding_infos:
            for param in sharding_info.params:
                assert main_block.has_var(param.name)
                assert startup_block.has_var(param.name)

                new_op = main_block.append_op(
                    type='c_broadcast',
                    inputs={'X': param},
                    outputs={'Out': param},
                    attrs={
                        'ring_id': sharding_info.group.id,
                        'root': sharding_info.get_var_rank(param.name),
                        'use_calc_stream': True,
                        OP_ROLE_KEY: OpRole.Optimize
                    })
                param_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
                    param)
                assert param_dist_attr is not None
                naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                    new_op, param_dist_attr.process_mesh,
                    param_dist_attr.dims_mapping, self._dist_context)
        main_block._sync_with_cpp()

    def _is_parameter_in_local_shard(self, param_name):
        assert param_name in self.varname_to_sharding_info
        sharding_info = self.varname_to_sharding_info[param_name]
        return sharding_info.is_in_local_shard(param_name)

    def _shard_gradient_synchronization(self, main_block):

        if self.stage < 2:
            return

        dp_ring_ids = [group.id for group in self.dp_groups]
        for idx, op in reversed(list(enumerate(main_block.ops))):
            if _is_param_grad_allreduce_op(op, main_block, dp_ring_ids):
                input_name = op.input_arg_names[0]
                base_name = _get_base_name_from_grad_name(input_name)
                sharding_info = self.varname_to_sharding_info[base_name]
                _insert_reduce_op(
                    main_block, idx, input_name, sharding_info.group.id,
                    sharding_info.get_var_rank(base_name), self._dist_context)
                if not self.partial_sharding:
                    main_block._remove_op(idx + 1, sync=False)
                else:
                    op._set_attr("ring_id", self.outer_dp_group.id)

        main_block._sync_with_cpp()

    def _shard_parameter(self, main_block, startup_block):

        if self.stage < 3:
            return

        dp_ring_ids = [group.id for group in self.dp_groups]
        for sharding_info in self.sharding_infos:
            need_broadcast_vars, param_usage = sharding_info.get_broadcast_vars_and_param_usage(
                main_block)
            not_used_param_nane = []
            for param_name in param_usage:
                if param_usage[param_name] == 0 and sharding_info.get_var_rank(
                        param_name) != sharding_info.local_rank:
                    not_used_param_nane.append(param_name)

            for idx, op in reversed(list(enumerate(main_block.ops))):
                if is_optimizer_op(op):
                    continue

                for input_name in op.desc.input_arg_names():
                    if op.type == "cast":
                        continue
                    if input_name not in need_broadcast_vars:
                        continue
                    root_rank = sharding_info.get_var_rank(input_name)
                    if root_rank == sharding_info.local_rank:
                        broadcast_varname = input_name
                    else:
                        broadcast_varname = unique_name.generate(input_name +
                                                                 "@BroadCast")
                        input_var = main_block.var(input_name)
                        new_var = main_block.create_var(
                            name=broadcast_varname,
                            shape=input_var.shape,
                            dtype=input_var.dtype,
                            persistable=False)
                        ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
                            input_var)
                        out_var_dist_attr = set_var_dist_attr(
                            self._dist_context, new_var,
                            ref_dist_attr.dims_mapping,
                            ref_dist_attr.process_mesh)
                        op._rename_input(input_name, broadcast_varname)

                    _insert_init_and_broadcast_op(
                        main_block, idx, broadcast_varname,
                        sharding_info.local_rank, root_rank,
                        sharding_info.group.id,
                        op.attr('op_role'), self._dist_context)

            for idx, op in reversed(list(enumerate(main_block.ops))):
                if op.type != "cast":
                    continue
                input_name = op.input_arg_names[0]
                output_name = op.output_arg_names[0]
                if input_name in not_used_param_nane:
                    main_block._remove_op(idx, sync=False)
                    main_block._remove_var(output_name, sync=False)

            for idx, op in reversed(list(enumerate(startup_block.ops))):
                assert len(op.output_arg_names) == 1
                output_name = op.output_arg_names[0]

                if op.type == "c_broadcast" and op.attr(
                        "ring_id") in dp_ring_ids:
                    if self.outer_dp_group and sharding_info.get_var_rank(
                            output_name) == sharding_info.local_rank:
                        op._set_attr("ring_id", self.outer_dp_group.id)
                    else:
                        startup_block._remove_op(idx, sync=False)
                    continue

J
JZ-LIANG 已提交
426 427
                if op.type != "c_broadcast" and output_name in param_usage and sharding_info.get_var_rank(
                        output_name) != sharding_info.local_rank:
J
JZ-LIANG 已提交
428 429
                    startup_block._remove_op(idx, sync=False)

J
JZ-LIANG 已提交
430 431 432 433 434
            for param_name in param_usage:
                if sharding_info.get_var_rank(
                        param_name) != sharding_info.local_rank:
                    main_block._remove_var(param_name, sync=False)
                    startup_block._remove_var(param_name, sync=False)
J
JZ-LIANG 已提交
435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598

        main_block._sync_with_cpp()
        startup_block._sync_with_cpp()


def _insert_init_and_broadcast_op(block, insert_idx, varname, local_rank,
                                  root_rank, ring_id, op_role, dist_context):
    """
    empty op for initialization
    """
    broadcast_var = block.var(varname)
    broadcast_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
        broadcast_var)

    new_op = block._insert_op_without_sync(
        insert_idx,
        type='c_broadcast',
        inputs={'X': varname},
        outputs={'Out': varname},
        attrs={
            'ring_id': ring_id,
            'root': root_rank,
            'use_calc_stream': True,
            OP_ROLE_KEY: op_role
        })
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
        new_op, broadcast_var_dist_attr.process_mesh,
        broadcast_var_dist_attr.dims_mapping, dist_context)
    if local_rank != root_rank:

        new_op = block._insert_op_without_sync(
            insert_idx,
            type="empty",
            outputs={"Out": broadcast_var.name},
            attrs={
                "shape": broadcast_var.shape,
                "dtype": broadcast_var.dtype,
                OP_ROLE_KEY: op_role
            })
        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
            new_op, broadcast_var_dist_attr.process_mesh,
            broadcast_var_dist_attr.dims_mapping, dist_context)
    return


def _insert_reduce_op(block,
                      insert_idx,
                      reduce_var,
                      ring_id,
                      root_id,
                      dist_context,
                      op_role=OpRole.Backward,
                      use_calc_stream=True):
    assert root_id >= 0, "root id should be a positive int, but now root id is {}".format(
        root_id)
    new_op = block._insert_op_without_sync(
        insert_idx,
        type='c_reduce_sum',
        inputs={'X': [reduce_var]},
        outputs={'Out': [reduce_var]},
        attrs={
            'ring_id': ring_id,
            'root_id': root_id,
            'use_calc_stream': use_calc_stream,
            OP_ROLE_KEY: op_role
        })

    dist_attr = dist_context.get_tensor_dist_attr_for_program(
        block.var(reduce_var))
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
        new_op, dist_attr.process_mesh, dist_attr.dims_mapping, dist_context)


def _get_dp_and_sharding_groups(origin_group, sharding_group_size, rank):
    dp_axis = 0
    sharding_axis = 1
    shape = [len(origin_group) // sharding_group_size, sharding_group_size]

    dp_group = _get_comm_group(origin_group, shape, dp_axis, rank)
    sharding_group = _get_comm_group(origin_group, shape, sharding_axis, rank)

    return dp_group, sharding_group


def _is_gradient_clip_op(op):
    return op.desc.has_attr("op_namescope") \
        and op.desc.attr("op_namescope").startswith("/gradient_clip")


def _is_weight_decay_op(op):
    return op.desc.has_attr("op_namescope") \
        and op.desc.attr("op_namescope").startswith("/regularization")


def _is_param_grad_fp32_cast_op(block, op):
    if not is_backward_op(op):
        return False
    if not _is_desired_cast_op(block, op, core.VarDesc.VarType.FP16,
                               core.VarDesc.VarType.FP32):
        return False
    output_name = op.desc.output_arg_names()[0]
    base_name = output_name[:output_name.find("@")]
    if not block.has_var(base_name):
        return False
    return block.var(base_name).is_parameter


def _is_param_fp16_cast_op(block, op, params):

    if is_optimizer_op(op):
        return False
    if not _is_desired_cast_op(block, op):
        return False
    input_name = op.desc.input_arg_names()[0]
    if input_name not in params:
        return False
    return True


def _is_desired_cast_op(block,
                        op,
                        src_var_type=core.VarDesc.VarType.FP32,
                        dst_var_type=core.VarDesc.VarType.FP16):
    if op.type != "cast":
        return False
    assert (len(op.desc.input_arg_names()) == 1)
    assert (len(op.desc.output_arg_names()) == 1)
    input_var = block.var(op.desc.input_arg_names()[0])
    output_var = block.var(op.desc.output_arg_names()[0])

    if input_var.dtype != src_var_type or \
        output_var.dtype != dst_var_type:
        return False

    return True


def _get_base_name_from_grad_name(grad_name):
    base_name = None
    if ".cast_fp16@GRAD" in grad_name:
        base_name = grad_name[:grad_name.find(".cast_fp16@GRAD")]
    elif "@GRAD" in grad_name:
        base_name = grad_name[:grad_name.find("@GRAD")]
    return base_name


def _is_param_grad_allreduce_op(op, block, dp_ring_ids):

    if not is_backward_op(op):
        return False
    if op.type != "c_allreduce_sum":
        return False
    if op.attr('ring_id') not in dp_ring_ids:
        return False

    output_name = op.output_arg_names[0]
    base_name = _get_base_name_from_grad_name(output_name)

    if not block.has_var(base_name):
        return False

    return block.var(base_name).is_parameter


J
JZ-LIANG 已提交
599 600 601 602
def _is_forward_op(op):
    return op.attr("op_role") == 0


J
JZ-LIANG 已提交
603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700
def _inference_data_parallel_group_for_operator(rank_id, op, dist_context):

    dp_group = None
    for input_name in op.input_arg_names:
        if not is_parameter_related(input_name, op.block):
            dist_attr = dist_context.get_op_dist_attr_for_program(op)
            process_mesh = dist_attr.process_mesh
            input_dim_mapping = dist_attr.get_input_dims_mapping(input_name)
            mesh_shape = process_mesh.topology
            # TODO(JZ-LIANG) replace with specific batch size dimension
            batch_size_axis = input_dim_mapping[0]
            if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
                group_ranks = _get_comm_group(process_mesh.processes,
                                              process_mesh.topology,
                                              batch_size_axis, rank_id)
                dp_group = new_process_group(group_ranks)
                break

    return dp_group


def shard_parameters(params, group_size):
    # TODO(JZ-LIANG) support multiple partition methods
    # method1: greedy even but unorder
    # method2: roughly even with oreder
    mapping = {}
    for rank_ in range(group_size):
        mapping[rank_] = []
    sizes = [0] * group_size
    for param in params:
        rank = sizes.index(min(sizes))
        mapping[rank].append(param)
        numel = reduce(lambda x, y: x * y, param.shape)
        assert numel > 0, "param [{}] should larger than 0, but it is [{}]".format(
            param.name, numel)
        sizes[rank] += numel

    return mapping


class ShardingInfo(object):
    def __init__(self, group, rank, params):
        self.group = group
        self.params = params
        self.param_names = [p.name for p in self.params]
        self.group_size = group.nranks
        self.global_rank = rank
        self.local_rank = group.ranks.index(self.global_rank)
        # rank in below mapping are local rank in this sharding group
        self.rank_to_params = shard_parameters(self.params, self.group_size)
        # include fp32 and fp16 param
        self.param_to_rank = dict()
        self._map_param_to_rank()

    def _map_param_to_rank(self):
        """
        mapping parameters to the rank which holds it.
        """
        for rank, params in self.rank_to_params.items():
            for param in params:
                self.param_to_rank[param.name] = rank

    def get_var_rank(self, varname):
        if varname in self.param_to_rank:
            return self.param_to_rank[varname]
        return -1

    def is_in_local_shard(self, param_name):
        return self.get_var_rank(param_name) == self.local_rank

    def get_broadcast_vars_and_param_usage(self, block):
        broadcast_vars = set([])
        fp16_params = set([])
        fp16_to_fp32 = {}

        param_usage = {x: 0 for x in self.param_names}
        for op in block.ops:
            if is_optimizer_op(op):
                continue
            for input_name in op.desc.input_arg_names():
                if input_name in self.param_names:
                    param_usage[input_name] += 1

        for op in block.ops:
            if not _is_param_fp16_cast_op(block, op, self.param_names):
                continue
            input_name = op.input_arg_names[0]
            output_name = op.output_arg_names[0]
            broadcast_vars.add(output_name)
            fp16_params.add(output_name)
            fp16_to_fp32[output_name] = input_name
            param_usage[input_name] -= 1
            self.param_to_rank[output_name] = self.param_to_rank[input_name]

        for param, usage in param_usage.items():
            if usage > 0:
                broadcast_vars.add(param)
        return broadcast_vars, param_usage