common.py 17.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15
import abc
16

17
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
18

19
from ..dist_attribute import OperatorDistributedAttribute
20
from ..process_group import new_process_group
21
from ..utils import _get_comm_group, _get_corresponding_rank, is_optimize_op
22

23 24
_g_distributed_operator_impl_containers = {}

25
_g_elementwise_ops = [
26 27 28 29 30 31 32
    "elementwise",
    "gelu",
    "dropout",
    "cast",
    "gather",
    "concat",
    "fused_softmax_mask_upper_triangle",
33
]
34
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
35 36


37
class ParallelMode:
38 39 40
    """
    the parallel mode for communication or auxiliary operator
    """
41

42 43 44 45 46 47
    DataParallel = "auto_parallel/data_parallel"
    ModelParallel = "auto_parallel/model_parallel"
    PipelineParalel = "auto_parallel/pipeline_paralel"
    MoEParallel = "auto_parallel/moe_parallel"


48
def is_elementwise_op(op_type):
49 50 51 52
    if op_type in _g_elementwise_ops:
        return True
    if "elementwise" in op_type:
        return True
53
    return False
54 55


56
class DistributedOperatorImplContainer:
57 58
    def __init__(self, op_type):
        self._type = op_type
59
        self._impls = []
60 61 62 63 64 65 66 67 68 69 70 71

    @property
    def type(self):
        return self._type

    @type.setter
    def type(self, op_type):
        self._type = op_type

    @property
    def impls(self):
        return self._impls
72 73

    def register_impl(self, dist_impl):
74 75 76
        assert (
            self.type == dist_impl.type
        ), "Op type of container must be same as that of the implementation."
77 78
        impl_idx = len(self.impls)
        dist_impl.idx = impl_idx
79 80 81 82 83
        self._impls.append(dist_impl)

    def get_impl(self, impl_idx):
        return self._impls[impl_idx]

84 85 86 87 88 89
    def get_input_compatible_impls(self, dist_op):
        compatible_impls = []
        for impl in self.impls:
            if impl.is_input_compatible(dist_op):
                compatible_impls.append(impl)
        return compatible_impls
90

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
    def get_output_compatible_impls(self, dist_op):
        compatible_impls = []
        for impl in self.impls:
            if impl.is_output_compatible(dist_op):
                compatible_impls.append(impl)
        return compatible_impls

    def get_compatible_impls(self, dist_op):
        compatible_impls = []
        for impl in self.impls:
            if impl.is_auto_compatible(dist_op):
                compatible_impls.append(impl)
        return compatible_impls


class DistributedOperatorImpl(abc.ABC):
    def __init__(self, name):
        self._name = name
        self._type = None
        self._idx = None
111 112
        self._forward_implemented = False
        self._backward_implemented = False
113

114 115 116
    @property
    def name(self):
        return self._name
117

118 119 120
    @name.setter
    def name(self, name):
        self._name = name
121

122 123 124 125 126 127 128 129 130 131 132
    @property
    def type(self):
        return self._type

    @type.setter
    def type(self, op_type):
        self._type = op_type

    @property
    def idx(self):
        return self._idx
133

134 135 136 137 138
    @idx.setter
    def idx(self, impl_idx):
        self._idx = impl_idx

    @abc.abstractmethod
139
    def is_input_compatible(self, dist_op):
140 141
        raise NotImplementedError("Please Implement this method in Subclass.")

142
    @abc.abstractmethod
143
    def is_output_compatible(self, dist_op):
144 145
        raise NotImplementedError("Please Implement this method in Subclass.")

146
    @abc.abstractmethod
沉潜的鱼儿's avatar
沉潜的鱼儿 已提交
147 148 149
    def is_auto_compatible(self, dist_op):
        raise NotImplementedError("Please Implement this method in Subclass.")

150 151 152 153 154 155 156 157 158 159
    @staticmethod
    @abc.abstractmethod
    def forward(dist_ctx, *args, **kwargs):
        raise NotImplementedError("Please Implement this method in Subclass.")

    @staticmethod
    @abc.abstractmethod
    def backward(dist_ctx, *grad_outputs, **kwargs):
        raise NotImplementedError("Please Implement this method in Subclass.")

160
    def update_dims_mapping(self, dist_op):
161 162 163
        raise NotImplementedError("Please Implement this method in Subclass.")


164 165 166
def register_distributed_operator_impl_container(container):
    global _g_distributed_operator_impl_containers
    _g_distributed_operator_impl_containers[container.type] = container
167 168


169 170 171
def get_distributed_operator_impl_container(op_type):
    global _g_distributed_operator_impl_containers
    return _g_distributed_operator_impl_containers.get(op_type, None)
172 173


174 175
def register_distributed_operator_impl(op_type, dist_impl):
    dist_op_impl_container = get_distributed_operator_impl_container(op_type)
176
    if dist_op_impl_container is not None:
177
        dist_impl.type = op_type
178
        dist_op_impl_container.register_impl(dist_impl)
179
    else:
180
        assert False, "Must register distributed operator registry first."
181 182


183
def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True):
184
    """
185
    Here just return the first compatible implemention.
186 187
    This will be improved by cost model in the future.
    """
188 189 190
    op_type = dist_op.serial_op.type
    dist_op_impl_container = get_distributed_operator_impl_container(op_type)
    dist_op_eltwise_impl_container = get_distributed_operator_impl_container(
191 192
        "elementwise"
    )
193
    dist_op_default_impl_container = get_distributed_operator_impl_container(
194 195
        "default"
    )
196
    compatible_impls = []
197 198 199 200 201
    if partial:
        if fwd:
            # First, find impls in the corresponding container
            if dist_op_impl_container:
                compatible_impls.extend(
202 203
                    dist_op_impl_container.get_input_compatible_impls(dist_op)
                )
204 205 206 207
            # Second, find impls in the elementwise container
            if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
                compatible_impls.extend(
                    dist_op_eltwise_impl_container.get_input_compatible_impls(
208 209 210
                        dist_op
                    )
                )
211 212 213 214
            # Third, find impls in the default container
            if dist_op_default_impl_container:
                compatible_impls.extend(
                    dist_op_default_impl_container.get_input_compatible_impls(
215 216 217
                        dist_op
                    )
                )
218 219 220 221
        else:
            # First, find impls in the corresponding container
            if dist_op_impl_container:
                compatible_impls.extend(
222 223
                    dist_op_impl_container.get_output_compatible_impls(dist_op)
                )
224 225 226 227
            # Second, find impls in the elementwise container
            if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
                compatible_impls.extend(
                    dist_op_eltwise_impl_container.get_output_compatible_impls(
228 229 230
                        dist_op
                    )
                )
231 232 233 234
            # Third, find impls in the default container
            if dist_op_default_impl_container:
                compatible_impls.extend(
                    dist_op_default_impl_container.get_output_compatible_impls(
235 236 237
                        dist_op
                    )
                )
238
    else:
239 240 241
        # First, find impls in the corresponding container
        if dist_op_impl_container:
            compatible_impls.extend(
242 243
                dist_op_impl_container.get_compatible_impls(dist_op)
            )
244 245 246
        # Second, find impls in the elementwise container
        if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
            compatible_impls.extend(
247 248
                dist_op_eltwise_impl_container.get_compatible_impls(dist_op)
            )
249 250 251
        # Third, find impls in the default container
        if dist_op_default_impl_container:
            compatible_impls.extend(
252 253
                dist_op_default_impl_container.get_compatible_impls(dist_op)
            )
254

255
    if compatible_impls:
256
        # For now, just return the first compatible impl
257 258
        # best_compatible_impl = compatible_impls[0]
        best_compatible_impl = compatible_impls
259
    else:
260 261
        best_compatible_impl = None
    return best_compatible_impl
262 263


J
JZ-LIANG 已提交
264
def is_parameter_related(varname, block):
265
    if ".subprog_" in varname:
266
        varname = varname[: varname.index(".subprog_")]
J
JZ-LIANG 已提交
267
    if ".cast_fp" in varname:
268
        varname = varname[: varname.index(".cast_fp")]
X
xu98bin 已提交
269 270
    if ".cast_bf" in varname:
        varname = varname[: varname.index(".cast_bf")]
271
    if ".quantized" in varname:
272
        varname = varname[: varname.index(".quantized")]
X
xu98bin 已提交
273 274
    # if "@RESHARD" in varname:
    #     varname = varname[: varname.index("@RESHARD")]
Z
zhaoyingli 已提交
275 276
    assert block._find_var_recursive(varname)
    var = block._var_recursive(varname)
J
JZ-LIANG 已提交
277 278 279
    return var.is_parameter


Z
zhaoyingli 已提交
280
def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
Z
zhaoyingli 已提交
281
    var_shape = block._var_recursive(src_var.name).shape
282
    var_topoloy = src_var_dist_attr.process_mesh.shape
Z
zhaoyingli 已提交
283 284 285 286 287 288 289 290 291 292 293
    var_dims_mapping = src_var_dist_attr.dims_mapping

    complete_shape = []
    for idx, shape in enumerate(var_shape):
        if var_dims_mapping[idx] == -1:
            complete_shape.append(shape)
        else:
            new_shape = shape * var_topoloy[var_dims_mapping[idx]]
            complete_shape.append(new_shape)

    exact_shape = []
294
    input_topology = op_input_dist_attr.process_mesh.shape
Z
zhaoyingli 已提交
295 296 297 298 299 300 301 302 303
    input_dims_mapping = op_input_dist_attr.dims_mapping
    for idx, shape in enumerate(complete_shape):
        if input_dims_mapping[idx] == -1:
            exact_shape.append(shape)
        else:
            new_shape = shape // input_topology[input_dims_mapping[idx]]
            exact_shape.append(new_shape)

    return exact_shape
304 305


306 307 308
def set_comm_op_dist_attr_for_program(
    new_op, process_mesh, tensor_dist_attr, ctx
):
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
    assert process_mesh is not None
    assert tensor_dist_attr is not None

    new_op_dist_attr = OperatorDistributedAttribute()
    new_op_dist_attr.process_mesh = process_mesh
    for input_varname in new_op.desc.input_arg_names():
        new_op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr)
    for output_varname in new_op.desc.output_arg_names():
        new_op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr)
    ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)


def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx):

    ref_dist_attr = ctx.get_op_dist_attr_for_program(ref_op)
    new_op_dist_attr = OperatorDistributedAttribute()
    new_op_dist_attr.process_mesh = ref_dist_attr.process_mesh

    for input_name in ref_op.input_names:
        assert input_name in new_op.input_names
        assert len(ref_op.input(input_name)) == 1
        assert len(new_op.input(input_name)) == 1

        ref_tensor_dist_attr = ref_dist_attr.get_input_dist_attr(
333 334
            ref_op.input(input_name)[0]
        )
335
        new_op_dist_attr.set_input_dist_attr(
336 337
            new_op.input(input_name)[0], ref_tensor_dist_attr
        )
338 339 340 341 342 343 344

    for output_name in ref_op.output_names:
        assert output_name in new_op.output_names
        assert len(ref_op.output(output_name)) == 1
        assert len(new_op.output(output_name)) == 1

        ref_tensor_dist_attr = ref_dist_attr.get_output_dist_attr(
345 346
            ref_op.output(output_name)[0]
        )
347
        new_op_dist_attr.set_output_dist_attr(
348 349
            new_op.output(output_name)[0], ref_tensor_dist_attr
        )
350 351

    ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
352 353 354 355 356 357 358 359


def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):
    """
    deduce the data parallel communication group for current operator.

    Args:
        dist_ctx (DistributedContext): dist context.
360 361 362
        op (Operator): the current (backward) operator which might need.
        act_grad_names (list): list of input activation grads variable name to the current operator.
        out_grad_names (list): list of the output parameter's grads variable name of the current operator.
363 364 365 366 367 368
        rank (int): global ranks index for current process.
    """
    dp_group = None

    op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
    process_mesh = op_dist_attr.process_mesh
369
    mesh_shape = process_mesh.shape
370 371
    # FIXME Hack for Pipeline Parallelism where the current operator
    # not belong to the mesh the current rank belong to.
372
    if rank not in process_mesh.process_ids:
373 374 375 376 377 378 379 380 381
        rank = _get_corresponding_rank(dist_ctx, process_mesh, rank)

    for var_name in act_grad_names:
        var_dim_mapping = op_dist_attr.get_input_dims_mapping(var_name)
        # consider that the variable's shape is None
        # TODO utilize the batch_dim attr instead of "0" in future
        batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1

        if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
382
            group_ranks = _get_comm_group(
383 384
                process_mesh.process_ids,
                process_mesh.shape,
385 386 387
                batch_size_axis,
                rank,
            )
388 389 390 391 392 393 394 395
            dp_group = new_process_group(group_ranks)
            break

    return dp_group


def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names):
    """
396
    insert the allreudce and scale ops for gradients of model
397 398 399 400
    parameters for operator in data parallelism.

    Args:
        dist_ctx (DistributedContext): dist context.
401 402
        op (Operator): the current (backward) operator which might need.
        allreduce_var_names (list): list of the parameter's grads variable name in the current operator output.
403 404 405 406 407 408 409 410 411 412 413
    """

    op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
    process_mesh = op_dist_attr.process_mesh
    dist_op_context = dist_ctx.dist_op_context
    main_block = dist_op_context.work_block
    dp_degree = len(dp_group.ranks)

    for var_name in allreduce_var_names:
        added_ops = []
        grad_var = main_block.var(var_name)
414 415 416 417 418 419 420 421 422 423 424 425 426
        allreduce_op = main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': [grad_var]},
            outputs={'Out': [grad_var]},
            attrs={
                'ring_id': dp_group.id,
                'use_calc_stream': True,
                OP_ROLE_KEY: OpRole.Backward,
            },
        )
        allreduce_op._set_attr(
            'op_namescope', str('/') + ParallelMode.DataParallel
        )
427 428 429
        added_ops.append(allreduce_op)

        if dist_ctx.gradient_scale:
430 431 432 433 434 435 436 437 438
            scale_op = main_block.append_op(
                type='scale',
                inputs={'X': grad_var},
                outputs={'Out': grad_var},
                attrs={'scale': 1.0 / dp_degree, OP_ROLE_KEY: OpRole.Backward},
            )
            scale_op._set_attr(
                'op_namescope', str('/') + ParallelMode.DataParallel
            )
439 440 441
            added_ops.append(scale_op)

        dims_mapping = op_dist_attr.get_output_dims_mapping(grad_var.name)
442 443 444 445 446
        assert (
            dims_mapping is not None
        ), "Unexception: dims_mapping of output [{}] of op [{}] is None".format(
            grad_var.name, op_dist_attr.op_type
        )
447 448 449 450 451 452 453 454 455
        # NOTE auxiliary op's dist attr should follow dist_op not dist_tensor
        for new_op in added_ops:
            new_op_attr = OperatorDistributedAttribute()
            new_op_attr.process_mesh = process_mesh
            new_op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
            new_op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
            dist_ctx.set_op_dist_attr_for_program(new_op, new_op_attr)


456 457 458
def gradient_synchronization(
    dist_ctx, op, act_grad_names, out_grad_names, rank
):
459
    """
460
    conduct the allreudce and scaling(dp size)for gradients of model
461 462 463 464
    parameters for operator in data parallelism.

    Args:
        dist_ctx (DistributedContext): dist context.
465 466 467
        op (Operator): the current (backward) operator which might need.
        act_grad_names (list): list of input activation grads variable name to the current operator.
        out_grad_names (list): list of the output parameter's grads variable name of the current operator.
468 469 470
        rank (int): global ranks index for current process.
    """

471 472 473
    if not is_in_backward_phase(dist_ctx):
        return

474 475 476 477 478
    if (
        is_optimize_op(op)
        or len(act_grad_names) == 0
        or len(out_grad_names) == 0
    ):
479 480 481 482 483 484 485 486
        return

    dp_group = get_data_parallel_group(dist_ctx, op, act_grad_names, rank)

    if not dp_group:
        return

    sync_and_scale_gradients(dist_ctx, op, dp_group, out_grad_names)
487 488 489


def is_data_parallel_scale_op(op):
490 491 492 493 494
    return (
        op.type == "scale"
        and op.desc.has_attr("op_namescope")
        and ParallelMode.DataParallel in op.desc.attr("op_namescope")
    )
495 496 497


def is_data_parallel_reduce_op(op):
498 499 500 501 502
    return (
        op.type in ["c_reduce_sum", "c_allreduce_sum"]
        and op.desc.has_attr("op_namescope")
        and ParallelMode.DataParallel in op.desc.attr("op_namescope")
    )
503 504 505 506 507 508 509 510 511


def is_in_backward_phase(dist_ctx):
    # NOTE currently high-order differential in Paddle dose NOT distinguish gradient computation operators
    # in Forward phase and operators in Backward phase (both with op_role=1), which will mislead
    # auto parallel to add gradient synchronization for gradient computation operators in Forward phase.
    # we use this FLAG to distinguish these two phases temporarily.

    return dist_ctx.dist_op_context.in_backward_phase()