common.py 18.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
import abc
16

17
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
18

19
from ..dist_attribute import OperatorDistAttr
20
from ..process_group import new_process_group
21
from ..utils import _get_comm_group, _get_corresponding_rank, is_optimize_op
22

23 24
_g_distributed_operator_impl_containers = {}

25
_g_elementwise_ops = [
26 27 28 29 30 31 32
    "elementwise",
    "gelu",
    "dropout",
    "cast",
    "gather",
    "concat",
    "fused_softmax_mask_upper_triangle",
33
]
34
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
35 36


37
class ParallelMode:
38 39 40
    """
    the parallel mode for communication or auxiliary operator
    """
41

42 43 44 45 46 47
    DataParallel = "auto_parallel/data_parallel"
    ModelParallel = "auto_parallel/model_parallel"
    PipelineParalel = "auto_parallel/pipeline_paralel"
    MoEParallel = "auto_parallel/moe_parallel"


48 49 50 51 52 53 54 55 56
class SyncMode:
    """
    the synchorization mode for communication or auxiliary operator
    """

    AmpFlagSync = "auto_parallel/amp_flag_synchorization"
    GlobalNormSync = "auto_parallel/global_norm_synchorization"


57
def is_elementwise_op(op_type):
58 59 60 61
    if op_type in _g_elementwise_ops:
        return True
    if "elementwise" in op_type:
        return True
62
    return False
63 64


65
class DistributedOperatorImplContainer:
66 67
    def __init__(self, op_type):
        self._type = op_type
68
        self._impls = []
69 70 71 72 73 74 75 76 77 78 79 80

    @property
    def type(self):
        return self._type

    @type.setter
    def type(self, op_type):
        self._type = op_type

    @property
    def impls(self):
        return self._impls
81 82

    def register_impl(self, dist_impl):
83 84 85
        assert (
            self.type == dist_impl.type
        ), "Op type of container must be same as that of the implementation."
86 87
        impl_idx = len(self.impls)
        dist_impl.idx = impl_idx
88 89 90 91 92
        self._impls.append(dist_impl)

    def get_impl(self, impl_idx):
        return self._impls[impl_idx]

93 94 95 96 97 98
    def get_input_compatible_impls(self, dist_op):
        compatible_impls = []
        for impl in self.impls:
            if impl.is_input_compatible(dist_op):
                compatible_impls.append(impl)
        return compatible_impls
99

100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    def get_output_compatible_impls(self, dist_op):
        compatible_impls = []
        for impl in self.impls:
            if impl.is_output_compatible(dist_op):
                compatible_impls.append(impl)
        return compatible_impls

    def get_compatible_impls(self, dist_op):
        compatible_impls = []
        for impl in self.impls:
            if impl.is_auto_compatible(dist_op):
                compatible_impls.append(impl)
        return compatible_impls


class DistributedOperatorImpl(abc.ABC):
    def __init__(self, name):
        self._name = name
        self._type = None
        self._idx = None
120 121
        self._forward_implemented = False
        self._backward_implemented = False
122

123 124 125
    @property
    def name(self):
        return self._name
126

127 128 129
    @name.setter
    def name(self, name):
        self._name = name
130

131 132 133 134 135 136 137 138 139 140 141
    @property
    def type(self):
        return self._type

    @type.setter
    def type(self, op_type):
        self._type = op_type

    @property
    def idx(self):
        return self._idx
142

143 144 145 146 147
    @idx.setter
    def idx(self, impl_idx):
        self._idx = impl_idx

    @abc.abstractmethod
148
    def is_input_compatible(self, dist_op):
149 150
        raise NotImplementedError("Please Implement this method in Subclass.")

151
    @abc.abstractmethod
152
    def is_output_compatible(self, dist_op):
153 154
        raise NotImplementedError("Please Implement this method in Subclass.")

155
    @abc.abstractmethod
沉潜的鱼儿's avatar
沉潜的鱼儿 已提交
156 157 158
    def is_auto_compatible(self, dist_op):
        raise NotImplementedError("Please Implement this method in Subclass.")

159 160 161 162 163 164 165 166 167 168
    @staticmethod
    @abc.abstractmethod
    def forward(dist_ctx, *args, **kwargs):
        raise NotImplementedError("Please Implement this method in Subclass.")

    @staticmethod
    @abc.abstractmethod
    def backward(dist_ctx, *grad_outputs, **kwargs):
        raise NotImplementedError("Please Implement this method in Subclass.")

169
    def update_dims_mapping(self, dist_op):
170 171 172
        raise NotImplementedError("Please Implement this method in Subclass.")


173 174 175
def register_distributed_operator_impl_container(container):
    global _g_distributed_operator_impl_containers
    _g_distributed_operator_impl_containers[container.type] = container
176 177


178 179 180
def get_distributed_operator_impl_container(op_type):
    global _g_distributed_operator_impl_containers
    return _g_distributed_operator_impl_containers.get(op_type, None)
181 182


183 184
def register_distributed_operator_impl(op_type, dist_impl):
    dist_op_impl_container = get_distributed_operator_impl_container(op_type)
185
    if dist_op_impl_container is not None:
186
        dist_impl.type = op_type
187
        dist_op_impl_container.register_impl(dist_impl)
188
    else:
189
        assert False, "Must register distributed operator registry first."
190 191


192
def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True):
193
    """
194
    Here just return the first compatible implemention.
195 196
    This will be improved by cost model in the future.
    """
197 198 199
    op_type = dist_op.serial_op.type
    dist_op_impl_container = get_distributed_operator_impl_container(op_type)
    dist_op_eltwise_impl_container = get_distributed_operator_impl_container(
200 201
        "elementwise"
    )
202
    dist_op_default_impl_container = get_distributed_operator_impl_container(
203 204
        "default"
    )
205
    compatible_impls = []
206 207 208 209 210
    if partial:
        if fwd:
            # First, find impls in the corresponding container
            if dist_op_impl_container:
                compatible_impls.extend(
211 212
                    dist_op_impl_container.get_input_compatible_impls(dist_op)
                )
213 214 215 216
            # Second, find impls in the elementwise container
            if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
                compatible_impls.extend(
                    dist_op_eltwise_impl_container.get_input_compatible_impls(
217 218 219
                        dist_op
                    )
                )
220 221 222 223
            # Third, find impls in the default container
            if dist_op_default_impl_container:
                compatible_impls.extend(
                    dist_op_default_impl_container.get_input_compatible_impls(
224 225 226
                        dist_op
                    )
                )
227 228 229 230
        else:
            # First, find impls in the corresponding container
            if dist_op_impl_container:
                compatible_impls.extend(
231 232
                    dist_op_impl_container.get_output_compatible_impls(dist_op)
                )
233 234 235 236
            # Second, find impls in the elementwise container
            if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
                compatible_impls.extend(
                    dist_op_eltwise_impl_container.get_output_compatible_impls(
237 238 239
                        dist_op
                    )
                )
240 241 242 243
            # Third, find impls in the default container
            if dist_op_default_impl_container:
                compatible_impls.extend(
                    dist_op_default_impl_container.get_output_compatible_impls(
244 245 246
                        dist_op
                    )
                )
247
    else:
248 249 250
        # First, find impls in the corresponding container
        if dist_op_impl_container:
            compatible_impls.extend(
251 252
                dist_op_impl_container.get_compatible_impls(dist_op)
            )
253 254 255
        # Second, find impls in the elementwise container
        if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
            compatible_impls.extend(
256 257
                dist_op_eltwise_impl_container.get_compatible_impls(dist_op)
            )
258 259 260
        # Third, find impls in the default container
        if dist_op_default_impl_container:
            compatible_impls.extend(
261 262
                dist_op_default_impl_container.get_compatible_impls(dist_op)
            )
263

264
    if compatible_impls:
265
        # For now, just return the first compatible impl
266 267
        # best_compatible_impl = compatible_impls[0]
        best_compatible_impl = compatible_impls
268
    else:
269 270
        best_compatible_impl = None
    return best_compatible_impl
271 272


J
JZ-LIANG 已提交
273
def is_parameter_related(varname, block):
274
    if ".subprog_" in varname:
275
        varname = varname[: varname.index(".subprog_")]
J
JZ-LIANG 已提交
276
    if ".cast_fp" in varname:
277
        varname = varname[: varname.index(".cast_fp")]
X
xu98bin 已提交
278 279
    if ".cast_bf" in varname:
        varname = varname[: varname.index(".cast_bf")]
280
    if ".quantized" in varname:
281
        varname = varname[: varname.index(".quantized")]
X
xu98bin 已提交
282 283
    # if "@RESHARD" in varname:
    #     varname = varname[: varname.index("@RESHARD")]
Z
zhaoyingli 已提交
284 285
    assert block._find_var_recursive(varname)
    var = block._var_recursive(varname)
J
JZ-LIANG 已提交
286 287 288
    return var.is_parameter


Z
zhaoyingli 已提交
289
def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
Z
zhaoyingli 已提交
290
    var_shape = block._var_recursive(src_var.name).shape
291
    var_topoloy = src_var_dist_attr.process_mesh.shape
Z
zhaoyingli 已提交
292 293 294 295 296 297 298 299 300 301 302
    var_dims_mapping = src_var_dist_attr.dims_mapping

    complete_shape = []
    for idx, shape in enumerate(var_shape):
        if var_dims_mapping[idx] == -1:
            complete_shape.append(shape)
        else:
            new_shape = shape * var_topoloy[var_dims_mapping[idx]]
            complete_shape.append(new_shape)

    exact_shape = []
303
    input_topology = op_input_dist_attr.process_mesh.shape
Z
zhaoyingli 已提交
304 305 306 307 308 309 310 311 312
    input_dims_mapping = op_input_dist_attr.dims_mapping
    for idx, shape in enumerate(complete_shape):
        if input_dims_mapping[idx] == -1:
            exact_shape.append(shape)
        else:
            new_shape = shape // input_topology[input_dims_mapping[idx]]
            exact_shape.append(new_shape)

    return exact_shape
313 314


315 316 317
def set_comm_op_dist_attr_for_program(
    new_op, process_mesh, tensor_dist_attr, ctx
):
318 319 320
    assert process_mesh is not None
    assert tensor_dist_attr is not None

321
    new_op_dist_attr = OperatorDistAttr()
322 323 324 325 326 327 328 329 330 331 332
    new_op_dist_attr.process_mesh = process_mesh
    for input_varname in new_op.desc.input_arg_names():
        new_op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr)
    for output_varname in new_op.desc.output_arg_names():
        new_op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr)
    ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)


def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx):

    ref_dist_attr = ctx.get_op_dist_attr_for_program(ref_op)
333
    new_op_dist_attr = OperatorDistAttr()
334 335 336 337 338 339 340 341
    new_op_dist_attr.process_mesh = ref_dist_attr.process_mesh

    for input_name in ref_op.input_names:
        assert input_name in new_op.input_names
        assert len(ref_op.input(input_name)) == 1
        assert len(new_op.input(input_name)) == 1

        ref_tensor_dist_attr = ref_dist_attr.get_input_dist_attr(
342 343
            ref_op.input(input_name)[0]
        )
344
        new_op_dist_attr.set_input_dist_attr(
345 346
            new_op.input(input_name)[0], ref_tensor_dist_attr
        )
347 348 349 350 351 352 353

    for output_name in ref_op.output_names:
        assert output_name in new_op.output_names
        assert len(ref_op.output(output_name)) == 1
        assert len(new_op.output(output_name)) == 1

        ref_tensor_dist_attr = ref_dist_attr.get_output_dist_attr(
354 355
            ref_op.output(output_name)[0]
        )
356
        new_op_dist_attr.set_output_dist_attr(
357 358
            new_op.output(output_name)[0], ref_tensor_dist_attr
        )
359 360

    ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
361 362 363 364 365 366 367 368


def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):
    """
    deduce the data parallel communication group for current operator.

    Args:
        dist_ctx (DistributedContext): dist context.
369 370 371
        op (Operator): the current (backward) operator which might need.
        act_grad_names (list): list of input activation grads variable name to the current operator.
        out_grad_names (list): list of the output parameter's grads variable name of the current operator.
372 373 374 375 376 377
        rank (int): global ranks index for current process.
    """
    dp_group = None

    op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
    process_mesh = op_dist_attr.process_mesh
378
    mesh_shape = process_mesh.shape
379 380
    # FIXME Hack for Pipeline Parallelism where the current operator
    # not belong to the mesh the current rank belong to.
381
    if rank not in process_mesh.process_ids:
382 383 384 385 386 387 388 389 390
        rank = _get_corresponding_rank(dist_ctx, process_mesh, rank)

    for var_name in act_grad_names:
        var_dim_mapping = op_dist_attr.get_input_dims_mapping(var_name)
        # consider that the variable's shape is None
        # TODO utilize the batch_dim attr instead of "0" in future
        batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1

        if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
391
            group_ranks = _get_comm_group(
392 393
                process_mesh.process_ids,
                process_mesh.shape,
394 395 396
                batch_size_axis,
                rank,
            )
397 398 399 400 401 402 403 404
            dp_group = new_process_group(group_ranks)
            break

    return dp_group


def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names):
    """
405
    insert the allreudce and scale ops for gradients of model
406 407 408 409
    parameters for operator in data parallelism.

    Args:
        dist_ctx (DistributedContext): dist context.
410 411
        op (Operator): the current (backward) operator which might need.
        allreduce_var_names (list): list of the parameter's grads variable name in the current operator output.
412 413 414 415 416 417 418 419 420 421 422
    """

    op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
    process_mesh = op_dist_attr.process_mesh
    dist_op_context = dist_ctx.dist_op_context
    main_block = dist_op_context.work_block
    dp_degree = len(dp_group.ranks)

    for var_name in allreduce_var_names:
        added_ops = []
        grad_var = main_block.var(var_name)
423 424 425 426 427 428 429 430 431 432 433 434 435
        allreduce_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': [grad_var]},
            outputs={'Out': [grad_var]},
            attrs={
                'ring_id': dp_group.id,
                'use_calc_stream': True,
                OP_ROLE_KEY: OpRole.Backward,
            },
        )
        allreduce_op._set_attr(
            'op_namescope', str('/') + ParallelMode.DataParallel
        )
436 437 438
        added_ops.append(allreduce_op)

        if dist_ctx.gradient_scale:
439 440 441 442 443 444 445 446 447
            scale_op = main_block.append_op(
                type='scale',
                inputs={'X': grad_var},
                outputs={'Out': grad_var},
                attrs={'scale': 1.0 / dp_degree, OP_ROLE_KEY: OpRole.Backward},
            )
            scale_op._set_attr(
                'op_namescope', str('/') + ParallelMode.DataParallel
            )
448 449 450
            added_ops.append(scale_op)

        dims_mapping = op_dist_attr.get_output_dims_mapping(grad_var.name)
451 452
        assert (
            dims_mapping is not None
453
        ), "Unexpected: dims_mapping of output [{}] of op [{}] is None".format(
454 455
            grad_var.name, op_dist_attr.op_type
        )
456 457
        # NOTE auxiliary op's dist attr should follow dist_op not dist_tensor
        for new_op in added_ops:
458
            new_op_attr = OperatorDistAttr()
459 460 461 462 463 464
            new_op_attr.process_mesh = process_mesh
            new_op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
            new_op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
            dist_ctx.set_op_dist_attr_for_program(new_op, new_op_attr)


465 466 467
def gradient_synchronization(
    dist_ctx, op, act_grad_names, out_grad_names, rank
):
468
    """
469
    conduct the allreudce and scaling(dp size)for gradients of model
470 471 472 473
    parameters for operator in data parallelism.

    Args:
        dist_ctx (DistributedContext): dist context.
474 475 476
        op (Operator): the current (backward) operator which might need.
        act_grad_names (list): list of input activation grads variable name to the current operator.
        out_grad_names (list): list of the output parameter's grads variable name of the current operator.
477 478 479
        rank (int): global ranks index for current process.
    """

480 481 482
    if not is_in_backward_phase(dist_ctx):
        return

483 484 485 486 487
    if (
        is_optimize_op(op)
        or len(act_grad_names) == 0
        or len(out_grad_names) == 0
    ):
488 489 490 491 492 493 494 495
        return

    dp_group = get_data_parallel_group(dist_ctx, op, act_grad_names, rank)

    if not dp_group:
        return

    sync_and_scale_gradients(dist_ctx, op, dp_group, out_grad_names)
496 497 498


def is_data_parallel_scale_op(op):
499 500 501 502 503
    return (
        op.type == "scale"
        and op.desc.has_attr("op_namescope")
        and ParallelMode.DataParallel in op.desc.attr("op_namescope")
    )
504 505 506


def is_data_parallel_reduce_op(op):
507 508 509 510 511
    return (
        op.type in ["c_reduce_sum", "c_allreduce_sum"]
        and op.desc.has_attr("op_namescope")
        and ParallelMode.DataParallel in op.desc.attr("op_namescope")
    )
512 513


514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
def is_amp_flag_sync_op(op):
    return (
        op.type == "c_allreduce_max"
        and op.desc.has_attr("op_namescope")
        and SyncMode.AmpFlagSync in op.desc.attr("op_namescope")
    )


def is_global_norm_sync_op(op):
    return (
        op.type == "c_allreduce_sum"
        and op.desc.has_attr("op_namescope")
        and SyncMode.GlobalNormSync in op.desc.attr("op_namescope")
    )


530 531 532 533 534 535 536
def is_in_backward_phase(dist_ctx):
    # NOTE currently high-order differential in Paddle dose NOT distinguish gradient computation operators
    # in Forward phase and operators in Backward phase (both with op_role=1), which will mislead
    # auto parallel to add gradient synchronization for gradient computation operators in Forward phase.
    # we use this FLAG to distinguish these two phases temporarily.

    return dist_ctx.dist_op_context.in_backward_phase()