common.py 18.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
import abc
16

17
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
18

19
from ..dist_attribute import OperatorDistAttr
20
from ..process_group import new_process_group
21
from ..utils import _get_comm_group, _get_corresponding_rank, is_optimize_op
22

23 24
_g_distributed_operator_impl_containers = {}

25
_g_elementwise_ops = [
26 27 28 29 30 31 32
    "elementwise",
    "gelu",
    "dropout",
    "cast",
    "gather",
    "concat",
    "fused_softmax_mask_upper_triangle",
33
]
34
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
35 36


37
class ParallelMode:
38 39 40
    """
    the parallel mode for communication or auxiliary operator
    """
41

42 43 44 45 46 47
    DataParallel = "auto_parallel/data_parallel"
    ModelParallel = "auto_parallel/model_parallel"
    PipelineParalel = "auto_parallel/pipeline_paralel"
    MoEParallel = "auto_parallel/moe_parallel"


48 49 50 51 52 53 54 55 56
class SyncMode:
    """
    the synchorization mode for communication or auxiliary operator
    """

    AmpFlagSync = "auto_parallel/amp_flag_synchorization"
    GlobalNormSync = "auto_parallel/global_norm_synchorization"


57
def is_elementwise_op(op_type):
58 59 60 61
    if op_type in _g_elementwise_ops:
        return True
    if "elementwise" in op_type:
        return True
62
    return False
63 64


65
class DistributedOperatorImplContainer:
66 67
    def __init__(self, op_type):
        self._type = op_type
68
        self._impls = []
69 70 71 72 73 74 75 76 77 78 79 80

    @property
    def type(self):
        return self._type

    @type.setter
    def type(self, op_type):
        self._type = op_type

    @property
    def impls(self):
        return self._impls
81 82

    def register_impl(self, dist_impl):
83 84 85
        assert (
            self.type == dist_impl.type
        ), "Op type of container must be same as that of the implementation."
86 87
        impl_idx = len(self.impls)
        dist_impl.idx = impl_idx
88 89 90 91 92
        self._impls.append(dist_impl)

    def get_impl(self, impl_idx):
        return self._impls[impl_idx]

93 94 95 96 97 98
    def get_input_compatible_impls(self, dist_op):
        compatible_impls = []
        for impl in self.impls:
            if impl.is_input_compatible(dist_op):
                compatible_impls.append(impl)
        return compatible_impls
99

100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    def get_output_compatible_impls(self, dist_op):
        compatible_impls = []
        for impl in self.impls:
            if impl.is_output_compatible(dist_op):
                compatible_impls.append(impl)
        return compatible_impls

    def get_compatible_impls(self, dist_op):
        compatible_impls = []
        for impl in self.impls:
            if impl.is_auto_compatible(dist_op):
                compatible_impls.append(impl)
        return compatible_impls


class DistributedOperatorImpl(abc.ABC):
    def __init__(self, name):
        self._name = name
        self._type = None
        self._idx = None
120 121
        self._forward_implemented = False
        self._backward_implemented = False
122

123 124 125
    @property
    def name(self):
        return self._name
126

127 128 129
    @name.setter
    def name(self, name):
        self._name = name
130

131 132 133 134 135 136 137 138 139 140 141
    @property
    def type(self):
        return self._type

    @type.setter
    def type(self, op_type):
        self._type = op_type

    @property
    def idx(self):
        return self._idx
142

143 144 145 146 147
    @idx.setter
    def idx(self, impl_idx):
        self._idx = impl_idx

    @abc.abstractmethod
148
    def is_input_compatible(self, dist_op):
149 150
        raise NotImplementedError("Please Implement this method in Subclass.")

151
    @abc.abstractmethod
152
    def is_output_compatible(self, dist_op):
153 154
        raise NotImplementedError("Please Implement this method in Subclass.")

155
    @abc.abstractmethod
沉潜的鱼儿's avatar
沉潜的鱼儿 已提交
156 157 158
    def is_auto_compatible(self, dist_op):
        raise NotImplementedError("Please Implement this method in Subclass.")

159 160 161 162 163 164 165 166 167 168
    @staticmethod
    @abc.abstractmethod
    def forward(dist_ctx, *args, **kwargs):
        raise NotImplementedError("Please Implement this method in Subclass.")

    @staticmethod
    @abc.abstractmethod
    def backward(dist_ctx, *grad_outputs, **kwargs):
        raise NotImplementedError("Please Implement this method in Subclass.")

169
    def update_dims_mapping(self, dist_op):
170 171 172
        raise NotImplementedError("Please Implement this method in Subclass.")


173 174 175
def register_distributed_operator_impl_container(container):
    global _g_distributed_operator_impl_containers
    _g_distributed_operator_impl_containers[container.type] = container
176 177


178 179 180
def get_distributed_operator_impl_container(op_type):
    global _g_distributed_operator_impl_containers
    return _g_distributed_operator_impl_containers.get(op_type, None)
181 182


183 184
def register_distributed_operator_impl(op_type, dist_impl):
    dist_op_impl_container = get_distributed_operator_impl_container(op_type)
185
    if dist_op_impl_container is not None:
186
        dist_impl.type = op_type
187
        dist_op_impl_container.register_impl(dist_impl)
188
    else:
189 190 191
        raise AssertionError(
            "Must register distributed operator registry first."
        )
192 193


194
def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True):
195
    """
C
chenxujun 已提交
196
    Here just return the first compatible implementation.
197 198
    This will be improved by cost model in the future.
    """
199 200 201
    op_type = dist_op.serial_op.type
    dist_op_impl_container = get_distributed_operator_impl_container(op_type)
    dist_op_eltwise_impl_container = get_distributed_operator_impl_container(
202 203
        "elementwise"
    )
204
    dist_op_default_impl_container = get_distributed_operator_impl_container(
205 206
        "default"
    )
207
    compatible_impls = []
208 209 210 211 212
    if partial:
        if fwd:
            # First, find impls in the corresponding container
            if dist_op_impl_container:
                compatible_impls.extend(
213 214
                    dist_op_impl_container.get_input_compatible_impls(dist_op)
                )
215 216 217 218
            # Second, find impls in the elementwise container
            if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
                compatible_impls.extend(
                    dist_op_eltwise_impl_container.get_input_compatible_impls(
219 220 221
                        dist_op
                    )
                )
222 223 224 225
            # Third, find impls in the default container
            if dist_op_default_impl_container:
                compatible_impls.extend(
                    dist_op_default_impl_container.get_input_compatible_impls(
226 227 228
                        dist_op
                    )
                )
229 230 231 232
        else:
            # First, find impls in the corresponding container
            if dist_op_impl_container:
                compatible_impls.extend(
233 234
                    dist_op_impl_container.get_output_compatible_impls(dist_op)
                )
235 236 237 238
            # Second, find impls in the elementwise container
            if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
                compatible_impls.extend(
                    dist_op_eltwise_impl_container.get_output_compatible_impls(
239 240 241
                        dist_op
                    )
                )
242 243 244 245
            # Third, find impls in the default container
            if dist_op_default_impl_container:
                compatible_impls.extend(
                    dist_op_default_impl_container.get_output_compatible_impls(
246 247 248
                        dist_op
                    )
                )
249
    else:
250 251 252
        # First, find impls in the corresponding container
        if dist_op_impl_container:
            compatible_impls.extend(
253 254
                dist_op_impl_container.get_compatible_impls(dist_op)
            )
255 256 257
        # Second, find impls in the elementwise container
        if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
            compatible_impls.extend(
258 259
                dist_op_eltwise_impl_container.get_compatible_impls(dist_op)
            )
260 261 262
        # Third, find impls in the default container
        if dist_op_default_impl_container:
            compatible_impls.extend(
263 264
                dist_op_default_impl_container.get_compatible_impls(dist_op)
            )
265

266
    if compatible_impls:
267
        # For now, just return the first compatible impl
268 269
        # best_compatible_impl = compatible_impls[0]
        best_compatible_impl = compatible_impls
270
    else:
271 272
        best_compatible_impl = None
    return best_compatible_impl
273 274


275 276
def is_parameter_related(varname, block, dist_context=None):
    # TODO(zhaoyingli): maintain a dict in dist_context to record all variables which are be renamed
277
    if ".subprog_" in varname:
278
        varname = varname[: varname.index(".subprog_")]
J
JZ-LIANG 已提交
279
    if ".cast_fp" in varname:
280
        varname = varname[: varname.index(".cast_fp")]
X
xu98bin 已提交
281 282
    if ".cast_bf" in varname:
        varname = varname[: varname.index(".cast_bf")]
283
    if ".quantized" in varname:
284
        varname = varname[: varname.index(".quantized")]
285 286
    assert block._find_var_recursive(
        varname
287
    ), f"cannot find var {varname} in cur block"
Z
zhaoyingli 已提交
288
    var = block._var_recursive(varname)
289 290 291 292 293 294 295
    # NOTE(hack method): to find the param which is resharded
    if dist_context and "@RESHARD" in varname:
        varname = varname[: varname.index("@RESHARD")]
        serial_program = dist_context.serial_main_program
        var = serial_program.global_block()._find_var_recursive(varname)
        if var is None:
            return False
J
JZ-LIANG 已提交
296 297 298
    return var.is_parameter


Z
zhaoyingli 已提交
299
def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
Z
zhaoyingli 已提交
300
    var_shape = block._var_recursive(src_var.name).shape
301
    var_topoloy = src_var_dist_attr.process_mesh.shape
Z
zhaoyingli 已提交
302 303 304 305 306 307 308 309 310 311 312
    var_dims_mapping = src_var_dist_attr.dims_mapping

    complete_shape = []
    for idx, shape in enumerate(var_shape):
        if var_dims_mapping[idx] == -1:
            complete_shape.append(shape)
        else:
            new_shape = shape * var_topoloy[var_dims_mapping[idx]]
            complete_shape.append(new_shape)

    exact_shape = []
313
    input_topology = op_input_dist_attr.process_mesh.shape
Z
zhaoyingli 已提交
314 315 316 317 318 319 320 321 322
    input_dims_mapping = op_input_dist_attr.dims_mapping
    for idx, shape in enumerate(complete_shape):
        if input_dims_mapping[idx] == -1:
            exact_shape.append(shape)
        else:
            new_shape = shape // input_topology[input_dims_mapping[idx]]
            exact_shape.append(new_shape)

    return exact_shape
323 324


325 326 327
def set_comm_op_dist_attr_for_program(
    new_op, process_mesh, tensor_dist_attr, ctx
):
328 329 330
    assert process_mesh is not None
    assert tensor_dist_attr is not None

331
    new_op_dist_attr = OperatorDistAttr()
332 333 334 335 336 337 338 339 340 341 342
    new_op_dist_attr.process_mesh = process_mesh
    for input_varname in new_op.desc.input_arg_names():
        new_op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr)
    for output_varname in new_op.desc.output_arg_names():
        new_op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr)
    ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)


def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx):

    ref_dist_attr = ctx.get_op_dist_attr_for_program(ref_op)
343
    new_op_dist_attr = OperatorDistAttr()
344 345 346 347 348 349 350 351
    new_op_dist_attr.process_mesh = ref_dist_attr.process_mesh

    for input_name in ref_op.input_names:
        assert input_name in new_op.input_names
        assert len(ref_op.input(input_name)) == 1
        assert len(new_op.input(input_name)) == 1

        ref_tensor_dist_attr = ref_dist_attr.get_input_dist_attr(
352 353
            ref_op.input(input_name)[0]
        )
354
        new_op_dist_attr.set_input_dist_attr(
355 356
            new_op.input(input_name)[0], ref_tensor_dist_attr
        )
357 358 359 360 361 362 363

    for output_name in ref_op.output_names:
        assert output_name in new_op.output_names
        assert len(ref_op.output(output_name)) == 1
        assert len(new_op.output(output_name)) == 1

        ref_tensor_dist_attr = ref_dist_attr.get_output_dist_attr(
364 365
            ref_op.output(output_name)[0]
        )
366
        new_op_dist_attr.set_output_dist_attr(
367 368
            new_op.output(output_name)[0], ref_tensor_dist_attr
        )
369 370

    ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
371 372 373 374 375 376 377 378


def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):
    """
    deduce the data parallel communication group for current operator.

    Args:
        dist_ctx (DistributedContext): dist context.
379 380 381
        op (Operator): the current (backward) operator which might need.
        act_grad_names (list): list of input activation grads variable name to the current operator.
        out_grad_names (list): list of the output parameter's grads variable name of the current operator.
382 383 384 385 386 387
        rank (int): global ranks index for current process.
    """
    dp_group = None

    op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
    process_mesh = op_dist_attr.process_mesh
388
    mesh_shape = process_mesh.shape
389 390
    # FIXME Hack for Pipeline Parallelism where the current operator
    # not belong to the mesh the current rank belong to.
391
    if rank not in process_mesh.process_ids:
392 393 394 395 396 397 398 399 400
        rank = _get_corresponding_rank(dist_ctx, process_mesh, rank)

    for var_name in act_grad_names:
        var_dim_mapping = op_dist_attr.get_input_dims_mapping(var_name)
        # consider that the variable's shape is None
        # TODO utilize the batch_dim attr instead of "0" in future
        batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1

        if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
401
            group_ranks = _get_comm_group(
402 403
                process_mesh.process_ids,
                process_mesh.shape,
404 405 406
                batch_size_axis,
                rank,
            )
407 408 409 410 411 412 413 414
            dp_group = new_process_group(group_ranks)
            break

    return dp_group


def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names):
    """
415
    insert the allreudce and scale ops for gradients of model
416 417 418 419
    parameters for operator in data parallelism.

    Args:
        dist_ctx (DistributedContext): dist context.
420 421
        op (Operator): the current (backward) operator which might need.
        allreduce_var_names (list): list of the parameter's grads variable name in the current operator output.
422 423 424 425 426 427 428 429 430 431 432
    """

    op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
    process_mesh = op_dist_attr.process_mesh
    dist_op_context = dist_ctx.dist_op_context
    main_block = dist_op_context.work_block
    dp_degree = len(dp_group.ranks)

    for var_name in allreduce_var_names:
        added_ops = []
        grad_var = main_block.var(var_name)
433 434 435 436 437 438 439 440 441 442
        allreduce_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': [grad_var]},
            outputs={'Out': [grad_var]},
            attrs={
                'ring_id': dp_group.id,
                'use_calc_stream': True,
                OP_ROLE_KEY: OpRole.Backward,
            },
        )
443
        allreduce_op._set_attr('op_namescope', '/' + ParallelMode.DataParallel)
444 445 446
        added_ops.append(allreduce_op)

        if dist_ctx.gradient_scale:
447 448 449 450 451 452
            scale_op = main_block.append_op(
                type='scale',
                inputs={'X': grad_var},
                outputs={'Out': grad_var},
                attrs={'scale': 1.0 / dp_degree, OP_ROLE_KEY: OpRole.Backward},
            )
453
            scale_op._set_attr('op_namescope', '/' + ParallelMode.DataParallel)
454 455 456
            added_ops.append(scale_op)

        dims_mapping = op_dist_attr.get_output_dims_mapping(grad_var.name)
457 458
        assert (
            dims_mapping is not None
459
        ), "Unexpected: dims_mapping of output [{}] of op [{}] is None".format(
460 461
            grad_var.name, op_dist_attr.op_type
        )
462 463
        # NOTE auxiliary op's dist attr should follow dist_op not dist_tensor
        for new_op in added_ops:
464
            new_op_attr = OperatorDistAttr()
465 466 467 468 469 470
            new_op_attr.process_mesh = process_mesh
            new_op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
            new_op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
            dist_ctx.set_op_dist_attr_for_program(new_op, new_op_attr)


471 472 473
def gradient_synchronization(
    dist_ctx, op, act_grad_names, out_grad_names, rank
):
474
    """
475
    conduct the allreudce and scaling(dp size)for gradients of model
476 477 478 479
    parameters for operator in data parallelism.

    Args:
        dist_ctx (DistributedContext): dist context.
480 481 482
        op (Operator): the current (backward) operator which might need.
        act_grad_names (list): list of input activation grads variable name to the current operator.
        out_grad_names (list): list of the output parameter's grads variable name of the current operator.
483 484 485
        rank (int): global ranks index for current process.
    """

486 487 488
    if not is_in_backward_phase(dist_ctx):
        return

489 490 491 492 493
    if (
        is_optimize_op(op)
        or len(act_grad_names) == 0
        or len(out_grad_names) == 0
    ):
494 495 496 497 498 499 500 501
        return

    dp_group = get_data_parallel_group(dist_ctx, op, act_grad_names, rank)

    if not dp_group:
        return

    sync_and_scale_gradients(dist_ctx, op, dp_group, out_grad_names)
502 503 504


def is_data_parallel_scale_op(op):
505 506 507 508 509
    return (
        op.type == "scale"
        and op.desc.has_attr("op_namescope")
        and ParallelMode.DataParallel in op.desc.attr("op_namescope")
    )
510 511 512


def is_data_parallel_reduce_op(op):
513 514 515 516 517
    return (
        op.type in ["c_reduce_sum", "c_allreduce_sum"]
        and op.desc.has_attr("op_namescope")
        and ParallelMode.DataParallel in op.desc.attr("op_namescope")
    )
518 519


520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535
def is_amp_flag_sync_op(op):
    return (
        op.type == "c_allreduce_max"
        and op.desc.has_attr("op_namescope")
        and SyncMode.AmpFlagSync in op.desc.attr("op_namescope")
    )


def is_global_norm_sync_op(op):
    return (
        op.type == "c_allreduce_sum"
        and op.desc.has_attr("op_namescope")
        and SyncMode.GlobalNormSync in op.desc.attr("op_namescope")
    )


536 537 538 539 540 541 542
def is_in_backward_phase(dist_ctx):
    # NOTE currently high-order differential in Paddle dose NOT distinguish gradient computation operators
    # in Forward phase and operators in Backward phase (both with op_role=1), which will mislead
    # auto parallel to add gradient synchronization for gradient computation operators in Forward phase.
    # we use this FLAG to distinguish these two phases temporarily.

    return dist_ctx.dist_op_context.in_backward_phase()