fused_transformer.py 56.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15 16 17 18
from paddle.nn import functional as F
from paddle.incubate.nn import functional as incubate_f
from paddle.nn import Layer
from paddle.framework import ParamAttr
import paddle
19
from paddle.nn.layer.transformer import _convert_attention_mask, _convert_param_attr_to_list
20 21 22 23
from paddle.nn.initializer import Constant

import collections

24

25 26 27 28 29 30 31 32 33 34 35 36 37 38
# for distributed tensor model parallel
def _set_var_distributed(var):
    if var is None:
        return

    var.is_distributed = True

    # NOTE: use current_block and find_var_recursive to support while_loop
    startup_block = paddle.static.default_startup_program().current_block()
    main_block = paddle.static.default_main_program().current_block()
    startup_block._find_var_recursive(var.name).is_distributed = True
    main_block._find_var_recursive(var.name).is_distributed = True


39 40
class FusedMultiHeadAttention(Layer):
    """
41
    Attention mapps queries and a set of key-value pairs to outputs, and
42 43 44 45
    Multi-Head Attention performs multiple parallel attention to jointly attending
    to information from different representation subspaces.
    Please refer to `Attention Is All You Need <https://arxiv.org/pdf/1706.03762.pdf>`_
    for more details.
46

47 48 49
    Parameters:
        embed_dim (int): The expected feature size in the input and output.
        num_heads (int): The number of heads in multi-head attention.
50
        dropout_rate (float, optional): The dropout probability used on attention
51
            weights to drop some attention targets for the dropout after attention.
52 53
            0 for no dropout. Default 0.5.
        attn_dropout_rate (float, optional): The dropout probability used on attention
54
            weights to drop some attention targets for the dropout in attention.
55
            0 for no dropout. Default 0.5.
56 57 58 59
        kdim (int, optional): The feature size in key. If None, assumed equal to
            `embed_dim`. Default None.
        vdim (int, optional): The feature size in value. If None, assumed equal to
            `embed_dim`. Default None.
60
        normalize_before (bool, optional): Indicate  whether it is pre_layer_norm
61
            (True) or post_layer_norm architecture (False). Default False.
62
        need_weights (bool, optional): Indicate whether to return the attention
63
            weights. Now, only False is supported. Default False.
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
        qkv_weight_attr(ParamAttr, optional): To specify the weight parameter property
            for QKV projection computation. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        qkv_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for QKV projection computation. The `False` value means the corresponding layer
            would not have trainable bias parameter. Default: None, which means the
            default bias parameter property is used. See usage for details in :code:`ParamAttr`.
        linear_weight_attr(ParamAttr, optional): To specify the weight parameter property
            for linear projection computation. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        linear_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for linear projection computation. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        pre_ln_scale_attr(ParamAttr, optional): To specify the weight parameter property
            for pre_layer_norm computation. Otherwise, all layers both use it as
            `attr` to create parameters. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        pre_ln_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for pre_layer_norm computation. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ln_scale_attr(ParamAttr, optional): To specify the weight parameter property
            for post_layer_norm computation. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ln_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for post_layer_norm computation. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
L
Li Min 已提交
93 94
        epsilon (float, optional): The small value added to the variance to prevent
            division by zero. Default: 1e-05.
95 96
        nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using tensor parallel.
        ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using tensor parallel.
97

98
    Examples:
99

100
        .. code-block:: python
101 102

            # required: gpu
103
            import paddle
104
            # input: [batch_size, sequence_length, embed_dim]
105 106 107
            query = paddle.rand((2, 4, 128))
            # self attention mask: [batch_size, num_heads, query_len, query_len]
            attn_mask = paddle.rand((2, 2, 4, 4))
108
            multi_head_attn = paddle.incubate.nn.FusedMultiHeadAttention(128, 2)
109 110 111 112 113 114
            output = multi_head_attn(query, None, None, attn_mask=attn_mask)  # [2, 4, 128]
    """

    def __init__(self,
                 embed_dim,
                 num_heads,
115
                 dropout_rate=0.5,
Z
zhangkaihuo 已提交
116
                 attn_dropout_rate=0.5,
117 118
                 kdim=None,
                 vdim=None,
119
                 normalize_before=False,
120
                 need_weights=False,
121 122 123 124 125 126 127 128
                 qkv_weight_attr=None,
                 qkv_bias_attr=None,
                 linear_weight_attr=None,
                 linear_bias_attr=None,
                 pre_ln_scale_attr=None,
                 pre_ln_bias_attr=None,
                 ln_scale_attr=None,
                 ln_bias_attr=None,
129
                 epsilon=1e-5,
130 131
                 nranks=1,
                 ring_id=-1,
132
                 name=None):
133
        super(FusedMultiHeadAttention, self).__init__()
134 135

        assert embed_dim > 0, ("Expected embed_dim to be greater than 0, "
136
                               "but received {}".format(embed_dim))
137
        assert num_heads > 0, ("Expected nhead to be greater than 0, "
138
                               "but received {}".format(num_heads))
139 140 141

        self.normalize_before = normalize_before
        self._dtype = self._helper.get_default_dtype()
142
        self._epsilon = epsilon
143
        self._ring_id = ring_id
144

145 146
        self.embed_dim = embed_dim
        self.num_heads = num_heads
147
        self.head_dim = embed_dim // num_heads
148 149 150
        self.kdim = kdim
        self.vdim = vdim
        self.need_weights = need_weights
151
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
152 153 154 155 156
        assert need_weights is False, "Only support need_weight is False now."

        # tensor model parallel
        assert num_heads % nranks == 0
        num_heads = num_heads // nranks
157 158 159

        self.qkv_weight = self.create_parameter(
            shape=[3, num_heads, self.head_dim, embed_dim],
160
            attr=qkv_weight_attr,
161 162 163 164
            dtype=self._dtype,
            is_bias=False)
        self.qkv_bias = self.create_parameter(
            shape=[3, num_heads, self.head_dim],
165
            attr=qkv_bias_attr,
166 167 168
            dtype=self._dtype,
            is_bias=True)
        self.linear_weight = self.create_parameter(
169 170
            shape=[num_heads * self.head_dim, embed_dim],
            attr=linear_weight_attr,
171 172 173 174
            dtype=self._dtype,
            is_bias=False)
        self.linear_bias = self.create_parameter(
            shape=[embed_dim],
175
            attr=linear_bias_attr,
176 177 178
            dtype=self._dtype,
            is_bias=True)

179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
        # tensor model parallel
        if nranks > 1:
            assert ring_id != -1
            # column parallel
            _set_var_distributed(self.qkv_weight)
            _set_var_distributed(self.qkv_bias)
            # row parallel
            _set_var_distributed(self.linear_weight)

        if normalize_before:
            self.pre_ln_scale = self.create_parameter(
                attr=pre_ln_scale_attr,
                shape=[embed_dim],
                default_initializer=Constant(value=1.0))
            self.pre_ln_bias = self.create_parameter(
                attr=pre_ln_bias_attr, shape=[embed_dim], is_bias=True)
            self.ln_scale = None
            self.ln_bias = None
        else:
            self.pre_ln_scale = None
            self.pre_ln_bias = None
            self.ln_scale = self.create_parameter(
                attr=ln_scale_attr,
                shape=[embed_dim],
                default_initializer=Constant(value=1.0))
            self.ln_bias = self.create_parameter(
                attr=ln_bias_attr, shape=[embed_dim], is_bias=True)
206 207 208 209 210

        self.dropout_rate = dropout_rate
        self.attn_dropout_rate = attn_dropout_rate

        self.name = name
211 212 213 214 215

    def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
        """
        Applies multi-head attention to map queries and a set of key-value pairs
        to outputs.
216

217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
        Parameters:
            query (Tensor): The queries for multi-head attention. It is a
                tensor with shape `[batch_size, query_length, embed_dim]`. The
                data type should be float32 or float64.
            key (Tensor, optional): The keys for multi-head attention. It is
                a tensor with shape `[batch_size, key_length, kdim]`. The
                data type should be float32 or float64. If None, use `query` as
                `key`. Default None.
            value (Tensor, optional): The values for multi-head attention. It
                is a tensor with shape `[batch_size, value_length, vdim]`.
                The data type should be float32 or float64. If None, use `query` as
                `value`. Default None.
            attn_mask (Tensor, optional): A tensor used in multi-head attention
                to prevents attention to some unwanted positions, usually the
                paddings or the subsequent positions. It is a tensor with shape
                broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
233 234 235 236 237
                When the data type is bool, the unwanted positions have `False`
                values and the others have `True` values. When the data type is
                int, the unwanted positions have 0 values and the others have 1
                values. When the data type is float, the unwanted positions have
                `-INF` values and the others have 0 values. It can be None when
238 239
                nothing wanted or needed to be prevented attention to. Default None.
            cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional):
240
                Now, only None is supported. Default None.
241

242 243
        Returns:
            Tensor|tuple: It is a tensor that has the same shape and data type \
244
                as `query`, representing attention output.
245
        """
246 247 248 249 250 251 252 253 254 255 256 257 258
        if attn_mask is not None:
            # Support bool or int mask
            attn_mask = _convert_attention_mask(attn_mask, query.dtype)

        out = incubate_f.fused_multi_head_attention(
            x=query,
            qkv_weight=self.qkv_weight,
            linear_weight=self.linear_weight,
            pre_layer_norm=self.normalize_before,
            pre_ln_scale=self.pre_ln_scale,
            pre_ln_bias=self.pre_ln_bias,
            ln_scale=self.ln_scale,
            ln_bias=self.ln_bias,
259
            pre_ln_epsilon=self._epsilon,
260 261
            qkv_bias=self.qkv_bias,
            linear_bias=self.linear_bias,
262
            cache_kv=cache,
263 264 265
            attn_mask=attn_mask,
            dropout_rate=self.dropout_rate,
            attn_dropout_rate=self.attn_dropout_rate,
266 267
            ln_epsilon=self._epsilon,
            training=self.training,
268
            ring_id=self._ring_id,
269
            name=self.name)
270
        return out
271

272 273 274 275 276 277 278
    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'embed_dim={}, num_heads={}, dropout_rate={}, attn_dropout_rate={}, epsilon={}, kdim={}, vdim={}, normalize_before={}, need_weights={}, dtype={}{}'.format(
            self.embed_dim, self.num_heads, self.dropout_rate,
            self.attn_dropout_rate, self._epsilon, self.kdim, self.vdim,
            self.normalize_before, self.need_weights, self._dtype, name_str)

279 280

class FusedFeedForward(Layer):
281 282 283 284 285 286
    """
    Parameters:
        d_model (int): The expected feature size in the input and output.
        dim_feedforward (int): The hidden layer size.
        dropout_rate (float, optional): The dropout probability used in pre-process
            and post-precess. Default 0.1
287 288
        epsilon (float, optional): he small value added to the variance to prevent
            division by zero. Default: 1e-05.
289 290 291 292 293
        activation (str, optional): The activation function. Default relu.
        act_dropout_rate (float, optional): The dropout probability after activition.
            If None, use the value of `dropout_rate`. Default None
        normalize_before (bool, optional): Indicate whether to put layer normalization
            into, preprocessing or postprocessing. Default False
294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
        linear1_weight_attr(ParamAttr, optional): To specify the weight parameter property
            for FFN first linear. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        linear1_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for FFN first linear. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        linear2_weight_attr(ParamAttr, optional): To specify the weight parameter property
            for FFN second linear. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        linear2_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for FFN second linear. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ln1_scale_attr(ParamAttr, optional): To specify the weight parameter property
            for FFN pre_layer_norm. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ln1_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for FFN pre_layer_norm. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ln2_scale_attr(ParamAttr, optional): To specify the weight parameter property
            for FFN post_layer_norm. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ln2_bias_attr(ParamAttr|bool, optional): To specify the bias parameter property
            for FFN layer_norm. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using tensor parallel.
        ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using tensor parallel.
        name (str, optional): The default value is None.  Normally there is no need for user to set
            this property. For more information, please refer to :ref:`api_guide_Name`.
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340

    Examples:
        .. code-block:: python

            # required: gpu
            import paddle
            from paddle.incubate.nn import FusedFeedForward

            fused_feedforward_layer = FusedFeedForward(8, 8)
            x = paddle.rand((1, 8, 8))
            out = fused_feedforward_layer(x)
            print(out.numpy().shape)
            # (1, 8, 8)
    """

341 342 343
    def __init__(self,
                 d_model,
                 dim_feedforward,
344
                 dropout_rate=0.1,
345
                 epsilon=1e-05,
346
                 activation="relu",
347
                 act_dropout_rate=None,
348
                 normalize_before=False,
349 350 351 352 353 354 355 356 357 358
                 linear1_weight_attr=None,
                 linear1_bias_attr=None,
                 linear2_weight_attr=None,
                 linear2_bias_attr=None,
                 ln1_scale_attr=None,
                 ln1_bias_attr=None,
                 ln2_scale_attr=None,
                 ln2_bias_attr=None,
                 nranks=1,
                 ring_id=-1,
359
                 name=None):
360 361

        super(FusedFeedForward, self).__init__()
362
        assert d_model > 0, (
363
            "Expected d_model to be greater than 0, but received {}".format(
364 365
                d_model))
        assert dim_feedforward > 0, (
366
            "Expected dim_feedforward to be greater than 0, but received {}".
367 368 369 370
            format(dim_feedforward))

        self._dtype = self._helper.get_default_dtype()
        self._d_model = d_model
371 372 373

        assert dim_feedforward % nranks == 0
        dim_feedforward = dim_feedforward // nranks
374 375 376 377 378
        self._dim_feedforward = dim_feedforward
        self._dropout_rate = dropout_rate
        self._act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate
        self._act_method = activation
        self._normalize_before = normalize_before
379
        self._epsilon = epsilon
380
        self._ring_id = ring_id
381 382 383

        self._linear1_weight = self.create_parameter(
            shape=[d_model, dim_feedforward],
384
            attr=linear1_weight_attr,
385 386 387 388
            dtype=self._dtype,
            is_bias=False)
        self._linear1_bias = self.create_parameter(
            shape=[dim_feedforward],
389
            attr=linear1_bias_attr,
390 391 392 393 394
            dtype=self._dtype,
            is_bias=True)

        self._linear2_weight = self.create_parameter(
            shape=[dim_feedforward, d_model],
395
            attr=linear2_weight_attr,
396 397 398 399 400
            dtype=self._dtype,
            is_bias=False)

        self._linear2_bias = self.create_parameter(
            shape=[d_model],
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
            attr=linear2_bias_attr,
            dtype=self._dtype,
            is_bias=True)

        if nranks > 1:
            assert ring_id != -1
            # column parallel
            _set_var_distributed(self._linear1_weight)
            _set_var_distributed(self._linear1_bias)
            _set_var_distributed(self._linear2_weight)

        if normalize_before:
            self._ln1_scale = self.create_parameter(
                shape=[d_model],
                attr=ln1_scale_attr,
                is_bias=False,
                default_initializer=Constant(1.0))
            self._ln1_bias = self.create_parameter(
                shape=[d_model], attr=ln1_bias_attr, is_bias=True)
            self._ln2_scale = None
            self._ln2_bias = None
        else:
            self._ln1_scale = None
            self._ln1_bias = None
            self._ln2_scale = self.create_parameter(
                shape=[d_model],
                attr=ln2_scale_attr,
                is_bias=False,
                default_initializer=Constant(1.0))
            self._ln2_bias = self.create_parameter(
                shape=[d_model], attr=ln2_bias_attr, is_bias=True)
432

433
        self.name = name
434 435

    def forward(self, src, cache=None):
436
        out = incubate_f.fused_feedforward(
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
            src,
            self._linear1_weight,
            self._linear2_weight,
            self._linear1_bias,
            self._linear2_bias,
            self._ln1_scale,
            self._ln1_bias,
            self._ln2_scale,
            self._ln2_bias,
            dropout1_rate=self._act_dropout_rate,
            dropout2_rate=self._dropout_rate,
            activation=self._act_method,
            ln1_epsilon=self._epsilon,
            ln2_epsilon=self._epsilon,
            pre_layer_norm=self._normalize_before,
            training=self.training,
453
            ring_id=self._ring_id,
454
            name=self.name)
455
        return out
456

457 458 459 460 461 462 463
    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'd_model={}, dim_feedforward={}, dropout_rate={}, epsilon={}, activation={}, act_dropout_rate={}, normalize_before={}, dtype={}{}'.format(
            self._d_model, self._dim_feedforward, self._dropout_rate,
            self._epsilon, self._act_method, self._act_dropout_rate,
            self._normalize_before, self._dtype, name_str)

464 465 466

class FusedTransformerEncoderLayer(Layer):
    """
467
    FusedTransformerEncoderLayer is composed of two sub-layers which are self (multi-head)
468 469 470 471 472 473 474 475 476 477
    attention and feedforward network. Before and after each sub-layer, pre-process
    and post-precess would be applied on the input and output accordingly. If
    `normalize_before` is True, pre-process is layer normalization and post-precess
    includes dropout, residual connection. Otherwise, no pre-process and post-precess
    includes dropout, residual connection, layer normalization.

    Parameters:
        d_model (int): The expected feature size in the input and output.
        nhead (int): The number of heads in multi-head attention(MHA).
        dim_feedforward (int): The hidden layer size in the feedforward network(FFN).
478
        dropout_rate (float, optional): The dropout probability used in pre-process
479 480 481
            and post-precess of MHA and FFN sub-layer. Default 0.1
        activation (str, optional): The activation function in the feedforward
            network. Default relu.
482
        attn_dropout_rate (float, optional): The dropout probability used
483 484
            in MHA to drop some attention target. If None, use the value of
            `dropout`. Default None
485
        act_dropout_rate (float, optional): The dropout probability used after FFN
486 487 488 489 490 491 492 493 494 495 496
            activition.  If None, use the value of `dropout`. Default None
        normalize_before (bool, optional): Indicate whether to put layer normalization
            into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer
            normalization and post-precess includes dropout, residual connection.
            Otherwise, no pre-process and post-precess includes dropout, residual
            connection, layer normalization. Default False
        weight_attr(ParamAttr|list|tuple, optional): To specify the weight parameter property.
            If it is a list/tuple, `weight_attr[0]` would be used as `weight_attr` for
            MHA, and `weight_attr[1]` would be used as `weight_attr` for linear in FFN.
            Otherwise, MHA and FFN both use it as `weight_attr` to create parameters.
            Default: None, which means the default weight parameter property is used.
497
            See usage for details in :code:`ParamAttr` .
498 499 500 501 502 503 504
        bias_attr (ParamAttr|list|tuple|bool, optional): To specify the bias parameter property.
            If it is a list/tuple, `bias_attr[0]` would be used as `bias_attr` for
            MHA, and `bias_attr[1]` would be used as `bias_attr` for linear in FFN.
            Otherwise, MHA and FFN both use it as `bias_attr` to create parameters.
            The `False` value means the corresponding layer would not have trainable
            bias parameter. See usage for details in :code:`ParamAttr` . Default: None,
            which means the default bias parameter property is used.
505

506 507 508 509

    Examples:

        .. code-block:: python
510

511
	    # required: gpu
512
            import paddle
513
            from paddle.incubate.nn import FusedTransformerEncoderLayer
514 515 516 517 518

            # encoder input: [batch_size, src_len, d_model]
            enc_input = paddle.rand((2, 4, 128))
            # self attention mask: [batch_size, n_head, src_len, src_len]
            attn_mask = paddle.rand((2, 2, 4, 4))
519
            encoder_layer = FusedTransformerEncoderLayer(128, 2, 512)
520 521 522 523 524 525 526
            enc_output = encoder_layer(enc_input, attn_mask)  # [2, 4, 128]
    """

    def __init__(self,
                 d_model,
                 nhead,
                 dim_feedforward,
527
                 dropout_rate=0.1,
528
                 activation="relu",
529 530
                 attn_dropout_rate=None,
                 act_dropout_rate=None,
531 532 533 534 535 536 537 538
                 normalize_before=False,
                 weight_attr=None,
                 bias_attr=None):
        self._config = locals()
        self._config.pop("self")
        self._config.pop("__class__", None)  # py3

        super(FusedTransformerEncoderLayer, self).__init__()
539
        assert d_model > 0, ("Expected d_model to be greater than 0, "
540
                             "but received {}".format(d_model))
541
        assert nhead > 0, ("Expected nhead to be greater than 0, "
542
                           "but received {}".format(nhead))
543 544
        assert dim_feedforward > 0, (
            "Expected dim_feedforward to be greater than 0, "
545
            "but received {}".format(dim_feedforward))
546 547 548 549 550 551 552 553 554 555
        attn_dropout_rate = dropout_rate if attn_dropout_rate is None else attn_dropout_rate
        act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate
        self.normalize_before = normalize_before

        weight_attrs = _convert_param_attr_to_list(weight_attr, 2)
        bias_attrs = _convert_param_attr_to_list(bias_attr, 2)

        self.fused_attn = FusedMultiHeadAttention(
            d_model,
            nhead,
556 557 558
            dropout_rate=dropout_rate,
            attn_dropout_rate=attn_dropout_rate,
            normalize_before=self.normalize_before,
559 560 561 562 563 564 565 566
            qkv_weight_attr=weight_attrs[0],
            qkv_bias_attr=bias_attrs[0],
            linear_weight_attr=weight_attrs[0],
            linear_bias_attr=bias_attrs[0],
            pre_ln_scale_attr=weight_attrs[0],
            pre_ln_bias_attr=bias_attrs[0],
            ln_scale_attr=weight_attrs[0],
            ln_bias_attr=bias_attrs[0])
567 568 569 570 571

        self.ffn = FusedFeedForward(
            d_model,
            dim_feedforward,
            dropout_rate=dropout_rate,
572
            activation=activation,
573 574
            act_dropout_rate=act_dropout_rate,
            normalize_before=self.normalize_before,
575 576 577 578
            linear1_weight_attr=weight_attrs[1],
            linear1_bias_attr=bias_attrs[1],
            linear2_weight_attr=weight_attrs[1],
            linear2_bias_attr=bias_attrs[1])
579 580 581 582 583 584 585 586 587 588 589 590

    def forward(self, src, src_mask=None, cache=None):
        """
        Applies a Transformer encoder layer on the input.
        Parameters:
            src (Tensor): The input of Transformer encoder layer. It is
                a tensor with shape `[batch_size, sequence_length, d_model]`.
                The data type should be float32 or float64.
            src_mask (Tensor, optional): A tensor used in multi-head attention
                to prevents attention to some unwanted positions, usually the
                paddings or the subsequent positions. It is a tensor with shape
                broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
591 592 593 594 595
                When the data type is bool, the unwanted positions have `False`
                values and the others have `True` values. When the data type is
                int, the unwanted positions have 0 values and the others have 1
                values. When the data type is float, the unwanted positions have
                `-INF` values and the others have 0 values. It can be None when
596 597 598 599 600 601 602 603 604 605 606 607 608 609
                nothing wanted or needed to be prevented attention to. Default None.
            cache (Tensor, optional): It is an instance of `MultiHeadAttention.Cache`.
                See `TransformerEncoderLayer.gen_cache` for more details. It is
                only used for inference and should be None for training. Default
                None.
        Returns:
            Tensor|tuple: It is a tensor that has the same shape and data type \
                as `enc_input`, representing the output of Transformer encoder \
                layer. Or a tuple if `cache` is not None, except for encoder \
                layer output, the tuple includes the new cache which is same \
                as input `cache` argument but `incremental_cache` has an \
                incremental length. See `MultiHeadAttention.gen_cache` and \
                `MultiHeadAttention.forward` for more details.
        """
610 611 612 613 614 615 616 617 618 619
        src_mask = _convert_attention_mask(src_mask, src.dtype)
        if cache is None:
            attn_out = self.fused_attn(src, attn_mask=src_mask)
        else:
            attn_out, incremental_cache = self.fused_attn(
                src, attn_mask=src_mask, cache=cache)

        ffn_out = self.ffn(attn_out)

        return ffn_out if cache is None else (ffn_out, incremental_cache)
620 621 622 623 624 625 626 627 628 629


class FusedTransformer(Layer):
    """
    A Transformer model composed of an instance of `TransformerEncoder` and an
    instance of `TransformerDecoder`. While the embedding layer and output layer
    are not included.

    Please refer to `Attention is all you need <http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf>`_ ,
    and see `TransformerEncoder` and `TransformerDecoder` for more details.
630

631 632 633 634
    Users can configurate the model architecture with corresponding parameters.
    Note the usage of `normalize_before` representing where to apply layer
    normalization (in pre-process or post-precess of multi-head attention or FFN),
    and some transformer like models are different on this, such as
635
    `BERT <https://arxiv.org/abs/1810.04805>`_ and `GPT2 <https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf>`_ .
636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660
    The default architecture here places layer normalization in post-process and
    applies another layer normalization on the output of last encoder/decoder layer.

    Parameters:
        d_model (int, optional): The expected feature size in the encoder/decoder input
            and output. Default 512
        nhead (int, optional): The number of heads in multi-head attention(MHA). Default 8
        num_encoder_layers (int, optional): The number of layers in encoder. Default 6
        num_decoder_layers (int, optional): The number of layers in decoder. Default 6
        dim_feedforward (int, optional): The hidden layer size in the feedforward network(FFN). Default 2048
        dropout (float, optional): The dropout probability used in pre-process
            and post-precess of MHA and FFN sub-layer. Default 0.1
        activation (str, optional): The activation function in the feedforward
            network. Default relu.
        attn_dropout (float, optional): The dropout probability used
            in MHA to drop some attention target. If None, use the value of
            `dropout`. Default None
        act_dropout (float, optional): The dropout probability used after FFN
            activition.  If None, use the value of `dropout`. Default None
        normalize_before (bool, optional): Indicate whether to put layer normalization
            into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer
            normalization and post-precess includes dropout, residual connection.
            Otherwise, no pre-process and post-precess includes dropout, residual
            connection, layer normalization. Default False
        weight_attr(ParamAttr|list|tuple, optional): To specify the weight parameter property.
661 662 663 664 665 666 667 668 669 670
            If it is a list/tuple, the length of `weight_attr` could be 1, 2 or 3. If it is 3,
            `weight_attr[0]` would be used as `weight_attr` for self attention, `weight_attr[1]`
            would be used as `weight_attr` for cross attention of `TransformerDecoder`,
            and `weight_attr[2]` would be used as `weight_attr` for linear in FFN.
            If it is 2, `weight_attr[0]` would be used as `weight_attr` both for self attention
            and cross attntion and `weight_attr[1]` would be used as `weight_attr` for
            linear in FFN. If it is 1, `weight_attr[0]` would be used as `weight_attr`
            for self attention, cross attention and linear in FFN. Otherwise,
            the three sub-layers all uses it as `weight_attr` to create parameters.
            Default: None, which means the default weight parameter property is used.
671
            See usage for details
672
            in :code:`ParamAttr` .
673
        bias_attr (ParamAttr|list|tuple|bool, optional): To specify the bias parameter property.
674 675 676 677 678 679 680 681 682 683 684
            If it is a list/tuple, the length of `bias_attr` could be 1, 2 or 3. If it is 3,
            `bias_attr[0]` would be used as `bias_attr` for self attention, `bias_attr[1]`
            would be used as `bias_attr` for cross attention of `TransformerDecoder`,
            and `bias_attr[2]` would be used as `bias_attr` for linear in FFN.
            If it is 2, `bias_attr[0]` would be used as `bias_attr` both for self attention
            and cross attntion and `bias_attr[1]` would be used as `bias_attr` for
            linear in FFN. If it is 1, `bias_attr[0]` would be used as `bias_attr`
            for self attention, cross attention and linear in FFN. Otherwise,
            the three sub-layers all uses it as `bias_attr` to create parameters.
            The `False` value means the corresponding layer would not have trainable
            bias parameter. See usage for details in :code:`ParamAttr` .
685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731
            Default: None,which means the default bias parameter property is used.
        custom_encoder (Layer, optional): If custom encoder is provided, use it as the encoder.
            Default None
        custom_decoder (Layer, optional): If custom decoder is provided, use it as the decoder.
            Default None

    Examples:

        .. code-block:: python

            import paddle
            from paddle.nn import Transformer

            # src: [batch_size, tgt_len, d_model]
            enc_input = paddle.rand((2, 4, 128))
            # tgt: [batch_size, src_len, d_model]
            dec_input = paddle.rand((2, 6, 128))
            # src_mask: [batch_size, n_head, src_len, src_len]
            enc_self_attn_mask = paddle.rand((2, 2, 4, 4))
            # tgt_mask: [batch_size, n_head, tgt_len, tgt_len]
            dec_self_attn_mask = paddle.rand((2, 2, 6, 6))
            # memory_mask: [batch_size, n_head, tgt_len, src_len]
            cross_attn_mask = paddle.rand((2, 2, 6, 4))
            transformer = Transformer(128, 2, 4, 4, 512)
            output = transformer(enc_input,
                                 dec_input,
                                 enc_self_attn_mask,
                                 dec_self_attn_mask,
                                 cross_attn_mask)  # [2, 6, 128]
    """

    def __init__(self,
                 d_model=512,
                 nhead=8,
                 num_encoder_layers=6,
                 num_decoder_layers=6,
                 dim_feedforward=2048,
                 dropout=0.1,
                 activation="relu",
                 attn_dropout=None,
                 act_dropout=None,
                 normalize_before=False,
                 weight_attr=None,
                 bias_attr=None,
                 custom_encoder=None,
                 custom_decoder=None):
        super(fusedTransformer, self).__init__()
732
        raise NotImplementedError()
733 734

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
735
        raise NotImplementedError()
736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921


class FusedMultiTransformer(Layer):
    """
    FusedMultiTransformer is composed of multi transformer layers which contains two
    sub-layers which are self (multi-head) attention and feedforward network. The
    function of one transformer layer is consistent with the following pseudo code:

    .. code-block:: python

        if pre_layer_norm:
            out = layer_norm(x)
            out = qkv_linear(out) + qkv_bias
        else:
            out = qkv_linear(x) + qkv_bias
        out = transpose(out, perm=[2, 0, 3, 1, 4])
        # extract q, k and v from out.
        q = out[0:1, ::]
        k = out[1:2, ::]
        v = out[2:3, ::]
        out = q * k^t
        out = attn_mask + out
        out = softmax(out)
        out = dropout(out)
        out = out * v
        out = transpose(out, perm=[0, 2, 1, 3])
        out = linear(out)
        if pre_layer_norm:
            out = x + dropout(out + bias)
        else:
            out = layer_norm(x + dropout(out + bias))

        residual = out;
        if pre_layer_norm:
            out = ffn_layer_norm(out)
        out = ffn1_linear(out)
        out = dropout(activation(out + ffn1_bias))
        out = ffn2_linear(out)
        out = residual + dropout(out + ffn2_bias)
        if not pre_layer_norm:
            out = ffn_layer_norm(out)

    Parameters:
        embed_dim (int): The expected feature size in the input and output.
        num_heads (int): The number of heads in multi-head attention(MHA).
        dim_feedforward (int): The hidden layer size in the feedforward network(FFN).
        dropout_rate (float, optional): The dropout probability used in pre-process
            and post-precess of MHA and FFN sub-layer. Default 0.0
        activation (str, optional): The activation function in the feedforward
            network. Default "gelu".
        normalize_before (bool, optional): Indicate whether to put layer normalization
            into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer
            normalization and post-precess includes dropout, residual connection.
            Otherwise, no pre-process and post-precess includes dropout, residual
            connection, layer normalization. Default True
        ln_scale_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
            for Attention layer_norm. For Attention layer_norm weight, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ln_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
            for Attention layer_norm. For Attention layer_norm bias, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        qkv_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
            for Attention qkv computation. For Attention qkv weight, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        qkv_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
            for Attention qkv computation. For Attention qkv bias, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        linear_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
            for Attention linear. For Attention linear weight, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        linear_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
            for Attention linear computation. For Attention linear bias, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ffn_ln_scale_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
            for FFN layer_norm. For FFN layer_norm weight, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ffn_ln_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
            for FFN layer_norm. For FFN layer_norm bias, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ffn1_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
            for FFN first linear. For FFN first linear weight, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ffn1_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
            for FFN first linear. For FFN first linear bias, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ffn2_weight_attrs(ParamAttr|list|tuple, optional): To specify the weight parameter property
            for FFN second linear. For FFN second linear weight, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. Default: None, which means the default weight
            parameter property is used. See usage for details in :code:`ParamAttr`.
        ffn2_bias_attrs(ParamAttr|list|tuple|bool, optional): To specify the bias parameter property
            for FFN second linear. For FFN second linear bias, if it is a list/tuple, `attrs[0]`
            would be used as `attr` for transformer layer 0, and `attrs[1]` would be used as
            `attr` for transformer layer 1,etc. Otherwise, all layers both use it as
            `attr` to create parameters. The `False` value means the corresponding layer would
            not have trainable bias parameter. Default: None, which means the default bias
            parameter property is used. See usage for details in :code:`ParamAttr`.
        epsilon (float, optional): Small float value added to denominator of the layer_norm to
            avoid dividing by zero. Default: 1e-05.
        num_layers (int, optional): The number of layers of the transformer. If `qkv_weight_attrs`
            is a list or tuple, the number of layers is obtained from `qkv_weight_attrs`. num_layers
            only takes effect when `qkv_weight_attrs` is not a list or tuple. Default: -1.
        nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using mp.
        ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using mp.
        name (str, optional): The default value is None.  Normally there is no need for user to set
            this property. For more information, please refer to :ref:`api_guide_Name`.

    Examples:

        .. code-block:: python

            # required: gpu
            import paddle
            from paddle.incubate.nn import FusedMultiTransformer

            # encoder input: [batch_size, src_len, d_model]
            enc_input = paddle.rand((2, 4, 128))
            # self attention mask: [batch_size, 1, src_len, src_len]
            attn_mask = paddle.rand((2, 1, 4, 4))
            encoder_layers = FusedMultiTransformer(128, 2, 512, num_layers=1)
            enc_output = encoder_layers(enc_input, attn_mask)  # [2, 4, 128]
    """

    def __init__(self,
                 embed_dim,
                 num_heads,
                 dim_feedforward,
                 dropout_rate=0.0,
                 activation="gelu",
                 normalize_before=True,
                 ln_scale_attrs=None,
                 ln_bias_attrs=None,
                 qkv_weight_attrs=None,
                 qkv_bias_attrs=None,
                 linear_weight_attrs=None,
                 linear_bias_attrs=None,
                 ffn_ln_scale_attrs=None,
                 ffn_ln_bias_attrs=None,
                 ffn1_weight_attrs=None,
                 ffn1_bias_attrs=None,
                 ffn2_weight_attrs=None,
                 ffn2_bias_attrs=None,
                 epsilon=1e-5,
                 num_layers=-1,
                 nranks=1,
                 ring_id=-1,
                 name=None):
        super(FusedMultiTransformer, self).__init__()

        assert embed_dim > 0, ("Expected embed_dim to be greater than 0, "
922
                               "but received {}".format(embed_dim))
923
        assert num_heads > 0, ("Expected nhead to be greater than 0, "
924
                               "but received {}".format(num_heads))
925
        assert dim_feedforward > 0, (
926
            "Expected dim_feedforward to be greater than 0, but received {}".
927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122
            format(dim_feedforward))

        self.normalize_before = normalize_before
        self._dtype = self._helper.get_default_dtype()
        self._epsilon = epsilon
        self._ring_id = ring_id

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        # tensor model parallel
        if nranks > 1:
            assert ring_id != -1
        assert num_heads % nranks == 0
        assert dim_feedforward % nranks == 0
        num_heads = num_heads // nranks
        dim_feedforward = dim_feedforward // nranks
        self._dim_feedforward = dim_feedforward

        if isinstance(qkv_weight_attrs, (list, tuple)):
            num_layers = len(qkv_weight_attrs)
        assert num_layers > 0

        self.ln_scales, self.ln_biases = [], []
        self.qkv_weights, self.qkv_biases = [], []
        self.linear_weights, self.linear_biases = [], []
        self.ffn_ln_scales, self.ffn_ln_biases = [], []
        self.ffn1_weights, self.ffn1_biases = [], []
        self.ffn2_weights, self.ffn2_biases = [], []

        def get_attr(attrs, idx):
            if isinstance(attrs, (list, tuple)):
                assert len(attrs) == num_layers
                return attrs[idx]
            return attrs

        for i in range(num_layers):
            ln_scale_attr = get_attr(ln_scale_attrs, i)
            ln_bias_attr = get_attr(ln_bias_attrs, i)
            qkv_weight_attr = get_attr(qkv_weight_attrs, i)
            qkv_bias_attr = get_attr(qkv_bias_attrs, i)
            linear_weight_attr = get_attr(linear_weight_attrs, i)
            linear_bias_attr = get_attr(linear_bias_attrs, i)

            ffn_ln_scale_attr = get_attr(ffn_ln_scale_attrs, i)
            ffn_ln_bias_attr = get_attr(ffn_ln_bias_attrs, i)
            ffn1_weight_attr = get_attr(ffn1_weight_attrs, i)
            ffn1_bias_attr = get_attr(ffn1_bias_attrs, i)
            ffn2_weight_attr = get_attr(ffn2_weight_attrs, i)
            ffn2_bias_attr = get_attr(ffn2_bias_attrs, i)

            ln_scale = self.create_parameter(
                attr=ln_scale_attr,
                shape=[embed_dim],
                default_initializer=Constant(value=1.0))
            ln_bias = self.create_parameter(
                attr=ln_bias_attr, shape=[embed_dim], is_bias=True)
            qkv_weight = self.create_parameter(
                shape=[3, num_heads, self.head_dim, embed_dim],
                attr=qkv_weight_attr,
                dtype=self._dtype,
                is_bias=False)
            qkv_bias = self.create_parameter(
                shape=[3, num_heads, self.head_dim],
                attr=qkv_bias_attr,
                dtype=self._dtype,
                is_bias=True)
            linear_weight = self.create_parameter(
                shape=[num_heads * self.head_dim, embed_dim],
                attr=linear_weight_attr,
                dtype=self._dtype,
                is_bias=False)
            linear_bias = self.create_parameter(
                shape=[embed_dim],
                attr=linear_bias_attr,
                dtype=self._dtype,
                is_bias=True)

            ffn_ln_scale = self.create_parameter(
                shape=[embed_dim],
                attr=ffn_ln_scale_attr,
                is_bias=False,
                default_initializer=Constant(1.0))
            ffn_ln_bias = self.create_parameter(
                shape=[embed_dim], attr=ffn_ln_bias_attr, is_bias=True)
            ffn1_weight = self.create_parameter(
                shape=[embed_dim, dim_feedforward],
                attr=ffn1_weight_attr,
                dtype=self._dtype,
                is_bias=False)
            ffn1_bias = self.create_parameter(
                shape=[dim_feedforward],
                attr=ffn1_bias_attr,
                dtype=self._dtype,
                is_bias=True)
            ffn2_weight = self.create_parameter(
                shape=[dim_feedforward, embed_dim],
                attr=ffn2_weight_attr,
                dtype=self._dtype,
                is_bias=False)
            ffn2_bias = self.create_parameter(
                shape=[embed_dim],
                attr=ffn2_bias_attr,
                dtype=self._dtype,
                is_bias=True)

            # tensor model parallel
            if nranks > 1:
                # column parallel
                _set_var_distributed(qkv_weight)
                _set_var_distributed(qkv_bias)
                _set_var_distributed(ffn1_weight)
                _set_var_distributed(ffn1_bias)
                # row parallel
                _set_var_distributed(linear_weight)
                _set_var_distributed(ffn2_weight)

            self.ln_scales.append(ln_scale)
            self.ln_biases.append(ln_bias)
            self.qkv_weights.append(qkv_weight)
            self.qkv_biases.append(qkv_bias)
            self.linear_weights.append(linear_weight)
            self.linear_biases.append(linear_bias)

            self.ffn_ln_scales.append(ffn_ln_scale)
            self.ffn_ln_biases.append(ffn_ln_bias)
            self.ffn1_weights.append(ffn1_weight)
            self.ffn1_biases.append(ffn1_bias)
            self.ffn2_weights.append(ffn2_weight)
            self.ffn2_biases.append(ffn2_bias)

        self.dropout_rate = dropout_rate
        self.activation = activation
        self.name = name

    def forward(self, src, attn_mask=None, caches=None, time_step=None):
        """
        Applies multi transformer layers on the input.

        Parameters:
            src (Tensor): The input of Transformer layers. It is
                a tensor with shape `[batch_size, sequence_length, d_model]`.
                The data type should be float16 or float32.
            attn_mask (Tensor, optional): A tensor used in multi-head attention
                to prevents attention to some unwanted positions, usually the
                paddings or the subsequent positions. It is a tensor with shape
                `[batch_size, 1, sequence_length, sequence_length]`. It can be
                None when nothing wanted or needed to be prevented attention to.
                Default None.
            caches (list(Tensor)|tuple(Tensor), optional): The cache structure
                tensors for the inference generation model. It is only used for
                inference and should be None for training. The shape is
                `[2, batch_size, num_head, max_seq_len, head_dim]`. Default None.
            time_step (Tensor, optional): The time step tensor for the generation
                model. Which used in decode stage, to represent the time step,
                that is, the real seq_len of CacheKV. The shape is `[1]`, must be
                in CPUPlace. Default None.

        Returns:
            Tensor|tuple: If `caches` is None, return a tensor that has
            the same shape and data type with `src`, representing the output
            of Transformer layers. If `caches` is not None, return the
            tuple (output, caches), which output is the output of
            Transformer layers, caches is inplace with input `caches`.
        """

        if caches is not None:
            assert len(caches) == len(self.qkv_weights)
        out = incubate_f.fused_multi_transformer(
            src,
            self.ln_scales,
            self.ln_biases,
            self.qkv_weights,
            self.qkv_biases,
            self.linear_weights,
            self.linear_biases,
            self.ffn_ln_scales,
            self.ffn_ln_biases,
            self.ffn1_weights,
            self.ffn1_biases,
            self.ffn2_weights,
            self.ffn2_biases,
            pre_layer_norm=self.normalize_before,
            epsilon=self._epsilon,
            cache_kvs=caches,
            time_step=time_step,
            attn_mask=attn_mask,
            dropout_rate=self.dropout_rate,
            activation=self.activation,
            training=self.training,
            mode='upscale_in_train',
            ring_id=self._ring_id,
            name=self.name)
        return out