fused_transformer.py 64.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15 16 17 18 19 20
import numpy as np

import paddle
from paddle.fluid import core
from paddle.fluid.core import VarDesc
from paddle.fluid.dygraph import no_grad
from paddle.fluid.framework import _non_static_mode, convert_np_dtype_to_dtype_
21 22
from paddle.incubate.nn import functional as incubate_f
from paddle.nn import Layer
23
from paddle.nn.initializer import Constant
24 25 26 27
from paddle.nn.layer.transformer import (
    _convert_attention_mask,
    _convert_param_attr_to_list,
)
28

29

30 31 32 33 34 35 36
# for distributed tensor model parallel
def _set_var_distributed(var):
    if var is None:
        return

    var.is_distributed = True

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
    if not _non_static_mode():
        # NOTE: use current_block and find_var_recursive to support while_loop
        startup_block = paddle.static.default_startup_program().current_block()
        main_block = paddle.static.default_main_program().current_block()
        startup_block._find_var_recursive(var.name).is_distributed = True
        main_block._find_var_recursive(var.name).is_distributed = True


def _to_dtype(t, dtype):
    # this function is a prune of Layer._transform function to fix fused op under amp.decorator(O2)
    if not paddle.is_floating_point(t):
        return t

    if type(dtype) is not VarDesc.VarType:
        dtype = convert_np_dtype_to_dtype_(dtype)

    if t.place.is_gpu_place():
        size_dtype = core.size_of_dtype(dtype)
        waiting_alloc_memory = (
56 57
            ((np.prod(t.shape) * size_dtype) / 256 + 1) * 256 * 1.2
        )
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
        gpu_memory_available = core.gpu_memory_available()
        if gpu_memory_available < waiting_alloc_memory:
            t_used = t._copy_to(paddle.CPUPlace(), False)
            t.value().get_tensor()._clear()
        else:
            t_used = t
    else:
        t_used = t

    if dtype is not None and dtype != t_used.dtype:
        with paddle.fluid.framework._dygraph_place_guard(place=t_used.place):
            t_casted = t_used.cast(dtype=dtype)
    else:
        t_casted = t_used

    new_t = t_casted

    dst_tensor = t.value().get_tensor()
    src_tensor = new_t.value().get_tensor()
    dst_tensor._share_data_with(src_tensor)

    return t
80 81


82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
class FusedBiasDropoutResidualLayerNorm(Layer):
    """
    Applies fused_bias_dropout_residual_layer_norm operation.

    Parameters:
        embed_dim (int): The expected feature size in the input and output.
        dropout_rate (float, optional): The dropout probability used on attention
            weights to drop some attention targets for the dropout after attention.
            0 for no dropout. Default 0.5.
        bias_attr (ParamAttr|bool, optional): To specify the bias parameter property.
            Default: None, which means the default bias parameter property is used.
            If it is set to False, this layer will not have trainable bias parameter.
            See usage for details in :code:`ParamAttr`.
        epsilon (float, optional): The small value added to the variance to prevent
            division by zero. Default: 1e-05.

    Examples:

        .. code-block:: python

            # required: gpu
            import paddle
            # input: [batch_size, seq_len, embed_dim]
            x = paddle.rand((2, 4, 128))
            # residual: [batch_size, seq_len, embed_dim]
            residual = paddle.rand((2, 4, 128))
            fused_bias_dropout_residual_ln = paddle.incubate.nn.FusedBiasDropoutResidualLayerNorm(128)
            output = fused_bias_dropout_residual_ln(x, residual)  # [2, 4, 128]
    """

112 113 114 115 116 117 118 119 120
    def __init__(
        self,
        embed_dim,
        dropout_rate=0.5,
        weight_attr=None,
        bias_attr=None,
        epsilon=1e-5,
        name=None,
    ):
121
        super().__init__()
122 123 124 125
        assert embed_dim > 0, (
            "Expected embed_dim to be greater than 0, "
            "but recieved {}".format(embed_dim)
        )
126 127 128 129
        self._dtype = self._helper.get_default_dtype()
        self._bias_attr = bias_attr
        self._weight_attr = weight_attr
        self.embed_dim = embed_dim
130 131 132 133 134 135
        self.linear_bias = self.create_parameter(
            shape=[embed_dim],
            attr=self._bias_attr,
            dtype=self._dtype,
            is_bias=True,
        )
136 137 138
        self.ln_scale = self.create_parameter(
            attr=self._weight_attr,
            shape=[embed_dim],
139 140 141 142 143
            default_initializer=Constant(value=1.0),
        )
        self.ln_bias = self.create_parameter(
            attr=self._bias_attr, shape=[embed_dim], is_bias=True
        )
144 145 146 147 148 149 150 151 152 153
        self.dropout_rate = dropout_rate
        self._epsilon = epsilon

        self.name = name

    def forward(self, x, residual):
        """
        Applies fused_bias_dropout_residual_layer_norm operation.

        Parameters:
154 155 156 157 158 159
            x (Tensor): The input tensor. It is a tensor with shape
                `[batch_size, seq_len, embed_dim]`. The data type should be
                float32 or float64.
            residual (Tensor, optional): The residual tensor. It is a tensor
                with shape `[batch_size, value_length, vdim]`. The data type
                should be float32 or float64.
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175

        Returns:
            Tensor|tuple: It is a tensor that has the same shape and data type \
                as `x`.
        """

        out = incubate_f.fused_bias_dropout_residual_layer_norm(
            x=x,
            residual=residual,
            bias=self.linear_bias,
            ln_scale=self.ln_scale,
            ln_bias=self.ln_bias,
            dropout_rate=self.dropout_rate,
            ln_epsilon=self._epsilon,
            training=self.training,
            mode='upscale_in_train',
176 177
            name=self.name,
        )
178 179 180 181 182
        return out

    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'embed_dim={}, seq_len={}, dropout_rate={}, epsilon={}, dtype={}{}'.format(
183 184 185 186 187 188 189
            self.embed_dim,
            self.seq_len,
            self.dropout_rate,
            self._epsilon,
            self._dtype,
            name_str,
        )
190 191


192 193
class FusedMultiHeadAttention(Layer):
    """
194
    Attention mapps queries and a set of key-value pairs to outputs, and
195 196 197 198
    Multi-Head Attention performs multiple parallel attention to jointly attending
    to information from different representation subspaces.
    Please refer to `Attention Is All You Need <https://arxiv.org/pdf/1706.03762.pdf>`_
    for more details.
199

200 201 202
    Parameters:
        embed_dim (int): The expected feature size in the input and output.
        num_heads (int): The number of heads in multi-head attention.
203
        dropout_rate (float, optional): The dropout probability used on attention
204
            weights to drop some attention targets for the dropout after attention.
205 206
            0 for no dropout. Default 0.5.
        attn_dropout_rate (float, optional): The dropout probability used on attention
207
            weights to drop some attention targets for the dropout in attention.
208
            0 for no dropout. Default 0.5.
209 210 211 212
        kdim (int, optional): The feature size in key. If None, assumed equal to
            `embed_dim`. Default None.
        vdim (int, optional): The feature size in value. If None, assumed equal to
            `embed_dim`. Default None.
213
        normalize_before (bool, optional): Indicate  whether it is pre_layer_norm
214
            (True) or post_layer_norm architecture (False). Default False.
215
        need_weights (bool, optional): Indicate whether to return the attention
216
            weights. Now, only False is supported. Default False.
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
        qkv_weight_attr(ParamAttr, optional): To specify the weight parameter property
            for QKV projection computation. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        qkv_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for QKV projection computation. The `False` value means the corresponding layer
            would not have trainable bias parameter. Default: None, which means the
            default bias parameter property is used. See usage for details in :code:`ParamAttr`.
        linear_weight_attr(ParamAttr, optional): To specify the weight parameter property
            for linear projection computation. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        linear_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for linear projection computation. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        pre_ln_scale_attr(ParamAttr, optional): To specify the weight parameter property
            for pre_layer_norm computation. Otherwise, all layers both use it as
            `attr` to create parameters. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        pre_ln_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for pre_layer_norm computation. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ln_scale_attr(ParamAttr, optional): To specify the weight parameter property
            for post_layer_norm computation. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ln_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for post_layer_norm computation. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
L
Li Min 已提交
246 247
        epsilon (float, optional): The small value added to the variance to prevent
            division by zero. Default: 1e-05.
248
        nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using tensor parallel.
249 250 251 252 253
        transpose_qkv_wb (bool, optional): Support input qkv matmul weight shape as
            [hidden_size, 3 * hidden_size] and qkv matmul bias shape as [3 * hidden_size].
            Will transpose the weight to [3, num_head, head_dim, hidden_size] and transpose bias to
            [3, num_head, hidden_size] in the fused_attention_op. Only support for GPU for now.
            The default value is False, which is not do transpose to qkv_w and qkv_b.
254
        ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using tensor parallel.
255

256
    Examples:
257

258
        .. code-block:: python
259 260

            # required: gpu
261
            import paddle
262
            # input: [batch_size, sequence_length, embed_dim]
263 264 265
            query = paddle.rand((2, 4, 128))
            # self attention mask: [batch_size, num_heads, query_len, query_len]
            attn_mask = paddle.rand((2, 2, 4, 4))
266
            multi_head_attn = paddle.incubate.nn.FusedMultiHeadAttention(128, 2)
267 268 269
            output = multi_head_attn(query, None, None, attn_mask=attn_mask)  # [2, 4, 128]
    """

270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout_rate=0.5,
        attn_dropout_rate=0.5,
        kdim=None,
        vdim=None,
        normalize_before=False,
        need_weights=False,
        qkv_weight_attr=None,
        qkv_bias_attr=None,
        linear_weight_attr=None,
        linear_bias_attr=None,
        pre_ln_scale_attr=None,
        pre_ln_bias_attr=None,
        ln_scale_attr=None,
        ln_bias_attr=None,
        epsilon=1e-5,
        nranks=1,
        ring_id=-1,
291
        transpose_qkv_wb=False,
292 293
        name=None,
    ):
294
        super().__init__()
295

296 297 298 299 300 301 302 303 304
        assert embed_dim > 0, (
            "Expected embed_dim to be greater than 0, "
            "but received {}".format(embed_dim)
        )
        assert (
            num_heads > 0
        ), "Expected nhead to be greater than 0, " "but received {}".format(
            num_heads
        )
305 306 307

        self.normalize_before = normalize_before
        self._dtype = self._helper.get_default_dtype()
308
        self._epsilon = epsilon
309
        self._ring_id = ring_id
310

311 312
        self.embed_dim = embed_dim
        self.num_heads = num_heads
313
        self.head_dim = embed_dim // num_heads
314 315 316
        self.kdim = kdim
        self.vdim = vdim
        self.need_weights = need_weights
317 318 319
        assert (
            self.head_dim * num_heads == embed_dim
        ), "embed_dim must be divisible by num_heads"
320 321 322 323
        assert need_weights is False, "Only support need_weight is False now."

        # tensor model parallel
        assert num_heads % nranks == 0
324 325 326 327 328 329 330 331 332 333
        self.num_heads = num_heads // nranks

        self.transpose_qkv_wb = transpose_qkv_wb
        if self.transpose_qkv_wb:
            # For tensor model parallel, use num_head * head_dim to compute the real shape.
            qkv_wight_shape = [embed_dim, 3 * self.num_heads * self.head_dim]
            qkv_bias_shape = [3 * self.num_heads * self.head_dim]
        else:
            qkv_wight_shape = [3, self.num_heads, self.head_dim, embed_dim]
            qkv_bias_shape = [3, self.num_heads, self.head_dim]
334 335

        self.qkv_weight = self.create_parameter(
336
            shape=qkv_wight_shape,
337
            attr=qkv_weight_attr,
338
            dtype=self._dtype,
339 340
            is_bias=False,
        )
341
        self.qkv_bias = self.create_parameter(
342
            shape=qkv_bias_shape,
343
            attr=qkv_bias_attr,
344
            dtype=self._dtype,
345 346
            is_bias=True,
        )
347
        self.linear_weight = self.create_parameter(
348
            shape=[self.num_heads * self.head_dim, embed_dim],
349 350
            attr=linear_weight_attr,
            dtype=self._dtype,
351 352 353 354 355 356 357 358
            is_bias=False,
        )
        self.linear_bias = self.create_parameter(
            shape=[embed_dim],
            attr=linear_bias_attr,
            dtype=self._dtype,
            is_bias=True,
        )
359

360 361 362 363 364 365 366 367 368 369 370 371 372
        # tensor model parallel
        if nranks > 1:
            assert ring_id != -1
            # column parallel
            _set_var_distributed(self.qkv_weight)
            _set_var_distributed(self.qkv_bias)
            # row parallel
            _set_var_distributed(self.linear_weight)

        if normalize_before:
            self.pre_ln_scale = self.create_parameter(
                attr=pre_ln_scale_attr,
                shape=[embed_dim],
373 374 375 376 377
                default_initializer=Constant(value=1.0),
            )
            self.pre_ln_bias = self.create_parameter(
                attr=pre_ln_bias_attr, shape=[embed_dim], is_bias=True
            )
378 379 380 381 382 383 384 385
            self.ln_scale = None
            self.ln_bias = None
        else:
            self.pre_ln_scale = None
            self.pre_ln_bias = None
            self.ln_scale = self.create_parameter(
                attr=ln_scale_attr,
                shape=[embed_dim],
386 387 388 389 390
                default_initializer=Constant(value=1.0),
            )
            self.ln_bias = self.create_parameter(
                attr=ln_bias_attr, shape=[embed_dim], is_bias=True
            )
391 392 393 394 395

        self.dropout_rate = dropout_rate
        self.attn_dropout_rate = attn_dropout_rate

        self.name = name
396 397 398 399 400

    def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
        """
        Applies multi-head attention to map queries and a set of key-value pairs
        to outputs.
401

402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
        Parameters:
            query (Tensor): The queries for multi-head attention. It is a
                tensor with shape `[batch_size, query_length, embed_dim]`. The
                data type should be float32 or float64.
            key (Tensor, optional): The keys for multi-head attention. It is
                a tensor with shape `[batch_size, key_length, kdim]`. The
                data type should be float32 or float64. If None, use `query` as
                `key`. Default None.
            value (Tensor, optional): The values for multi-head attention. It
                is a tensor with shape `[batch_size, value_length, vdim]`.
                The data type should be float32 or float64. If None, use `query` as
                `value`. Default None.
            attn_mask (Tensor, optional): A tensor used in multi-head attention
                to prevents attention to some unwanted positions, usually the
                paddings or the subsequent positions. It is a tensor with shape
                broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
418 419 420 421 422
                When the data type is bool, the unwanted positions have `False`
                values and the others have `True` values. When the data type is
                int, the unwanted positions have 0 values and the others have 1
                values. When the data type is float, the unwanted positions have
                `-INF` values and the others have 0 values. It can be None when
423 424
                nothing wanted or needed to be prevented attention to. Default None.
            cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional):
425
                Now, only None is supported. Default None.
426

427 428
        Returns:
            Tensor|tuple: It is a tensor that has the same shape and data type \
429
                as `query`, representing attention output.
430
        """
431 432 433 434 435 436 437 438 439 440 441 442 443
        if attn_mask is not None:
            # Support bool or int mask
            attn_mask = _convert_attention_mask(attn_mask, query.dtype)

        out = incubate_f.fused_multi_head_attention(
            x=query,
            qkv_weight=self.qkv_weight,
            linear_weight=self.linear_weight,
            pre_layer_norm=self.normalize_before,
            pre_ln_scale=self.pre_ln_scale,
            pre_ln_bias=self.pre_ln_bias,
            ln_scale=self.ln_scale,
            ln_bias=self.ln_bias,
444
            pre_ln_epsilon=self._epsilon,
445 446
            qkv_bias=self.qkv_bias,
            linear_bias=self.linear_bias,
447
            cache_kv=cache,
448 449 450
            attn_mask=attn_mask,
            dropout_rate=self.dropout_rate,
            attn_dropout_rate=self.attn_dropout_rate,
451 452
            ln_epsilon=self._epsilon,
            training=self.training,
453
            ring_id=self._ring_id,
454 455
            num_heads=self.num_heads,
            transpose_qkv_wb=self.transpose_qkv_wb,
456 457
            name=self.name,
        )
458
        return out
459

460 461 462
    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'embed_dim={}, num_heads={}, dropout_rate={}, attn_dropout_rate={}, epsilon={}, kdim={}, vdim={}, normalize_before={}, need_weights={}, dtype={}{}'.format(
463 464 465 466 467 468 469 470 471 472 473 474
            self.embed_dim,
            self.num_heads,
            self.dropout_rate,
            self.attn_dropout_rate,
            self._epsilon,
            self.kdim,
            self.vdim,
            self.normalize_before,
            self.need_weights,
            self._dtype,
            name_str,
        )
475

476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
    def _amp_decorate(self, dtype):
        # tmp fix for amp.decorator(O2)
        layer_norm_params_id = []
        if self.normalize_before:
            layer_norm_params_id.append(id(self.pre_ln_scale))
            layer_norm_params_id.append(id(self.pre_ln_bias))
        else:
            layer_norm_params_id.append(id(self.ln_scale))
            layer_norm_params_id.append(id(self.ln_bias))

        for key, param in self._parameters.items():
            if id(param) in layer_norm_params_id:
                continue
            if param is not None:
                with no_grad():
                    param_applied = _to_dtype(param, dtype)

        self._dtype = dtype

495 496

class FusedFeedForward(Layer):
497 498 499 500 501 502
    """
    Parameters:
        d_model (int): The expected feature size in the input and output.
        dim_feedforward (int): The hidden layer size.
        dropout_rate (float, optional): The dropout probability used in pre-process
            and post-precess. Default 0.1
503 504
        epsilon (float, optional): he small value added to the variance to prevent
            division by zero. Default: 1e-05.
505 506 507 508 509
        activation (str, optional): The activation function. Default relu.
        act_dropout_rate (float, optional): The dropout probability after activition.
            If None, use the value of `dropout_rate`. Default None
        normalize_before (bool, optional): Indicate whether to put layer normalization
            into, preprocessing or postprocessing. Default False
510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541
        linear1_weight_attr(ParamAttr, optional): To specify the weight parameter property
            for FFN first linear. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        linear1_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for FFN first linear. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        linear2_weight_attr(ParamAttr, optional): To specify the weight parameter property
            for FFN second linear. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        linear2_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for FFN second linear. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ln1_scale_attr(ParamAttr, optional): To specify the weight parameter property
            for FFN pre_layer_norm. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ln1_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for FFN pre_layer_norm. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ln2_scale_attr(ParamAttr, optional): To specify the weight parameter property
            for FFN post_layer_norm. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ln2_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for FFN layer_norm. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using tensor parallel.
        ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using tensor parallel.
        name (str, optional): The default value is None.  Normally there is no need for user to set
            this property. For more information, please refer to :ref:`api_guide_Name`.
542 543 544 545 546 547 548 549 550 551 552

    Examples:
        .. code-block:: python

            # required: gpu
            import paddle
            from paddle.incubate.nn import FusedFeedForward

            fused_feedforward_layer = FusedFeedForward(8, 8)
            x = paddle.rand((1, 8, 8))
            out = fused_feedforward_layer(x)
553 554
            print(out.shape)
            # [1, 8, 8]
555 556
    """

557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577
    def __init__(
        self,
        d_model,
        dim_feedforward,
        dropout_rate=0.1,
        epsilon=1e-05,
        activation="relu",
        act_dropout_rate=None,
        normalize_before=False,
        linear1_weight_attr=None,
        linear1_bias_attr=None,
        linear2_weight_attr=None,
        linear2_bias_attr=None,
        ln1_scale_attr=None,
        ln1_bias_attr=None,
        ln2_scale_attr=None,
        ln2_bias_attr=None,
        nranks=1,
        ring_id=-1,
        name=None,
    ):
578

579
        super().__init__()
580 581 582 583 584 585 586 587 588 589
        assert (
            d_model > 0
        ), "Expected d_model to be greater than 0, but received {}".format(
            d_model
        )
        assert (
            dim_feedforward > 0
        ), "Expected dim_feedforward to be greater than 0, but received {}".format(
            dim_feedforward
        )
590 591 592

        self._dtype = self._helper.get_default_dtype()
        self._d_model = d_model
593 594 595

        assert dim_feedforward % nranks == 0
        dim_feedforward = dim_feedforward // nranks
596 597
        self._dim_feedforward = dim_feedforward
        self._dropout_rate = dropout_rate
598 599 600
        self._act_dropout_rate = (
            dropout_rate if act_dropout_rate is None else act_dropout_rate
        )
601 602
        self._act_method = activation
        self._normalize_before = normalize_before
603
        self._epsilon = epsilon
604
        self._ring_id = ring_id
605 606 607

        self._linear1_weight = self.create_parameter(
            shape=[d_model, dim_feedforward],
608
            attr=linear1_weight_attr,
609
            dtype=self._dtype,
610 611 612 613 614 615 616 617
            is_bias=False,
        )
        self._linear1_bias = self.create_parameter(
            shape=[dim_feedforward],
            attr=linear1_bias_attr,
            dtype=self._dtype,
            is_bias=True,
        )
618 619 620

        self._linear2_weight = self.create_parameter(
            shape=[dim_feedforward, d_model],
621
            attr=linear2_weight_attr,
622
            dtype=self._dtype,
623 624
            is_bias=False,
        )
625

626 627 628 629 630 631
        self._linear2_bias = self.create_parameter(
            shape=[d_model],
            attr=linear2_bias_attr,
            dtype=self._dtype,
            is_bias=True,
        )
632

633 634 635 636 637 638 639 640 641 642 643 644
        if nranks > 1:
            assert ring_id != -1
            # column parallel
            _set_var_distributed(self._linear1_weight)
            _set_var_distributed(self._linear1_bias)
            _set_var_distributed(self._linear2_weight)

        if normalize_before:
            self._ln1_scale = self.create_parameter(
                shape=[d_model],
                attr=ln1_scale_attr,
                is_bias=False,
645 646 647 648 649
                default_initializer=Constant(1.0),
            )
            self._ln1_bias = self.create_parameter(
                shape=[d_model], attr=ln1_bias_attr, is_bias=True
            )
650 651 652 653 654 655 656 657 658
            self._ln2_scale = None
            self._ln2_bias = None
        else:
            self._ln1_scale = None
            self._ln1_bias = None
            self._ln2_scale = self.create_parameter(
                shape=[d_model],
                attr=ln2_scale_attr,
                is_bias=False,
659 660 661 662 663
                default_initializer=Constant(1.0),
            )
            self._ln2_bias = self.create_parameter(
                shape=[d_model], attr=ln2_bias_attr, is_bias=True
            )
664

665
        self.name = name
666 667

    def forward(self, src, cache=None):
668
        out = incubate_f.fused_feedforward(
669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684
            src,
            self._linear1_weight,
            self._linear2_weight,
            self._linear1_bias,
            self._linear2_bias,
            self._ln1_scale,
            self._ln1_bias,
            self._ln2_scale,
            self._ln2_bias,
            dropout1_rate=self._act_dropout_rate,
            dropout2_rate=self._dropout_rate,
            activation=self._act_method,
            ln1_epsilon=self._epsilon,
            ln2_epsilon=self._epsilon,
            pre_layer_norm=self._normalize_before,
            training=self.training,
685
            ring_id=self._ring_id,
686 687
            name=self.name,
        )
688
        return out
689

690 691 692
    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'd_model={}, dim_feedforward={}, dropout_rate={}, epsilon={}, activation={}, act_dropout_rate={}, normalize_before={}, dtype={}{}'.format(
693 694 695 696 697 698 699 700 701 702
            self._d_model,
            self._dim_feedforward,
            self._dropout_rate,
            self._epsilon,
            self._act_method,
            self._act_dropout_rate,
            self._normalize_before,
            self._dtype,
            name_str,
        )
703

704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722
    def _amp_decorate(self, dtype):
        # tmp fix for amp.decorator(O2)
        layer_norm_params_id = []
        if self._normalize_before:
            layer_norm_params_id.append(id(self._ln1_scale))
            layer_norm_params_id.append(id(self._ln1_bias))
        else:
            layer_norm_params_id.append(id(self._ln2_scale))
            layer_norm_params_id.append(id(self._ln2_bias))

        for key, param in self._parameters.items():
            if id(param) in layer_norm_params_id:
                continue
            if param is not None:
                with no_grad():
                    param_applied = _to_dtype(param, dtype)

        self._dtype = dtype

723 724 725

class FusedTransformerEncoderLayer(Layer):
    """
U
ustiniankw 已提交
726

727
    FusedTransformerEncoderLayer is composed of two sub-layers which are self (multi-head)
728 729 730 731 732 733 734 735 736 737
    attention and feedforward network. Before and after each sub-layer, pre-process
    and post-precess would be applied on the input and output accordingly. If
    `normalize_before` is True, pre-process is layer normalization and post-precess
    includes dropout, residual connection. Otherwise, no pre-process and post-precess
    includes dropout, residual connection, layer normalization.

    Parameters:
        d_model (int): The expected feature size in the input and output.
        nhead (int): The number of heads in multi-head attention(MHA).
        dim_feedforward (int): The hidden layer size in the feedforward network(FFN).
738
        dropout_rate (float, optional): The dropout probability used in pre-process
739 740 741
            and post-precess of MHA and FFN sub-layer. Default 0.1
        activation (str, optional): The activation function in the feedforward
            network. Default relu.
742
        attn_dropout_rate (float, optional): The dropout probability used
743 744
            in MHA to drop some attention target. If None, use the value of
            `dropout`. Default None
745
        act_dropout_rate (float, optional): The dropout probability used after FFN
746 747 748 749 750 751 752 753 754 755 756
            activition.  If None, use the value of `dropout`. Default None
        normalize_before (bool, optional): Indicate whether to put layer normalization
            into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer
            normalization and post-precess includes dropout, residual connection.
            Otherwise, no pre-process and post-precess includes dropout, residual
            connection, layer normalization. Default False
        weight_attr(ParamAttr|list|tuple, optional): To specify the weight parameter property.
            If it is a list/tuple, `weight_attr[0]` would be used as `weight_attr` for
            MHA, and `weight_attr[1]` would be used as `weight_attr` for linear in FFN.
            Otherwise, MHA and FFN both use it as `weight_attr` to create parameters.
            Default: None, which means the default weight parameter property is used.
757
            See usage for details in :code:`ParamAttr` .
758 759 760 761 762 763 764
        bias_attr (ParamAttr|list|tuple|bool, optional): To specify the bias parameter property.
            If it is a list/tuple, `bias_attr[0]` would be used as `bias_attr` for
            MHA, and `bias_attr[1]` would be used as `bias_attr` for linear in FFN.
            Otherwise, MHA and FFN both use it as `bias_attr` to create parameters.
            The `False` value means the corresponding layer would not have trainable
            bias parameter. See usage for details in :code:`ParamAttr` . Default: None,
            which means the default bias parameter property is used.
765

766 767 768

    Examples:
        .. code-block:: python
769

770
            # required: gpu
771
            import paddle
772
            from paddle.incubate.nn import FusedTransformerEncoderLayer
773 774 775 776 777

            # encoder input: [batch_size, src_len, d_model]
            enc_input = paddle.rand((2, 4, 128))
            # self attention mask: [batch_size, n_head, src_len, src_len]
            attn_mask = paddle.rand((2, 2, 4, 4))
778
            encoder_layer = FusedTransformerEncoderLayer(128, 2, 512)
779
            enc_output = encoder_layer(enc_input, attn_mask)  # [2, 4, 128]
U
ustiniankw 已提交
780

781 782
    """

783 784 785 786 787 788 789 790 791 792 793 794 795
    def __init__(
        self,
        d_model,
        nhead,
        dim_feedforward,
        dropout_rate=0.1,
        activation="relu",
        attn_dropout_rate=None,
        act_dropout_rate=None,
        normalize_before=False,
        weight_attr=None,
        bias_attr=None,
    ):
796 797 798 799
        self._config = locals()
        self._config.pop("self")
        self._config.pop("__class__", None)  # py3

800
        super().__init__()
801 802 803 804 805 806 807 808 809 810
        assert (
            d_model > 0
        ), "Expected d_model to be greater than 0, " "but received {}".format(
            d_model
        )
        assert (
            nhead > 0
        ), "Expected nhead to be greater than 0, " "but received {}".format(
            nhead
        )
811 812
        assert dim_feedforward > 0, (
            "Expected dim_feedforward to be greater than 0, "
813 814 815 816 817 818 819 820
            "but received {}".format(dim_feedforward)
        )
        attn_dropout_rate = (
            dropout_rate if attn_dropout_rate is None else attn_dropout_rate
        )
        act_dropout_rate = (
            dropout_rate if act_dropout_rate is None else act_dropout_rate
        )
821 822 823 824 825 826 827 828
        self.normalize_before = normalize_before

        weight_attrs = _convert_param_attr_to_list(weight_attr, 2)
        bias_attrs = _convert_param_attr_to_list(bias_attr, 2)

        self.fused_attn = FusedMultiHeadAttention(
            d_model,
            nhead,
829 830 831
            dropout_rate=dropout_rate,
            attn_dropout_rate=attn_dropout_rate,
            normalize_before=self.normalize_before,
832 833 834 835 836 837 838
            qkv_weight_attr=weight_attrs[0],
            qkv_bias_attr=bias_attrs[0],
            linear_weight_attr=weight_attrs[0],
            linear_bias_attr=bias_attrs[0],
            pre_ln_scale_attr=weight_attrs[0],
            pre_ln_bias_attr=bias_attrs[0],
            ln_scale_attr=weight_attrs[0],
839 840 841 842 843 844 845 846 847 848 849 850 851 852 853
            ln_bias_attr=bias_attrs[0],
        )

        self.ffn = FusedFeedForward(
            d_model,
            dim_feedforward,
            dropout_rate=dropout_rate,
            activation=activation,
            act_dropout_rate=act_dropout_rate,
            normalize_before=self.normalize_before,
            linear1_weight_attr=weight_attrs[1],
            linear1_bias_attr=bias_attrs[1],
            linear2_weight_attr=weight_attrs[1],
            linear2_bias_attr=bias_attrs[1],
        )
854 855 856

    def forward(self, src, src_mask=None, cache=None):
        """
U
ustiniankw 已提交
857

858
        Applies a Transformer encoder layer on the input.
U
ustiniankw 已提交
859

860 861 862 863 864 865 866 867
        Parameters:
            src (Tensor): The input of Transformer encoder layer. It is
                a tensor with shape `[batch_size, sequence_length, d_model]`.
                The data type should be float32 or float64.
            src_mask (Tensor, optional): A tensor used in multi-head attention
                to prevents attention to some unwanted positions, usually the
                paddings or the subsequent positions. It is a tensor with shape
                broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
868 869 870 871 872
                When the data type is bool, the unwanted positions have `False`
                values and the others have `True` values. When the data type is
                int, the unwanted positions have 0 values and the others have 1
                values. When the data type is float, the unwanted positions have
                `-INF` values and the others have 0 values. It can be None when
873 874
                nothing wanted or needed to be prevented attention to. Default None.
            cache (Tensor, optional): It is an instance of `MultiHeadAttention.Cache`.
U
ustiniankw 已提交
875
                See :ref:`api_paddle_nn_TransformerEncoderLayer`.gen_cache for more details. It is
876 877
                only used for inference and should be None for training. Default
                None.
U
ustiniankw 已提交
878

879
        Returns:
U
ustiniankw 已提交
880
            Tensor|tuple, It is a tensor that has the same shape and data type \
881 882 883 884 885 886
                as `enc_input`, representing the output of Transformer encoder \
                layer. Or a tuple if `cache` is not None, except for encoder \
                layer output, the tuple includes the new cache which is same \
                as input `cache` argument but `incremental_cache` has an \
                incremental length. See `MultiHeadAttention.gen_cache` and \
                `MultiHeadAttention.forward` for more details.
U
ustiniankw 已提交
887

888
        """
889 890 891 892
        src_mask = _convert_attention_mask(src_mask, src.dtype)
        if cache is None:
            attn_out = self.fused_attn(src, attn_mask=src_mask)
        else:
893 894 895
            attn_out, incremental_cache = self.fused_attn(
                src, attn_mask=src_mask, cache=cache
            )
896 897 898 899

        ffn_out = self.ffn(attn_out)

        return ffn_out if cache is None else (ffn_out, incremental_cache)
900 901 902 903 904 905 906 907 908 909


class FusedTransformer(Layer):
    """
    A Transformer model composed of an instance of `TransformerEncoder` and an
    instance of `TransformerDecoder`. While the embedding layer and output layer
    are not included.

    Please refer to `Attention is all you need <http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf>`_ ,
    and see `TransformerEncoder` and `TransformerDecoder` for more details.
910

911 912 913 914
    Users can configurate the model architecture with corresponding parameters.
    Note the usage of `normalize_before` representing where to apply layer
    normalization (in pre-process or post-precess of multi-head attention or FFN),
    and some transformer like models are different on this, such as
915
    `BERT <https://arxiv.org/abs/1810.04805>`_ and `GPT2 <https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf>`_ .
916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940
    The default architecture here places layer normalization in post-process and
    applies another layer normalization on the output of last encoder/decoder layer.

    Parameters:
        d_model (int, optional): The expected feature size in the encoder/decoder input
            and output. Default 512
        nhead (int, optional): The number of heads in multi-head attention(MHA). Default 8
        num_encoder_layers (int, optional): The number of layers in encoder. Default 6
        num_decoder_layers (int, optional): The number of layers in decoder. Default 6
        dim_feedforward (int, optional): The hidden layer size in the feedforward network(FFN). Default 2048
        dropout (float, optional): The dropout probability used in pre-process
            and post-precess of MHA and FFN sub-layer. Default 0.1
        activation (str, optional): The activation function in the feedforward
            network. Default relu.
        attn_dropout (float, optional): The dropout probability used
            in MHA to drop some attention target. If None, use the value of
            `dropout`. Default None
        act_dropout (float, optional): The dropout probability used after FFN
            activition.  If None, use the value of `dropout`. Default None
        normalize_before (bool, optional): Indicate whether to put layer normalization
            into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer
            normalization and post-precess includes dropout, residual connection.
            Otherwise, no pre-process and post-precess includes dropout, residual
            connection, layer normalization. Default False
        weight_attr(ParamAttr|list|tuple, optional): To specify the weight parameter property.
941 942 943 944 945 946 947 948 949 950
            If it is a list/tuple, the length of `weight_attr` could be 1, 2 or 3. If it is 3,
            `weight_attr[0]` would be used as `weight_attr` for self attention, `weight_attr[1]`
            would be used as `weight_attr` for cross attention of `TransformerDecoder`,
            and `weight_attr[2]` would be used as `weight_attr` for linear in FFN.
            If it is 2, `weight_attr[0]` would be used as `weight_attr` both for self attention
            and cross attntion and `weight_attr[1]` would be used as `weight_attr` for
            linear in FFN. If it is 1, `weight_attr[0]` would be used as `weight_attr`
            for self attention, cross attention and linear in FFN. Otherwise,
            the three sub-layers all uses it as `weight_attr` to create parameters.
            Default: None, which means the default weight parameter property is used.
951
            See usage for details
952
            in :code:`ParamAttr` .
953
        bias_attr (ParamAttr|list|tuple|bool, optional): To specify the bias parameter property.
954 955 956 957 958 959 960 961 962 963 964
            If it is a list/tuple, the length of `bias_attr` could be 1, 2 or 3. If it is 3,
            `bias_attr[0]` would be used as `bias_attr` for self attention, `bias_attr[1]`
            would be used as `bias_attr` for cross attention of `TransformerDecoder`,
            and `bias_attr[2]` would be used as `bias_attr` for linear in FFN.
            If it is 2, `bias_attr[0]` would be used as `bias_attr` both for self attention
            and cross attntion and `bias_attr[1]` would be used as `bias_attr` for
            linear in FFN. If it is 1, `bias_attr[0]` would be used as `bias_attr`
            for self attention, cross attention and linear in FFN. Otherwise,
            the three sub-layers all uses it as `bias_attr` to create parameters.
            The `False` value means the corresponding layer would not have trainable
            bias parameter. See usage for details in :code:`ParamAttr` .
965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995
            Default: None,which means the default bias parameter property is used.
        custom_encoder (Layer, optional): If custom encoder is provided, use it as the encoder.
            Default None
        custom_decoder (Layer, optional): If custom decoder is provided, use it as the decoder.
            Default None

    Examples:

        .. code-block:: python

            import paddle
            from paddle.nn import Transformer

            # src: [batch_size, tgt_len, d_model]
            enc_input = paddle.rand((2, 4, 128))
            # tgt: [batch_size, src_len, d_model]
            dec_input = paddle.rand((2, 6, 128))
            # src_mask: [batch_size, n_head, src_len, src_len]
            enc_self_attn_mask = paddle.rand((2, 2, 4, 4))
            # tgt_mask: [batch_size, n_head, tgt_len, tgt_len]
            dec_self_attn_mask = paddle.rand((2, 2, 6, 6))
            # memory_mask: [batch_size, n_head, tgt_len, src_len]
            cross_attn_mask = paddle.rand((2, 2, 6, 4))
            transformer = Transformer(128, 2, 4, 4, 512)
            output = transformer(enc_input,
                                 dec_input,
                                 enc_self_attn_mask,
                                 dec_self_attn_mask,
                                 cross_attn_mask)  # [2, 6, 128]
    """

996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012
    def __init__(
        self,
        d_model=512,
        nhead=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        dim_feedforward=2048,
        dropout=0.1,
        activation="relu",
        attn_dropout=None,
        act_dropout=None,
        normalize_before=False,
        weight_attr=None,
        bias_attr=None,
        custom_encoder=None,
        custom_decoder=None,
    ):
1013
        super().__init__()
1014
        raise NotImplementedError()
1015 1016

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
1017
        raise NotImplementedError()
1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156


class FusedMultiTransformer(Layer):
    """
    FusedMultiTransformer is composed of multi transformer layers which contains two
    sub-layers which are self (multi-head) attention and feedforward network. The
    function of one transformer layer is consistent with the following pseudo code:

    .. code-block:: python

        if pre_layer_norm:
            out = layer_norm(x)
            out = qkv_linear(out) + qkv_bias
        else:
            out = qkv_linear(x) + qkv_bias
        out = transpose(out, perm=[2, 0, 3, 1, 4])
        # extract q, k and v from out.
        q = out[0:1, ::]
        k = out[1:2, ::]
        v = out[2:3, ::]
        out = q * k^t
        out = attn_mask + out
        out = softmax(out)
        out = dropout(out)
        out = out * v
        out = transpose(out, perm=[0, 2, 1, 3])
        out = linear(out)
        if pre_layer_norm:
            out = x + dropout(out + bias)
        else:
            out = layer_norm(x + dropout(out + bias))

        residual = out;
        if pre_layer_norm:
            out = ffn_layer_norm(out)
        out = ffn1_linear(out)
        out = dropout(activation(out + ffn1_bias))
        out = ffn2_linear(out)
        out = residual + dropout(out + ffn2_bias)
        if not pre_layer_norm:
            out = ffn_layer_norm(out)

    Parameters:
        embed_dim (int): The expected feature size in the input and output.
        num_heads (int): The number of heads in multi-head attention(MHA).
        dim_feedforward (int): The hidden layer size in the feedforward network(FFN).
        dropout_rate (float, optional): The dropout probability used in pre-process
            and post-precess of MHA and FFN sub-layer. Default 0.0
        activation (str, optional): The activation function in the feedforward
            network. Default "gelu".
        normalize_before (bool, optional): Indicate whether to put layer normalization
            into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer
            normalization and post-precess includes dropout, residual connection.
            Otherwise, no pre-process and post-precess includes dropout, residual
            connection, layer normalization. Default True
        ln_scale_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
            for Attention layer_norm. For Attention layer_norm weight, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ln_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
            for Attention layer_norm. For Attention layer_norm bias, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        qkv_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
            for Attention qkv computation. For Attention qkv weight, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        qkv_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
            for Attention qkv computation. For Attention qkv bias, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        linear_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
            for Attention linear. For Attention linear weight, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        linear_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
            for Attention linear computation. For Attention linear bias, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ffn_ln_scale_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
            for FFN layer_norm. For FFN layer_norm weight, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ffn_ln_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
            for FFN layer_norm. For FFN layer_norm bias, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ffn1_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
            for FFN first linear. For FFN first linear weight, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ffn1_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
            for FFN first linear. For FFN first linear bias, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ffn2_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
            for FFN second linear. For FFN second linear weight, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ffn2_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
            for FFN second linear. For FFN second linear bias, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        epsilon (float, optional): Small float value added to denominator of the layer_norm to
            avoid dividing by zero. Default: 1e-05.
        num_layers (int, optional): The number of layers of the transformer. If `qkv_weight_attrs`
            is a list or tuple, the number of layers is obtained from `qkv_weight_attrs`. num_layers
            only takes effect when `qkv_weight_attrs` is not a list or tuple. Default: -1.
        nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using mp.
1157 1158 1159
        trans_qkvw (bool, optional): Whether to transpose for weights of qkv.
            If true, the shape eights of qkv should be [3, num_head, dim_head, dim_embed].
            Otherwise the shape of weights of qkv should be [dim_embed, 3, num_head, dim_head]. Default: True.
1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179
        ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using mp.
        name (str, optional): The default value is None.  Normally there is no need for user to set
            this property. For more information, please refer to :ref:`api_guide_Name`.

    Examples:

        .. code-block:: python

            # required: gpu
            import paddle
            from paddle.incubate.nn import FusedMultiTransformer

            # encoder input: [batch_size, src_len, d_model]
            enc_input = paddle.rand((2, 4, 128))
            # self attention mask: [batch_size, 1, src_len, src_len]
            attn_mask = paddle.rand((2, 1, 4, 4))
            encoder_layers = FusedMultiTransformer(128, 2, 512, num_layers=1)
            enc_output = encoder_layers(enc_input, attn_mask)  # [2, 4, 128]
    """

1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206
    def __init__(
        self,
        embed_dim,
        num_heads,
        dim_feedforward,
        dropout_rate=0.0,
        activation="gelu",
        normalize_before=True,
        ln_scale_attrs=None,
        ln_bias_attrs=None,
        qkv_weight_attrs=None,
        qkv_bias_attrs=None,
        linear_weight_attrs=None,
        linear_bias_attrs=None,
        ffn_ln_scale_attrs=None,
        ffn_ln_bias_attrs=None,
        ffn1_weight_attrs=None,
        ffn1_bias_attrs=None,
        ffn2_weight_attrs=None,
        ffn2_bias_attrs=None,
        epsilon=1e-5,
        num_layers=-1,
        nranks=1,
        trans_qkvw=True,
        ring_id=-1,
        name=None,
    ):
1207
        super().__init__()
1208

1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222
        assert embed_dim > 0, (
            "Expected embed_dim to be greater than 0, "
            "but received {}".format(embed_dim)
        )
        assert (
            num_heads > 0
        ), "Expected nhead to be greater than 0, " "but received {}".format(
            num_heads
        )
        assert (
            dim_feedforward > 0
        ), "Expected dim_feedforward to be greater than 0, but received {}".format(
            dim_feedforward
        )
1223 1224 1225 1226

        self.normalize_before = normalize_before
        self._dtype = self._helper.get_default_dtype()
        self._epsilon = epsilon
1227
        self._trans_qkvw = trans_qkvw
1228 1229 1230 1231 1232
        self._ring_id = ring_id

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
1233 1234 1235
        assert (
            self.head_dim * num_heads == embed_dim
        ), "embed_dim must be divisible by num_heads"
1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280

        # tensor model parallel
        if nranks > 1:
            assert ring_id != -1
        assert num_heads % nranks == 0
        assert dim_feedforward % nranks == 0
        num_heads = num_heads // nranks
        dim_feedforward = dim_feedforward // nranks
        self._dim_feedforward = dim_feedforward

        if isinstance(qkv_weight_attrs, (list, tuple)):
            num_layers = len(qkv_weight_attrs)
        assert num_layers > 0

        self.ln_scales, self.ln_biases = [], []
        self.qkv_weights, self.qkv_biases = [], []
        self.linear_weights, self.linear_biases = [], []
        self.ffn_ln_scales, self.ffn_ln_biases = [], []
        self.ffn1_weights, self.ffn1_biases = [], []
        self.ffn2_weights, self.ffn2_biases = [], []

        def get_attr(attrs, idx):
            if isinstance(attrs, (list, tuple)):
                assert len(attrs) == num_layers
                return attrs[idx]
            return attrs

        for i in range(num_layers):
            ln_scale_attr = get_attr(ln_scale_attrs, i)
            ln_bias_attr = get_attr(ln_bias_attrs, i)
            qkv_weight_attr = get_attr(qkv_weight_attrs, i)
            qkv_bias_attr = get_attr(qkv_bias_attrs, i)
            linear_weight_attr = get_attr(linear_weight_attrs, i)
            linear_bias_attr = get_attr(linear_bias_attrs, i)

            ffn_ln_scale_attr = get_attr(ffn_ln_scale_attrs, i)
            ffn_ln_bias_attr = get_attr(ffn_ln_bias_attrs, i)
            ffn1_weight_attr = get_attr(ffn1_weight_attrs, i)
            ffn1_bias_attr = get_attr(ffn1_bias_attrs, i)
            ffn2_weight_attr = get_attr(ffn2_weight_attrs, i)
            ffn2_bias_attr = get_attr(ffn2_bias_attrs, i)

            ln_scale = self.create_parameter(
                attr=ln_scale_attr,
                shape=[embed_dim],
1281 1282 1283 1284 1285
                default_initializer=Constant(value=1.0),
            )
            ln_bias = self.create_parameter(
                attr=ln_bias_attr, shape=[embed_dim], is_bias=True
            )
1286
            qkv_weight = self.create_parameter(
1287
                shape=[3, num_heads, self.head_dim, embed_dim]
1288 1289
                if trans_qkvw
                else [embed_dim, 3, num_heads, self.head_dim],
1290 1291
                attr=qkv_weight_attr,
                dtype=self._dtype,
1292 1293
                is_bias=False,
            )
1294 1295 1296 1297
            qkv_bias = self.create_parameter(
                shape=[3, num_heads, self.head_dim],
                attr=qkv_bias_attr,
                dtype=self._dtype,
1298 1299
                is_bias=True,
            )
1300 1301 1302 1303
            linear_weight = self.create_parameter(
                shape=[num_heads * self.head_dim, embed_dim],
                attr=linear_weight_attr,
                dtype=self._dtype,
1304 1305 1306 1307 1308 1309 1310 1311
                is_bias=False,
            )
            linear_bias = self.create_parameter(
                shape=[embed_dim],
                attr=linear_bias_attr,
                dtype=self._dtype,
                is_bias=True,
            )
1312 1313 1314 1315 1316

            ffn_ln_scale = self.create_parameter(
                shape=[embed_dim],
                attr=ffn_ln_scale_attr,
                is_bias=False,
1317 1318 1319 1320 1321
                default_initializer=Constant(1.0),
            )
            ffn_ln_bias = self.create_parameter(
                shape=[embed_dim], attr=ffn_ln_bias_attr, is_bias=True
            )
1322 1323 1324 1325
            ffn1_weight = self.create_parameter(
                shape=[embed_dim, dim_feedforward],
                attr=ffn1_weight_attr,
                dtype=self._dtype,
1326 1327 1328 1329 1330 1331 1332 1333
                is_bias=False,
            )
            ffn1_bias = self.create_parameter(
                shape=[dim_feedforward],
                attr=ffn1_bias_attr,
                dtype=self._dtype,
                is_bias=True,
            )
1334 1335 1336 1337
            ffn2_weight = self.create_parameter(
                shape=[dim_feedforward, embed_dim],
                attr=ffn2_weight_attr,
                dtype=self._dtype,
1338 1339 1340 1341 1342 1343 1344 1345
                is_bias=False,
            )
            ffn2_bias = self.create_parameter(
                shape=[embed_dim],
                attr=ffn2_bias_attr,
                dtype=self._dtype,
                is_bias=True,
            )
1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375

            # tensor model parallel
            if nranks > 1:
                # column parallel
                _set_var_distributed(qkv_weight)
                _set_var_distributed(qkv_bias)
                _set_var_distributed(ffn1_weight)
                _set_var_distributed(ffn1_bias)
                # row parallel
                _set_var_distributed(linear_weight)
                _set_var_distributed(ffn2_weight)

            self.ln_scales.append(ln_scale)
            self.ln_biases.append(ln_bias)
            self.qkv_weights.append(qkv_weight)
            self.qkv_biases.append(qkv_bias)
            self.linear_weights.append(linear_weight)
            self.linear_biases.append(linear_bias)

            self.ffn_ln_scales.append(ffn_ln_scale)
            self.ffn_ln_biases.append(ffn_ln_bias)
            self.ffn1_weights.append(ffn1_weight)
            self.ffn1_biases.append(ffn1_bias)
            self.ffn2_weights.append(ffn2_weight)
            self.ffn2_biases.append(ffn2_bias)

        self.dropout_rate = dropout_rate
        self.activation = activation
        self.name = name

1376
    def forward(
1377 1378 1379 1380 1381 1382 1383 1384
        self,
        src,
        attn_mask=None,
        caches=None,
        pre_caches=None,
        rotary_embs=None,
        rotary_emb_dims=0,
        time_step=None,
1385
    ):
1386
        r"""
1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402
        Applies multi transformer layers on the input.

        Parameters:
            src (Tensor): The input of Transformer layers. It is
                a tensor with shape `[batch_size, sequence_length, d_model]`.
                The data type should be float16 or float32.
            attn_mask (Tensor, optional): A tensor used in multi-head attention
                to prevents attention to some unwanted positions, usually the
                paddings or the subsequent positions. It is a tensor with shape
                `[batch_size, 1, sequence_length, sequence_length]`. It can be
                None when nothing wanted or needed to be prevented attention to.
                Default None.
            caches (list(Tensor)|tuple(Tensor), optional): The cache structure
                tensors for the inference generation model. It is only used for
                inference and should be None for training. The shape is
                `[2, batch_size, num_head, max_seq_len, head_dim]`. Default None.
1403 1404
            pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches
                for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None.
1405 1406 1407
            rotary_embs (Tensor optional): The RoPE embs for the rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None.
            rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None,
                1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0.
1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439
            time_step (Tensor, optional): The time step tensor for the generation
                model. Which used in decode stage, to represent the time step,
                that is, the real seq_len of CacheKV. The shape is `[1]`, must be
                in CPUPlace. Default None.

        Returns:
            Tensor|tuple: If `caches` is None, return a tensor that has
            the same shape and data type with `src`, representing the output
            of Transformer layers. If `caches` is not None, return the
            tuple (output, caches), which output is the output of
            Transformer layers, caches is inplace with input `caches`.
        """

        if caches is not None:
            assert len(caches) == len(self.qkv_weights)
        out = incubate_f.fused_multi_transformer(
            src,
            self.ln_scales,
            self.ln_biases,
            self.qkv_weights,
            self.qkv_biases,
            self.linear_weights,
            self.linear_biases,
            self.ffn_ln_scales,
            self.ffn_ln_biases,
            self.ffn1_weights,
            self.ffn1_biases,
            self.ffn2_weights,
            self.ffn2_biases,
            pre_layer_norm=self.normalize_before,
            epsilon=self._epsilon,
            cache_kvs=caches,
1440
            pre_caches=pre_caches,
1441
            rotary_embs=rotary_embs,
1442 1443 1444
            time_step=time_step,
            attn_mask=attn_mask,
            dropout_rate=self.dropout_rate,
1445
            rotary_emb_dims=rotary_emb_dims,
1446 1447 1448
            activation=self.activation,
            training=self.training,
            mode='upscale_in_train',
1449
            trans_qkvw=self._trans_qkvw,
1450
            ring_id=self._ring_id,
1451 1452
            name=self.name,
        )
1453
        return out