mp_layers.py 26.5 KB
Newer Older
W
wuhuachaocoding 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
16
from paddle.autograd import PyLayer
W
wuhuachaocoding 已提交
17 18
from paddle.fluid import core
from paddle.nn import functional as F
19

20
from ....communication.reduce import ReduceOp, _get_reduce_op
W
wuhuachaocoding 已提交
21
from ...base import topology as tp
22
from . import mp_ops
23
from .mp_ops import _get_mp_env_flag
24
from .random import get_rng_state_tracker
W
wuhuachaocoding 已提交
25 26 27 28 29 30 31 32 33

__all__ = []

# Follow this paper to achieve the file:
# Shoeybi M, Patwary M, Puri R, et al. Megatron-lm: Training multi-billion parameter
# language models using model parallelism[J]. arXiv preprint arXiv:1909.08053, 2019. (https://arxiv.org/abs/1909.08053)


def is_fused_matmul_bias_supported():
姜永久 已提交
34
    return hasattr(core.eager.ops.legacy, 'fused_gemm_epilogue')
W
wuhuachaocoding 已提交
35 36


37 38 39 40 41 42 43
def is_fused_linear_param_grad_add_supported():
    if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
        return hasattr(paddle._C_ops, 'fused_linear_param_grad_add')
    else:
        return False


44
class VocabParallelEmbedding(paddle.nn.Layer):
W
wuhuachaocoding 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
    """Embedding mp parallelized in the vocabulary dimension.
    this class is used for splitting embedding in mp group.

    Args:
        num_embeddings(int): One element which indicate the size of the dictionary of embeddings.
        embedding_dim(int): One element which indicate the size of each embedding vector respectively.
        weight_attr(ParamAttr|None): To specify the weight parameter property. Default: None, which means the
            default weight parameter property is used. See usage for details in :ref:`api_ParamAttr` . In addition,
            user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter.
            The local word vector needs to be transformed into numpy format, and the shape of local word
            vector should be consistent with :attr:`num_embeddings` . Then :ref:`api_initializer_NumpyArrayInitializer`
            is used to load custom or pre-trained word vectors. See code example for details.
        mp_group(Group): The tensor parallel group.
        name(str, optional): For detailed information, please refer
               to :ref:`api_guide_Name`. Usually name is no need to set and
               None by default.

    Examples:
        .. code-block:: python
        import paddle
        from paddle.distributed import fleet

        class SimpleMPNet(paddle.nn.Layer):
           def __init__(self, vocab_size, hidden_size, inner_size, output_size):
69
              super().__init__()
W
wuhuachaocoding 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
              self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
                    hidden_size,
                    inner_size,
                    gather_output=False,
                    has_bias=True)

              self.linear2 = fleet.meta_parallel.RowParallelLinear(
                    inner_size,
                    hidden_size,
                    input_is_parallel=True,
                    has_bias=True)

              self.linear3 = paddle.nn.Linear(hidden_size, output_size)

              self.embedding = fleet.meta_parallel.VocabParallelEmbedding(
                                vocab_size,
                                hidden_size)

           def forward(self, x):
              x = self.embedding(x)
              x = self.linear1(x)
              x = self.linear2(x)
              x = self.linear3(x)
              return x
    """

96 97 98 99 100 101 102 103
    def __init__(
        self,
        num_embeddings,
        embedding_dim,
        weight_attr=None,
        mp_group=None,
        name=None,
    ):
104
        super().__init__()
W
wuhuachaocoding 已提交
105

106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
        self.model_parallel_group = (
            tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group()
            if mp_group is None
            else mp_group
        )
        self.world_size = (
            tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size()
            if mp_group is None
            else mp_group.nranks
        )
        self.rank = (
            tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()
            if mp_group is None
            else mp_group.rank
        )
W
wuhuachaocoding 已提交
121 122

        self.origin_num_embeddings = num_embeddings
123
        self.is_mp = self.world_size > 1
W
wuhuachaocoding 已提交
124

125 126 127
        assert (
            num_embeddings % self.world_size == 0
        ), "The length of the vocabulary must be divisible by the parallelism degree of MP"
W
wuhuachaocoding 已提交
128 129 130 131 132 133 134 135 136 137 138

        per_part_size = num_embeddings // self.world_size

        self.vocab_start_index = self.rank * per_part_size
        self._dtype = self._helper.get_default_dtype()
        self._size = [per_part_size, embedding_dim]
        self._weight_attr = weight_attr
        self._name = name

        if self.is_mp and paddle.in_dynamic_mode():
            with get_rng_state_tracker().rng_state():
139 140 141 142 143 144
                self.weight = self.create_parameter(
                    attr=self._weight_attr,
                    shape=self._size,
                    dtype=self._dtype,
                    is_bias=False,
                )
W
wuhuachaocoding 已提交
145
        else:
146 147 148 149 150 151
            self.weight = self.create_parameter(
                attr=self._weight_attr,
                shape=self._size,
                dtype=self._dtype,
                is_bias=False,
            )
W
wuhuachaocoding 已提交
152 153

        self.weight.is_distributed = True if self.is_mp else False
154
        if self.weight.is_distributed:
155
            self.weight.split_axis = 0
W
wuhuachaocoding 已提交
156 157 158 159 160 161 162

    def forward(self, x):
        if self.is_mp:
            output_parallel = mp_ops._c_lookup_table(
                self.weight,
                x,
                start_index=self.vocab_start_index,
163 164 165 166 167 168 169 170
                name=self._name,
            )
            output = mp_ops._mp_allreduce(
                output_parallel,
                group=self.model_parallel_group,
                use_calc_stream=True,
                use_model_parallel=True,
            )
W
wuhuachaocoding 已提交
171
        else:
172 173 174 175 176 177 178
            output = F.embedding(
                x,
                weight=self.weight,
                padding_idx=None,
                sparse=False,
                name=self._name,
            )
W
wuhuachaocoding 已提交
179 180 181
        return output


182
class ColumnParallelLinear(paddle.nn.Layer):
W
wuhuachaocoding 已提交
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
    """Linear layer with mp parallelized(column).
    this class is used for splitting Linear Layer in mp group, column split the weight of the Linear layer.

    Args:
        in_features(int): The number of input units.
        out_features(int): The number of output units.
        weight_attr(ParamAttr|None): The attribute for the learnable weight of this layer. The default value is None
            and the weight will be initialized to zero. For detailed information, please refer to paddle.ParamAttr.
        has_bias(bool): whether to add bias.
        gather_output(bool): whether to do allgahter for the output of each rank.
        fuse_matmul_bias(bool): whether to fuse matmul and bias.
        mp_group(Group): The tensor parallel group.
        name(str, optional): Normally there is no need for user to set this parameter.
            For detailed information, please refer to :ref:`api_guide_Name` .

    Examples:
        .. code-block:: python
        import paddle
        from paddle.distributed import fleet

        class SimpleMPNet(paddle.nn.Layer):
           def __init__(self, vocab_size, hidden_size, inner_size, output_size):
205
              super().__init__()
W
wuhuachaocoding 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
              self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
                    hidden_size,
                    inner_size,
                    gather_output=False,
                    has_bias=True)

              self.linear2 = fleet.meta_parallel.RowParallelLinear(
                    inner_size,
                    hidden_size,
                    input_is_parallel=True,
                    has_bias=True)

              self.linear3 = paddle.nn.Linear(hidden_size, output_size)

              self.embedding = fleet.meta_parallel.VocabParallelEmbedding(
                                vocab_size,
                                hidden_size)

           def forward(self, x):
              x = self.embedding(x)
              x = self.linear1(x)
              x = self.linear2(x)
              x = self.linear3(x)
              return x
    """

232 233 234 235 236 237 238 239 240 241 242
    def __init__(
        self,
        in_features,
        out_features,
        weight_attr=None,
        has_bias=None,
        gather_output=True,
        fuse_matmul_bias=False,
        mp_group=None,
        name=None,
    ):
243
        super().__init__()
W
wuhuachaocoding 已提交
244

245 246 247 248 249 250 251 252 253 254
        self.model_parallel_group = (
            tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group()
            if mp_group is None
            else mp_group
        )
        self.world_size = (
            tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size()
            if mp_group is None
            else mp_group.nranks
        )
W
wuhuachaocoding 已提交
255
        self._name = name
256
        self.is_mp = self.world_size > 1
W
wuhuachaocoding 已提交
257 258 259 260 261

        self.gather_output = gather_output
        assert out_features % self.world_size == 0, (
            "Number of column of the weight for linear ({}) must be"
            " divisible by model parallel size ({})".format(
262 263 264
                out_features, self.world_size
            )
        )
W
wuhuachaocoding 已提交
265 266 267 268 269 270 271 272 273 274 275
        self.output_size_per_partition = out_features // self.world_size

        self._weight_attr = weight_attr
        self._dtype = self._helper.get_default_dtype()

        if self.is_mp and paddle.in_dynamic_mode():
            with get_rng_state_tracker().rng_state():
                self.weight = self.create_parameter(
                    shape=[in_features, self.output_size_per_partition],
                    attr=self._weight_attr,
                    dtype=self._dtype,
276 277
                    is_bias=False,
                )
W
wuhuachaocoding 已提交
278 279 280 281 282
        else:
            self.weight = self.create_parameter(
                shape=[in_features, self.output_size_per_partition],
                attr=self._weight_attr,
                dtype=self._dtype,
283 284
                is_bias=False,
            )
W
wuhuachaocoding 已提交
285 286 287

        self.weight.is_distributed = True if self.is_mp else False

288
        if self.weight.is_distributed:
289
            self.weight.split_axis = 1
290

W
wuhuachaocoding 已提交
291 292 293 294 295 296
        if has_bias:
            # initialize bias to zero like Megatron
            self.bias = self.create_parameter(
                shape=[self.output_size_per_partition],
                attr=paddle.nn.initializer.Constant(value=0.0),
                dtype=self._dtype,
297 298
                is_bias=True,
            )
W
wuhuachaocoding 已提交
299
            self.bias.is_distributed = True if self.is_mp else False
300
            if self.bias.is_distributed:
301
                self.bias.split_axis = 0
W
wuhuachaocoding 已提交
302 303 304 305 306
        else:
            self.bias = None

        self.linear = F.linear

307 308
        self.fuse_matmul_bias = fuse_matmul_bias
        if self.fuse_matmul_bias:
W
wuhuachaocoding 已提交
309 310 311 312 313
            if not is_fused_matmul_bias_supported():
                raise NotImplementedError(
                    "You set fuse_matmul_bias=True in ColumnParallelLinear, "
                    "however, the paddle you are using not support this operation. "
                    "Please set fuse_matmul_bias=False or use paddle compiled "
314 315
                    "with cuda 11.6 or higher."
                )
W
wuhuachaocoding 已提交
316
            from paddle.incubate.nn.functional import fused_linear
317

W
wuhuachaocoding 已提交
318 319 320 321
            self.linear = fused_linear

    def forward(self, x):
        # use inner api to process identity
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453

        def _overlap_linear():
            fuse_matmul_bias = self.fuse_matmul_bias

            class InnerOverlapLinear(paddle.autograd.PyLayer):
                @staticmethod
                def forward(ctx, x, weight, bias):
                    ctx.save_for_backward(x, weight, bias)
                    if (
                        _get_mp_env_flag("Flags_mp_aysnc_allreduce")
                        and _get_mp_env_flag("Flags_skip_mp_c_identity")
                    ) is False:
                        x = paddle._legacy_C_ops.c_identity(
                            x,
                            'use_calc_stream',
                            True,
                            'ring_id',
                            self.model_parallel_group.id,
                            'use_model_parallel',
                            True,
                        )
                    if not fuse_matmul_bias:
                        return paddle._C_ops.linear(x, weight, bias)
                    else:
                        return paddle._legacy_C_ops.fused_gemm_epilogue(
                            x, weight, bias
                        )

                @staticmethod
                def backward(ctx, dy):
                    x, weight, bias = ctx.saved_tensor()
                    dx = paddle.matmul(dy, weight, transpose_y=True)
                    op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
                    task = self.model_parallel_group.process_group.all_reduce(
                        dx, op_type, sync_op=False
                    )
                    # TODO(GhostScreaming): remove it in future.
                    tmp = paddle.ones([512])

                    if _get_mp_env_flag("Flags_fused_linear_param_grad_add"):
                        if not is_fused_linear_param_grad_add_supported():
                            raise NotImplementedError(
                                "You set environment variable Flags_fused_linear_param_grad_add=True, "
                                "however, the paddle you are using not support this operation. "
                                "Please unset Flags_fused_linear_param_grad_add or use paddle compiled "
                                "with cuda 11.6 or higher."
                            )

                        if bias is None:
                            if hasattr(weight, "main_grad"):
                                (
                                    weight.main_grad,
                                    _,
                                ) = paddle._C_ops.fused_linear_param_grad_add(
                                    x, dy, weight.main_grad, None, True, False
                                )
                                task.wait()
                                return dx, None
                            else:
                                if weight.grad is not None:
                                    (
                                        weight.grad,
                                        _,
                                    ) = paddle._C_ops.fused_linear_param_grad_add(
                                        x, dy, weight.grad, None, False, False
                                    )
                                    task.wait()
                                    return dx, None
                                else:
                                    (
                                        dw,
                                        _,
                                    ) = paddle._C_ops.fused_linear_param_grad_add(
                                        x, dy, None, None, False, False
                                    )
                                    task.wait()
                                    return dx, dw

                        if hasattr(weight, "main_grad") and hasattr(
                            bias, "main_grad"
                        ):
                            (
                                weight.main_grad,
                                bias.main_grad,
                            ) = paddle._C_ops.fused_linear_param_grad_add(
                                input,
                                dy,
                                weight.main_grad,
                                bias.main_grad,
                                True,
                                True,
                            )
                            task.wait()
                            return dx, None, None
                        else:
                            if weight.grad is not None:
                                assert bias.grad is not None
                                (
                                    weight.grad,
                                    bias.grad,
                                ) = paddle._C_ops.fused_linear_param_grad_add(
                                    x, dy, weight.grad, bias.grad, False, True
                                )
                                task.wait()
                                return dx, None, None
                            else:
                                (
                                    dw,
                                    dbias,
                                ) = paddle._C_ops.fused_linear_param_grad_add(
                                    x, dy, None, None, False, True
                                )
                                task.wait()
                                return dx, dw, dbias
                    else:
                        dw = paddle.matmul(
                            x.reshape([-1, x.shape[-1]]),
                            dy.reshape([-1, dy.shape[-1]]),
                            transpose_x=True,
                        )
                        if bias is None:
                            task.wait()
                            return dx, dw
                        else:
                            dbias = paddle.sum(dy, axis=0)
                            task.wait()
                            return dx, dw, dbias

            return InnerOverlapLinear.apply(x, self.weight, self.bias)

        if _get_mp_env_flag("Flags_mp_aysnc_allreduce"):
            output_parallel = _overlap_linear()
W
wuhuachaocoding 已提交
454
        else:
455 456 457 458 459 460
            if self.is_mp:
                input_parallel = mp_ops._c_identity(
                    x, group=self.model_parallel_group
                )
            else:
                input_parallel = x
W
wuhuachaocoding 已提交
461

462 463 464
            output_parallel = self.linear(
                input_parallel, self.weight, self.bias, name=self._name
            )
W
wuhuachaocoding 已提交
465 466

        if self.gather_output and self.is_mp:
467 468 469
            output = mp_ops._c_concat(
                output_parallel, group=self.model_parallel_group
            )
W
wuhuachaocoding 已提交
470 471 472 473 474
        else:
            output = output_parallel
        return output


475 476 477 478 479 480 481 482 483 484 485
class MPScale(PyLayer):
    @staticmethod
    def forward(ctx, x, mp_degree):
        out = paddle.scale(x, 1.0 / mp_degree)
        return out

    @staticmethod
    def backward(ctx, dout):
        return dout


486
class RowParallelLinear(paddle.nn.Layer):
W
wuhuachaocoding 已提交
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508
    """Linear layer with mp parallelized(row).
    this class is used for splitting Linear Layer in mp group, row split the weight of the Linear layer.

    Args:
        in_features(int): The number of input units.
        out_features(int): The number of output units.
        weight_attr(ParamAttr|None): The attribute for the learnable weight of this layer. The default value is None
            and the weight will be initialized to zero. For detailed information, please refer to paddle.ParamAttr.
        has_bias(bool): whether to add bias.
        input_is_parallel(bool): whether the input has alreadly been splitted across the mp group.
        fuse_matmul_bias(bool): whether to fuse matmul and bias.
        mp_group(Group): The tensor parallel group.
        name(str, optional): Normally there is no need for user to set this parameter.
            For detailed information, please refer to :ref:`api_guide_Name` .

    Examples:
        .. code-block:: python
        import paddle
        from paddle.distributed import fleet

        class SimpleMPNet(paddle.nn.Layer):
           def __init__(self, vocab_size, hidden_size, inner_size, output_size):
509
              super().__init__()
W
wuhuachaocoding 已提交
510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535
              self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
                    hidden_size,
                    inner_size,
                    gather_output=False,
                    has_bias=True)

              self.linear2 = fleet.meta_parallel.RowParallelLinear(
                    inner_size,
                    hidden_size,
                    input_is_parallel=True,
                    has_bias=True)

              self.linear3 = paddle.nn.Linear(hidden_size, output_size)

              self.embedding = fleet.meta_parallel.VocabParallelEmbedding(
                                vocab_size,
                                hidden_size)

           def forward(self, x):
              x = self.embedding(x)
              x = self.linear1(x)
              x = self.linear2(x)
              x = self.linear3(x)
              return x
    """

536 537 538 539 540 541 542 543 544 545 546
    def __init__(
        self,
        in_features,
        out_features,
        weight_attr=None,
        has_bias=True,
        input_is_parallel=False,
        fuse_matmul_bias=False,
        mp_group=None,
        name=None,
    ):
547
        super().__init__()
W
wuhuachaocoding 已提交
548 549 550 551 552 553 554 555

        self.in_features = in_features
        self.out_features = out_features
        self.input_is_parallel = input_is_parallel
        self._weight_attr = weight_attr
        self._dtype = self._helper.get_default_dtype()
        self._name = name

556 557 558 559 560 561 562 563 564 565 566 567 568 569 570
        self.model_parallel_group = (
            tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group()
            if mp_group is None
            else mp_group
        )
        self.world_size = (
            tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size()
            if mp_group is None
            else mp_group.nranks
        )
        self.rank = (
            tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()
            if mp_group is None
            else mp_group.rank
        )
W
wuhuachaocoding 已提交
571

572
        self.is_mp = self.world_size > 1
W
wuhuachaocoding 已提交
573 574 575
        assert in_features % self.world_size == 0, (
            "Number of row of the weight for linear ({}) must be"
            " divisible by model parallel size ({})".format(
576 577 578
                in_features, self.world_size
            )
        )
W
wuhuachaocoding 已提交
579 580 581 582 583 584 585 586 587

        self.input_size_per_partition = in_features // self.world_size

        if self.is_mp and paddle.in_dynamic_mode():
            with get_rng_state_tracker().rng_state():
                self.weight = self.create_parameter(
                    shape=[self.input_size_per_partition, self.out_features],
                    attr=self._weight_attr,
                    dtype=self._dtype,
588 589
                    is_bias=False,
                )
W
wuhuachaocoding 已提交
590 591 592 593 594
        else:
            self.weight = self.create_parameter(
                shape=[self.input_size_per_partition, self.out_features],
                attr=self._weight_attr,
                dtype=self._dtype,
595 596
                is_bias=False,
            )
W
wuhuachaocoding 已提交
597 598

        self.weight.is_distributed = True if self.is_mp else False
599
        if self.weight.is_distributed:
600
            self.weight.split_axis = 0
W
wuhuachaocoding 已提交
601 602 603 604 605 606

        if has_bias:
            self.bias = self.create_parameter(
                shape=[self.out_features],
                attr=paddle.nn.initializer.Constant(value=0.0),
                dtype=self._dtype,
607 608
                is_bias=True,
            )
W
wuhuachaocoding 已提交
609 610 611 612 613 614 615 616 617 618 619
        else:
            self.bias = None

        self.linear = F.linear

        if fuse_matmul_bias:
            if not is_fused_matmul_bias_supported():
                raise NotImplementedError(
                    "You set fuse_matmul_bias=True in RowParallelLinear, "
                    "however, the paddle you are using not support this operation. "
                    "Please set fuse_matmul_bias=False or use paddle compiled "
620 621
                    "with cuda 11.6 or higher."
                )
W
wuhuachaocoding 已提交
622
            from paddle.incubate.nn.functional import fused_linear
623

W
wuhuachaocoding 已提交
624
            self.linear = fused_linear
625
        self.fuse_matmul_bias = fuse_matmul_bias
W
wuhuachaocoding 已提交
626 627 628 629 630 631 632 633 634

    def forward(self, x):
        if self.input_is_parallel or (not self.is_mp):
            input_parallel = x
        else:
            # split last dim
            input_parallel = mp_ops._c_split(x, group=self.model_parallel_group)

        if self.is_mp:
635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658
            if self.fuse_matmul_bias:
                bias = MPScale.apply(self.bias, self.world_size)
                output_parallel = self.linear(
                    input_parallel, self.weight, bias, name=self._name
                )
                output = mp_ops._mp_allreduce(
                    output_parallel,
                    group=self.model_parallel_group,
                    use_calc_stream=True,
                    use_model_parallel=True,
                )
            else:
                output_parallel = self.linear(
                    input_parallel, self.weight, name=self._name
                )
                output_ = mp_ops._mp_allreduce(
                    output_parallel,
                    group=self.model_parallel_group,
                    use_calc_stream=True,
                    use_model_parallel=True,
                )
                output = (
                    output_ + self.bias if self.bias is not None else output_
                )
W
wuhuachaocoding 已提交
659
        else:
660 661 662
            output = self.linear(
                input_parallel, self.weight, self.bias, name=self._name
            )
W
wuhuachaocoding 已提交
663 664 665 666

        return output


667
class ParallelCrossEntropy(paddle.nn.Layer):
W
wuhuachaocoding 已提交
668 669 670 671 672 673 674
    """CrossEntropy with mp parallelized.
    this class is used for splitting softmax cross entropy in mp group.

    Args:
        mp_group(Group): The tensor parallel group.
        name(str, optional): Normally there is no need for user to set this parameter.
            For detailed information, please refer to :ref:`api_guide_Name` .
675
        ignore_index (long int, optional):  Specifies a target value that is ignored and
676 677
            does not contribute to the loss. A negative value means that no label value
            needs to be ignored. Default is -100 .
W
wuhuachaocoding 已提交
678 679 680 681 682 683 684

    Examples:
        .. code-block:: python
        loss_func = ParallelCrossEntropy()
        loss = loss_func(img, lable)
    """

685
    def __init__(self, mp_group=None, name=None, ignore_index=-100):
686
        super().__init__()
W
wuhuachaocoding 已提交
687
        self.name = name
688 689 690 691 692 693 694 695 696 697 698 699 700 701 702
        self.model_parallel_group = (
            tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group()
            if mp_group is None
            else mp_group
        )
        self.world_size = (
            tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size()
            if mp_group is None
            else mp_group.nranks
        )
        self.rank = (
            tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()
            if mp_group is None
            else mp_group.rank
        )
703
        self.ignore_index = ignore_index
W
wuhuachaocoding 已提交
704 705 706

    def forward(self, input, label):
        loss = mp_ops._c_softmax_with_cross_entropy(
707 708 709 710
            input,
            label,
            group=self.model_parallel_group,
            ignore_index=self.ignore_index,
711
        )
W
wuhuachaocoding 已提交
712
        return loss