mp_ops.py 28.6 KB
Newer Older
W
wuhuachaocoding 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
16
from paddle import _legacy_C_ops
W
wuhuachaocoding 已提交
17 18 19 20 21 22 23 24
from paddle.fluid import core
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import _in_legacy_dygraph
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.dygraph import layers
from paddle.distributed import collective
25
from ....communication.reduce import ReduceOp, _get_reduce_op
W
wuhuachaocoding 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
from paddle.fluid.data_feeder import check_dtype
import paddle.fluid.dygraph_utils as dygraph_utils


def _c_identity(tensor, group=None):
    """
    Return a copy of the tensor, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

46 47 48 49 50 51
    if in_dygraph_mode():
        from paddle.autograd import PyLayer

        class c_identity_eager(PyLayer):
            @staticmethod
            def forward(ctx, tensor):
52 53 54 55 56 57 58 59 60
                return _legacy_C_ops.c_identity(
                    tensor,
                    'use_calc_stream',
                    True,
                    'ring_id',
                    group.id,
                    'use_model_parallel',
                    True,
                )
61 62 63

            @staticmethod
            def backward(ctx, dy):
64
                op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
65 66 67 68 69 70
                group.process_group.allreduce_on_calc_stream(dy, op_type)
                return dy

        return c_identity_eager.apply(tensor)

    elif _in_legacy_dygraph():
71 72 73 74 75 76 77 78 79
        return _legacy_C_ops.c_identity(
            tensor,
            'use_calc_stream',
            True,
            'ring_id',
            ring_id,
            'use_model_parallel',
            True,
        )
W
wuhuachaocoding 已提交
80 81 82 83 84
    op_type = 'c_identity'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
        tensor,
        'tensor',
        ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_identity',
    )

    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': True,
            'use_model_parallel': True,
        },
    )
W
wuhuachaocoding 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    return out


def _c_concat(tensor, group=None):
    """
    Return allgather of the tensor, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
    if group is not None and not group.is_member():
        return
    group = collective._get_default_group() if group is None else group
    ring_id = group.id

    global_rank = collective._get_global_env().rank
    rank = group.rank
    nranks = group.nranks

    if _non_static_mode():
126 127 128 129 130 131 132 133 134 135 136 137 138
        return _legacy_C_ops.c_concat(
            tensor,
            'ring_id',
            ring_id,
            'use_calc_stream',
            True,
            'rank',
            rank,
            'nranks',
            nranks,
            'use_model_parallel',
            True,
        )
W
wuhuachaocoding 已提交
139 140 141 142 143 144

    op_type = 'c_concat'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
        tensor,
        'tensor',
        ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_concat',
    )

    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': True,
            'use_model_parallel': True,
            'nranks': nranks,
            'rank': rank,
        },
    )
W
wuhuachaocoding 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
    return out


def _c_split(tensor, group=None):
    """
    Split tensor evenly among all members, mainly used with model parallel.

    Args:
        tensor (Tensor): The input Tensor. Its data type
            should be float16, float32, float64, int32 or int64.
        rank (int): The rank of the current process.
        group (int): The id of the process group to work on.

    Returns:
        Tensor.
    """
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

    global_rank = collective._get_global_env().rank
    rank = global_rank if group is None else group.get_group_rank(global_rank)
185 186 187 188 189
    nranks = (
        collective._get_global_env().world_size
        if group is None
        else group.nranks
    )
W
wuhuachaocoding 已提交
190 191

    if _non_static_mode():
192 193 194 195 196 197 198 199 200 201 202 203 204
        return _legacy_C_ops.c_split(
            tensor,
            'use_calc_stream',
            True,
            'ring_id',
            ring_id,
            'rank',
            rank,
            'nranks',
            nranks,
            'use_model_parallel',
            True,
        )
W
wuhuachaocoding 已提交
205 206 207 208 209 210

    op_type = 'c_split'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
        tensor,
        'tensor',
        ['float16', 'float32', 'float64', 'int32', 'int64'],
        '_c_split',
    )

    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': True,
            'rank': rank,
            'nranks': nranks,
            'use_model_parallel': True,
        },
    )
W
wuhuachaocoding 已提交
229 230 231
    return out


232 233 234 235 236 237 238 239
def _mp_allreduce(
    tensor,
    op=ReduceOp.SUM,
    group=None,
    use_calc_stream=True,
    use_model_parallel=True,
):
    """[it is same as allreduce above, but it supports model parallel. And it support inplace startegy]"""
W
wuhuachaocoding 已提交
240 241 242 243 244 245 246 247 248 249 250
    if group is not None and not group.is_member():
        return

    if in_dygraph_mode():
        group = collective._get_default_group() if group is None else group
        assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op)

        from paddle.autograd import PyLayer

        class mp_allreduce_eager(PyLayer):
            @staticmethod
251 252 253
            def forward(
                ctx, tensor, group, use_calc_stream, use_model_parallel
            ):
W
wuhuachaocoding 已提交
254 255 256
                ctx.ring_id = group.id

                if use_calc_stream:
257
                    op_type = _get_reduce_op(op, "_mp_allreduce")
W
wuhuachaocoding 已提交
258
                    group.process_group.allreduce_on_calc_stream(
259 260
                        tensor, op_type
                    )
W
wuhuachaocoding 已提交
261 262 263
                    return tensor
                else:
                    return _legacy_C_ops.c_allreduce_sum_(
264 265 266 267 268 269 270 271
                        tensor,
                        'use_calc_stream',
                        use_calc_stream,
                        'ring_id',
                        ring_id,
                        "use_model_parallel",
                        use_model_parallel,
                    )
W
wuhuachaocoding 已提交
272 273 274

            @staticmethod
            def backward(ctx, dy):
275 276 277 278 279 280 281 282 283 284 285 286 287
                return _legacy_C_ops.c_identity(
                    dy,
                    'use_calc_stream',
                    True,
                    'ring_id',
                    ctx.ring_id,
                    'use_model_parallel',
                    True,
                )

        return mp_allreduce_eager.apply(
            tensor, group, use_calc_stream, use_model_parallel
        )
W
wuhuachaocoding 已提交
288 289 290 291

    ring_id = 0 if group is None else group.id
    if _in_legacy_dygraph():
        if op == ReduceOp.SUM:
292 293 294 295 296 297 298 299 300
            return _legacy_C_ops.c_allreduce_sum_(
                tensor,
                'use_calc_stream',
                use_calc_stream,
                'ring_id',
                ring_id,
                "use_model_parallel",
                use_model_parallel,
            )
W
wuhuachaocoding 已提交
301 302 303 304 305 306 307 308
        else:
            raise ValueError("Unknown parameter: {}.".format(op))

    op_type = 'c_allreduce_sum'
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

    check_variable_and_dtype(
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
        tensor,
        'tensor',
        ['float16', 'float32', 'float64', 'int32', 'int64'],
        op_type,
    )

    helper.append_op(
        type=op_type,
        inputs={'X': tensor},
        outputs={'Out': out},
        attrs={
            'ring_id': ring_id,
            'use_calc_stream': use_calc_stream,
            'use_model_parallel': use_model_parallel,
        },
    )
W
wuhuachaocoding 已提交
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
    return out


def _c_lookup_table(table, index, start_index=0, name=None):
    """
    Lookup table according to index.

    Args:
        table (Tensor): The input Tensor. Its data type
            should be float16, float32, float64.
        index (Tensor): The index to lookup table.
        start_index (int): The initial index for table range.
        name (string): The name of the api

    Returns:
        Tensor.
    """
    if _non_static_mode():
343 344 345
        return _legacy_C_ops.c_embedding(
            table, index, "start_index", start_index
        )
W
wuhuachaocoding 已提交
346 347 348 349 350 351

    op_type = 'c_embedding'
    helper = LayerHelper(op_type, **locals())
    dtype = helper.input_dtype(input_param_name='table')
    check_variable_and_dtype(index, 'input', ['int32', 'int64'], op_type)
    tmp = helper.create_variable_for_type_inference(dtype)
352 353 354 355 356 357
    helper.append_op(
        type='c_embedding',
        inputs={'Ids': index, 'W': table},
        outputs={'Out': tmp},
        attrs={"start_index": start_index},
    )
W
wuhuachaocoding 已提交
358 359 360 361 362 363 364 365
    return tmp


class _Linear(layers.Layer):
    """
    Linear
    """

366 367 368 369 370 371 372 373
    def __init__(
        self,
        in_features,
        out_features,
        weight_attr=None,
        bias_attr=None,
        name=None,
    ):
374
        super().__init__()
W
wuhuachaocoding 已提交
375 376 377
        self._dtype = self._helper.get_default_dtype()
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
378 379 380 381 382 383 384 385 386 387 388 389
        self.weight = self.create_parameter(
            shape=[in_features, out_features],
            attr=self._weight_attr,
            dtype=self._dtype,
            is_bias=False,
        )
        self.bias = self.create_parameter(
            shape=[out_features],
            attr=self._bias_attr,
            dtype=self._dtype,
            is_bias=True,
        )
W
wuhuachaocoding 已提交
390 391 392
        self.name = name

    def forward(self, input):
393 394 395
        out = _linear(
            x=input, weight=self.weight, bias=self.bias, name=self.name
        )
W
wuhuachaocoding 已提交
396 397 398 399 400
        return out

    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'in_features={}, out_features={}, dtype={}{}'.format(
401 402
            self.weight.shape[0], self.weight.shape[1], self._dtype, name_str
        )
W
wuhuachaocoding 已提交
403 404


405 406 407
def _c_softmax_with_cross_entropy(
    logits, label, group=None, return_softmax=False
):
W
wuhuachaocoding 已提交
408 409 410 411 412
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id
    global_rank = collective._get_global_env().rank
    rank = global_rank if group is None else group.get_group_rank(global_rank)
413 414 415 416 417
    nranks = (
        collective._get_global_env().world_size
        if group is None
        else group.nranks
    )
W
wuhuachaocoding 已提交
418 419 420 421 422 423

    input_dims = len(list(logits.shape))
    label_dims = len(list(label.shape))
    if input_dims - 1 != label_dims and input_dims != label_dims:
        raise ValueError(
            'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
424 425 426 427
             (got nput_dims{}, label_dims{})'.format(
                input_dims, label_dims
            )
        )
W
wuhuachaocoding 已提交
428 429 430 431 432
    if input_dims - 1 == label_dims:
        label = paddle.unsqueeze(label, axis=-1)

    if _non_static_mode():
        softmax, loss = _legacy_C_ops.c_softmax_with_cross_entropy(
433 434
            logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks
        )
W
wuhuachaocoding 已提交
435 436 437 438 439 440 441 442 443 444 445 446 447
        if not return_softmax:
            return loss
        else:
            return loss, softmax

    attrs = {
        'ring_id': ring_id,
        'rank': rank,
        'nranks': nranks,
    }
    helper = LayerHelper('c_softmax_with_cross_entropy', **locals())
    softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
    loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
448 449 450 451 452 453
    helper.append_op(
        type='c_softmax_with_cross_entropy',
        inputs={'Logits': logits, 'Label': label},
        outputs={'Softmax': softmax, 'Loss': loss},
        attrs=attrs,
    )
W
wuhuachaocoding 已提交
454 455 456 457 458 459 460 461 462 463 464 465 466

    if return_softmax:
        return loss, softmax

    return loss


def _linear(x, weight, bias=None, name=None):
    """
    Fuction Linear
    """
    if _non_static_mode():
        pre_bias = _varbase_creator(dtype=x.dtype)
467 468 469 470 471 472 473 474 475 476 477 478 479 480
        _legacy_C_ops.matmul(
            x,
            weight,
            pre_bias,
            'transpose_X',
            False,
            'transpose_Y',
            False,
            "alpha",
            1,
        )
        return dygraph_utils._append_bias_in_dygraph(
            pre_bias, bias, axis=len(x.shape) - 1
        )
W
wuhuachaocoding 已提交
481 482 483
    else:
        helper = LayerHelper('linear', **locals())
        dtype = x.dtype
484 485 486
        assert (
            len(x.shape) < 4
        ), "X latitude is not supported greater than 3 now."
W
wuhuachaocoding 已提交
487

488 489 490
        check_variable_and_dtype(
            x, 'x', ['float16', 'float32', 'float64'], 'linear'
        )
W
wuhuachaocoding 已提交
491 492 493 494 495 496 497 498 499
        check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear')

        inputs = {'X': [x], 'Y': [weight]}
        attrs = {
            'transpose_X': False,
            'transpose_Y': False,
            'alpha': 1,
        }
        tmp = helper.create_variable_for_type_inference(dtype)
500 501 502
        helper.append_op(
            type='matmul_v2', inputs=inputs, outputs={'Out': tmp}, attrs=attrs
        )
W
wuhuachaocoding 已提交
503 504
        if bias is not None:
            res = helper.create_variable_for_type_inference(dtype)
505 506 507 508 509 510
            helper.append_op(
                type='elementwise_add',
                inputs={'X': [tmp], 'Y': [bias]},
                outputs={'Out': [res]},
                attrs={'axis': len(x.shape) - 1},
            )
W
wuhuachaocoding 已提交
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528
        else:
            res = tmp
        return res


def _set_var_distributed(var):
    if var is None:
        return

    var.is_distributed = True

    # NOTE: use current_block and find_var_recursive to support while_loop
    startup_block = paddle.static.default_startup_program().current_block()
    main_block = paddle.static.default_main_program().current_block()
    startup_block._find_var_recursive(var.name).is_distributed = True
    main_block._find_var_recursive(var.name).is_distributed = True


529 530 531 532 533 534 535 536 537 538 539 540 541 542
def _parallel_linear(
    x,
    num_rows,
    num_cols,
    axis,
    param_attr,
    bias_attr,
    gather_out,
    inner_rank,
    nranks,
    split_tensor,
    name,
    group=None,
):
W
wuhuachaocoding 已提交
543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560
    """
    Parallel Linear

    axis the dimension of the parameter of linear layer.
    axis = 0: the row dimension
    axis = 1: the col dimension

    """
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

    if axis == 0:
        if split_tensor:
            x = _c_split(x, group=group)
    else:
        x = _c_identity(x, group=group)

561 562 563 564 565 566 567
    linear = paddle.nn.Linear(
        num_rows,
        num_cols,
        weight_attr=param_attr,
        bias_attr=bias_attr,
        name=name,
    )
W
wuhuachaocoding 已提交
568 569

    # NOTE: npu linear function use matmul_v2 but linear use matmul
570 571 572
    linear_function = (
        _linear if core.is_compiled_with_npu() else paddle.nn.functional.linear
    )
W
wuhuachaocoding 已提交
573 574 575 576 577
    linear_out = linear_function(
        x,
        linear.weight,
        # NOTE(wangxi): row split, bias need add after allreduce
        None if axis == 0 else linear.bias,
578 579
        linear.name,
    )
W
wuhuachaocoding 已提交
580 581 582 583 584

    _set_var_distributed(linear.weight)
    # set is_distributed for splited bias
    # if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank.
    # if a linear layer is splited by col, the bias would also be split into each rank as its weight
585
    if axis == 1 and linear._bias_attr is not False:
W
wuhuachaocoding 已提交
586 587
        _set_var_distributed(linear.bias)

588 589
    if not gather_out:
        return linear_out
W
wuhuachaocoding 已提交
590 591 592 593 594 595 596 597 598 599 600

    out_shape = list(linear_out.shape)
    out_shape[0] *= 1 if axis == 0 else nranks
    main_block = paddle.static.default_main_program().current_block()
    out = main_block.create_var(
        shape=out_shape,
        dtype=linear_out.dtype,
        type=linear_out.type,
        lod_level=linear_out.lod_level,
        persistable=False,
        is_data=False,
601 602
        need_check_feed=linear_out.desc.need_check_feed(),
    )
W
wuhuachaocoding 已提交
603
    if axis == 0:
604 605 606 607 608 609 610 611 612 613
        main_block.append_op(
            type='c_allreduce_sum',
            inputs={'X': linear_out},
            outputs={'Out': out},
            attrs={
                'ring_id': ring_id,
                'use_calc_stream': True,
                'use_model_parallel': True,
            },
        )
W
wuhuachaocoding 已提交
614 615 616
        if linear.bias is not None:
            out = out + linear.bias
    else:
617 618 619 620 621 622 623 624 625 626 627 628
        main_block.append_op(
            type='c_concat',
            inputs={'X': linear_out},
            outputs={'Out': out},
            attrs={
                'rank': inner_rank,
                'ring_id': ring_id,
                'nranks': nranks,
                'use_calc_stream': True,
                'use_model_parallel': True,
            },
        )
W
wuhuachaocoding 已提交
629 630 631
    return out


632 633 634 635 636 637 638 639 640 641
def _parallel_embedding(
    x,
    per_part_embeddings,
    origin_size,
    param_attr,
    inner_rank,
    num_partitions,
    name,
    group=None,
):
W
wuhuachaocoding 已提交
642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657
    """
    Parallel Embedding
    """
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id

    helper = LayerHelper("_parallel_embedding", **locals())

    per_part_size = per_part_embeddings
    rank = inner_rank

    vocab_start_index = rank * per_part_size
    dtype = helper.get_default_dtype()
    size = [per_part_size, origin_size[1]]

658 659 660
    weight = helper.create_parameter(
        attr=param_attr, shape=size, dtype=dtype, is_bias=False
    )
W
wuhuachaocoding 已提交
661 662

    if num_partitions == 1:
663 664 665
        return paddle.nn.functional.embedding(
            x, weight=weight, padding_idx=None, sparse=False, name=name
        )
W
wuhuachaocoding 已提交
666 667 668 669 670 671

    startup_block = paddle.static.default_startup_program().global_block()
    main_block = paddle.static.default_main_program().global_block()
    startup_block.vars[weight.name].is_distributed = True
    main_block.vars[weight.name].is_distributed = True

672 673 674 675 676 677 678 679 680
    output_parallel = _c_lookup_table(
        weight, x, start_index=vocab_start_index, name=name
    )
    out = _mp_allreduce(
        output_parallel,
        group=group,
        use_calc_stream=True,
        use_model_parallel=True,
    )
W
wuhuachaocoding 已提交
681 682 683
    return out


684 685 686 687 688 689 690 691 692 693 694
def split(
    x,
    size,
    operation,
    axis=0,
    num_partitions=1,
    gather_out=True,
    weight_attr=None,
    bias_attr=None,
    name=None,
):
W
wuhuachaocoding 已提交
695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812
    """

    Split the weight of the specified operation into multiple devices
    and do the computation in parallel.

    Now the following three cases are supported.

    Case 1: Parallel Embedding
        The weight of the embedding operation is a NxM matrix with N rows and M columns.
        With parallel embedding, the weight is split into num_partitions partitions, each
        of which is a matrix with (N/num_partitions + 1) rows and M column where the last
        row as the padding idx.

        Suppose we split the NxM weight into two partitons on device_0 and device_1
        respectively. Then, one each device, the final weight has (N/2 + 1) rows with the
        index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1]
        keep unchanged and all other values are changed to N/2 which is the padding index and
        are mapped to all zeros after embedding. In the same way, on device_1, the value V in the
        input within [N/2, N-1] will be changed to (V - N/2), and all other values are changed
        to N/2 and are mapped to all zeros after embedding. Finally, the results on the two
        devices are sum-reduced.

        The Embedding put on single card is as shown below:

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_single.png
            :width: 800
            :height: 350
            :alt: single_embedding
            :align: center

        Parallel Embedding is shown as below:

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_embedding_split.png
            :width: 800
            :alt: split_embedding
            :align: center

    Case 2: Row Parallel Linear
        The weight of the linear operation is a NxM matrix with N rows and M columns.
        With row parallel linear, the weight is split into num_partitions partitions, each
        of which is a matrix with N/num_partitions rows and M column.

        The linear layer put on single card is shown as below, the input variable is represented by X,
        the weight matrix is represented by W and the output vaiable is O. The linear layer on single card is
        simple matrix multiplication operation, O = X * W.

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_single.png
            :width: 800
            :alt: single_linear
            :align: center

        Row Parallel Linear is shown as below. As the name suggests, Row Parallel Linear splits the weight matrix W into
        [[W_row1], [W_row2]] along the row. And accordingly the input is splitted along the column into [X_col1, X_col2] and multiply their
        respective weight matrices. Finally apply AllReduce on the output from each card to get the final output.

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_row.png
            :width: 800
            :alt: split_row
            :align: center

    Case 3: Column Parallel Linear
        The weight of the linear operation is a NxM matrix with N rows and M columns.
        With column parallel linear, the weight is split into num_paratitions partitions, each
        of which is a matrix with N rows and M/num_partitions column.

        The linear layer put on single card has been illustrated on case 2 and Column Parallel Linear
        is shown as below. The Column Parallel Linear splits the weight matrix W into [W_col1, W_col2] along the column and
        these splitted matrices respectively multiply the input. Finally apply AllGather on the output from each card to get the final output.

        .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col.png
            :width: 800
            :alt: split_col
            :align: center

    As observed, the column parallel linear and row parallel linear can be combined to skip one ALLGATHER communication
    operator. Furthermore the Attention and MLP can be combined to imporve the performance as shown below.

    .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/split_col_row.png
            :width: 800
            :alt: split_col_row
            :align: center

    Args:
        x (Tensor): Input tensor. It's data type should be float16, float32, float64, int32 or int64.
        size (list|tuple): A list or tuple with two elements indicating the shape of the weight.
        operation (str): The name of the operation. The supported operations are 'linear' and 'embedding'.
        axis (int, Optional): Indicate along which axis to split the weight. Default: 0.
        num_partitions (int, Optional): How many parts the weight is partitioned. Default: 1.
        gather_out (bool, Optional): Whether to gather the output after computation. By default, the output
            on each partitions will be gathered after computation. Default: True.
        weight_attr (ParamAttr, Optional): The parameter attribute for the learnable
            weights(Parameter) of the specified operation. Default: None.
        bias_attr (ParamAttr, Optional): The parameter attribute for the bias
            of the specified operation. Default: None.
        name (str, Optional): The default value is None. Normally there is no need for user to set this
            property. Default: None. For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor.

    Examples:
        .. code-block:: python

            # required: distributed
            import paddle
            import paddle.distributed.fleet as fleet

            paddle.enable_static()
            paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
            fleet.init(is_collective=True)
            data = paddle.randint(0, 8, shape=[10,4])
            emb_out = paddle.distributed.split(
                data,
                (8, 8),
                operation="embedding",
                num_partitions=2)

    """
813 814 815 816 817 818 819 820 821 822
    assert isinstance(size, (list, tuple)), (
        "The type of size for "
        "paddle.distributed.split must be list or tuple."
    )
    assert len(size) == 2, (
        "Number of elements in size of " "paddle.distributed.split must be two."
    )
    assert isinstance(operation, str), (
        "The type of operation for " "paddle.distributed.split must be str."
    )
W
wuhuachaocoding 已提交
823 824 825 826 827 828 829
    supported_operations = [
        'linear',
        'embedding',
    ]
    assert operation in supported_operations, (
        "The operation for "
        "paddle.distributed.split must be one of {}.".format(
830 831 832
            supported_operations
        )
    )
W
wuhuachaocoding 已提交
833 834 835 836
    if _non_static_mode():
        raise ValueError(
            "paddle.distributed.split cannot be used in dynamic "
            "graph mode, plese use ParallelEmbedding, ParallelRowLinear, "
837 838
            "ParallelColumnLinear instead."
        )
W
wuhuachaocoding 已提交
839 840
    else:
        from paddle.distributed.fleet import fleet
841 842 843 844 845

        assert fleet._role_maker, (
            "To use paddle.distributed.split, "
            "you must call fleet.init() firstly."
        )
W
wuhuachaocoding 已提交
846 847 848 849 850 851 852
        rank = fleet.worker_index()
        nranks = fleet.worker_num()

    # rank within a model parallel group
    inner_rank = rank % num_partitions

    if operation == "embedding":
853 854 855 856 857 858 859 860 861 862
        assert axis == 0, (
            "We only support to split the weight of embedding "
            "along the first axis now."
        )
        assert size[0] % num_partitions == 0, (
            "The length of the vocabulary must be divisible by num_partitions "
            "but received vocabulary={} num_partitions={}".format(
                size[0], num_partitions
            )
        )
W
wuhuachaocoding 已提交
863 864

        per_part_size = size[0] // num_partitions
865 866 867 868 869 870 871 872 873 874
        emb_out = _parallel_embedding(
            x,
            per_part_size,
            size,
            weight_attr,
            inner_rank,
            num_partitions,
            name,
            group=None,
        )
W
wuhuachaocoding 已提交
875 876 877 878 879 880 881
        return emb_out
    else:
        should_split = False
        if axis == 0:
            assert size[0] % num_partitions == 0, (
                "Number of rows of the weight for linear ({}) must be"
                " divisible by num_partitions ({})".format(
882 883 884
                    size[0], num_partitions
                )
            )
W
wuhuachaocoding 已提交
885 886
            per_part_size = size[0] // num_partitions
            linear_size = (per_part_size, size[1])
887 888
            if x.shape[-1] == size[0]:
                should_split = True
W
wuhuachaocoding 已提交
889 890 891 892 893

        elif axis == 1:
            assert size[1] % num_partitions == 0, (
                "Number of column of the weight for linear ({}) must be"
                " divisible by num_partitions ({})".format(
894 895 896
                    size[1], num_partitions
                )
            )
W
wuhuachaocoding 已提交
897 898 899
            per_part_size = size[1] // num_partitions
            linear_size = (size[0], per_part_size)
        else:
900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918
            raise ValueError(
                "The value of axis must be 0 or 1, but the value "
                "given is {}.".format(axis)
            )

        linear_out = _parallel_linear(
            x,
            linear_size[0],
            linear_size[1],
            axis,
            weight_attr,
            bias_attr,
            gather_out,
            inner_rank,
            num_partitions,
            should_split,
            name=name,
            group=None,
        )
W
wuhuachaocoding 已提交
919
        return linear_out