fused_transformer.py 45.9 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from paddle import _legacy_C_ops
16 17 18 19
from paddle.fluid import core
from paddle.fluid.data_feeder import check_dtype, check_variable_and_dtype
from paddle.fluid.framework import _non_static_mode, default_main_program
from paddle.fluid.layer_helper import LayerHelper
L
Li Min 已提交
20 21 22 23

__all__ = []


24 25 26 27 28 29 30
def _verify_dropout_rate(dropout_rate):
    if not isinstance(dropout_rate, (float, int)):
        raise TypeError("dropout_rate argument should be a number")
    if dropout_rate < 0 or dropout_rate > 1:
        raise ValueError("dropout_rate argument should between 0 and 1")


31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
def fused_feedforward(
    x,
    linear1_weight,
    linear2_weight,
    linear1_bias=None,
    linear2_bias=None,
    ln1_scale=None,
    ln1_bias=None,
    ln2_scale=None,
    ln2_bias=None,
    dropout1_rate=0.5,
    dropout2_rate=0.5,
    activation="relu",
    ln1_epsilon=1e-5,
    ln2_epsilon=1e-5,
    pre_layer_norm=False,
    training=True,
    mode='upscale_in_train',
    ring_id=-1,
    add_residual=True,
    name=None,
):
53
    r"""
54 55 56 57 58 59
    This is a fusion operator to compute feed forward layer in transformer model architecture.
    This operator only supports running on GPU. The function of the operator is consistent with
    the following pseudo code:

    .. code-block:: python

60
        residual = x
61
        if pre_layer_norm:
62 63 64 65 66 67 68 69
            out = layer_norm1(x)
        else:
            out = x
        out = linear2(dropout1(activation(linear1(src))))
        if add_residual:
            out = residual + dropout2(out)
        else:
            out = dropout2(out)
70
        if not pre_layer_norm:
71 72
            out = layer_norm2(out)

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

    Args:
        x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16, float32 or float64, the shape is`[batch\_size, sequence\_length, d_model]`.
        linear1_weight (Tensor): The weight of first linear, the data type is same as `x`, the shape is `[d\_model, dim\_feedforward]`.
        linear2_weight (Tensor): The weight of second linear, the data type is same as `x`, the shape is `[dim\_feedforward, d\_model]`.
        linear1_bias (Tensor, optional): The bias of first linear, the data type is same as `x`, the shape is `[dim_feedforward]`. Default None.
        linear2_bias (Tensor, optional): The bias of second linear, the data type is same as `x`, the shape is `[d_model]`. Default None.
        ln1_scale (Tensor, optional): the weight of first layer_norm, the data type is float32 or float64, the shape is same as `x`. Default None.
        ln1_bias (Tensor, optional): The bias of first layer_norm, the data type is float32 or float64, the shape is `[d\_model]`. Default None.
        ln2_scale (Tensor, optional): The weight of second layer_norm, the data type is float32 or float64, the shape is same as `x`. Default None.
        ln2_bias (Tensor, optional): The bias of second layer_norm, the data type is float32 or float64, the shape is `[d\_model]`. Default None.
        dropout1_rate (float, optional): The first dropout probability of setting units to zero. Default 0.5.
        dropout2_rate (float, optional): The second dropout probability of setting units to zero. Default 0.5.
        activation (str, optional): The activation. Default "relu".
        ln1_epsilon (float, optional): Small float of first layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
        ln2_epsilon (float, optional): Small float of second layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
        pre_layer_norm (bool, optional): add layer_norm in the pre-processing stage or post-processing state.
L
Li Min 已提交
90 91 92 93 94 95 96 97 98 99 100 101
        training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
        mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']

                               1. upscale_in_train(default), upscale the output at training time

                                  - train: out = input * mask / ( 1.0 - p )
                                  - inference: out = input

                               2. downscale_in_infer, downscale the output at inference

                                  - train: out = input * mask
                                  - inference: out = input * (1.0 - p)
102
        ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using tensor parallel.
103
        add_residual (bool, optional): Whether add residual at the end. Default is True.
104 105 106 107 108 109 110 111 112 113
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor: The output Tensor, the data type and shape is same as `x`.

    Examples:
        .. code-block:: python

            # required: gpu
            import paddle
114 115 116 117 118 119 120
            import paddle.incubate.nn.functional as F

            x = paddle.randn(shape=(1, 8, 8), dtype="float32")
            linear1_weight = paddle.randn(shape=(8, 8), dtype="float32")
            linear2_weight = paddle.randn(shape=(8, 8), dtype="float32")
            out = F.fused_feedforward(x, linear1_weight, linear2_weight)
            print(out.shape)
121 122 123 124 125
            # (1, 8, 8)
    """
    _verify_dropout_rate(dropout1_rate)
    _verify_dropout_rate(dropout2_rate)

126 127 128
    seed = None
    if mode not in ('downscale_in_infer', 'upscale_in_train'):
        raise ValueError(
129 130
            "mode argument should be 'downscale_in_infer' or 'upscale_in_train'"
        )
131 132 133
    mode = (
        'downgrade_in_infer' if mode == 'downscale_in_infer' else mode
    )  # semantic transfer
134

J
Jiabin Yang 已提交
135
    if _non_static_mode():
136 137
        if default_main_program().random_seed != 0:
            seed = default_main_program().random_seed
138
        out, _, _, _, _, _, _, _, _, _, _ = _legacy_C_ops.fused_feedforward(
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
            x,
            None,
            None,
            linear1_weight,
            linear1_bias,
            linear2_weight,
            linear2_bias,
            ln1_scale,
            ln1_bias,
            ln2_scale,
            ln2_bias,
            'pre_layer_norm',
            pre_layer_norm,
            'ln1_epsilon',
            ln1_epsilon,
            'ln2_epsilon',
            ln2_epsilon,
            'act_method',
            activation,
            'dropout1_rate',
            dropout1_rate,
            'dropout2_rate',
            dropout2_rate,
            "is_test",
            not training,
            "dropout1_fix_seed",
            seed is not None,
            "dropout2_fix_seed",
            seed is not None,
            "dropout1_seed",
            seed if seed is not None else 0,
            "dropout2_seed",
            seed if seed is not None else 0,
            'dropout1_implementation',
            mode,
            'dropout2_implementation',
            mode,
            'add_residual',
            add_residual,
            'ring_id',
            ring_id,
        )
181 182 183 184
        return out

    helper = LayerHelper("fused_feedforward")
    dtype = x.dtype
185 186 187 188 189 190
    check_variable_and_dtype(
        x, 'x', ['float16', 'float32', 'float64'], 'fused_feedforward'
    )
    check_dtype(
        dtype, 'dtype', ['float16', 'float32', 'float64'], 'fused_feedforward'
    )
191 192 193

    out = helper.create_variable_for_type_inference(x.dtype)
    dropout1_mask = helper.create_variable_for_type_inference(
194 195
        'uint8', stop_gradient=True
    )
196
    dropout2_mask = helper.create_variable_for_type_inference(
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
        'uint8', stop_gradient=True
    )
    ln1_mean = helper.create_variable_for_type_inference(
        x.dtype, stop_gradient=True
    )
    ln1_variance = helper.create_variable_for_type_inference(
        x.dtype, stop_gradient=True
    )
    ln2_mean = helper.create_variable_for_type_inference(
        x.dtype, stop_gradient=True
    )
    ln2_variance = helper.create_variable_for_type_inference(
        x.dtype, stop_gradient=True
    )
    linear1_out = helper.create_variable_for_type_inference(
        x.dtype, stop_gradient=True
    )
    ln1_out = helper.create_variable_for_type_inference(
        x.dtype, stop_gradient=True
    )
    dropout1_out = helper.create_variable_for_type_inference(
        x.dtype, stop_gradient=True
    )
    dropout2_out = helper.create_variable_for_type_inference(
        x.dtype, stop_gradient=True
    )
223

224 225 226
    if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
        seed = helper.main_program.random_seed

227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
    helper.append_op(
        type='fused_feedforward',
        inputs={
            'X': x,
            'Linear1Weight': linear1_weight,
            'Linear1Bias': linear1_bias,
            'Linear2Weight': linear2_weight,
            'Linear2Bias': linear2_bias,
            'Ln1Scale': ln1_scale,
            'Ln1Bias': ln1_bias,
            'Ln2Scale': ln2_scale,
            'Ln2Bias': ln2_bias,
        },
        outputs={
            'Out': out,
            'Dropout1Mask': dropout1_mask,
            'Dropout2Mask': dropout2_mask,
            'Ln1Mean': ln1_mean,
            'Ln1Variance': ln1_variance,
            'Ln2Mean': ln2_mean,
            'Ln2Variance': ln2_variance,
            'Linear1Out': linear1_out,
            'Ln1Out': ln1_out,
            'Dropout1Out': dropout1_out,
            'Dropout2Out': dropout2_out,
        },
        attrs={
            'dropout1_rate': dropout1_rate,
            'dropout2_rate': dropout2_rate,
            'act_method': activation,
            'pre_layer_norm': pre_layer_norm,
            'ln1_epsilon': ln1_epsilon,
            'ln2_epsilon': ln2_epsilon,
            'is_test': not training,
            'dropout1_fix_seed': seed is not None,
            'dropout2_fix_seed': seed is not None,
            'dropout1_seed': seed if seed is not None else 0,
            'dropout2_seed': seed if seed is not None else 0,
            'dropout1_implementation': mode,
            'dropout2_implementation': mode,
            'add_residual': add_residual,
            'ring_id': ring_id,
        },
    )
271 272 273
    return out


274 275 276 277 278 279 280 281 282 283 284 285
def fused_bias_dropout_residual_layer_norm(
    x,
    residual,
    bias=None,
    ln_scale=None,
    ln_bias=None,
    dropout_rate=0.5,
    ln_epsilon=1e-5,
    training=True,
    mode='upscale_in_train',
    name=None,
):
286
    r"""
U
ustiniankw 已提交
287

288 289 290
    The fused_bias_dropout_residual_layer_norm operator. The pseudo code is as follows:

    .. code-block:: python
U
ustiniankw 已提交
291

292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
        y = layer_norm(residual + dropout(bias + x))

    Parameters:
        x (Tensor): The input tensor. The shape is `[*, embed\_dim]`.
        residual (Tensor): The residual tensor. The shape is same as x.
        bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
        ln_scale (Tensor, optional): The weight tensor of layernorm. The shape is `[embed_dim]`. Default None.
        ln_bias (Tensor, optional): The bias tensor of layernorm. The shape is `[embed_dim]`. Default None.
        dropout_rate (float, optional): The dropout probability used on attention
            weights to drop some attention targets for the dropout after attention.
            0 for no dropout. Default 0.5.
        ln_epsilon (float, optional): Small float value added to denominator of layer_norm
            to avoid dividing by zero. Default is 1e-5.
        training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
        mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']

                               1. upscale_in_train(default), upscale the output at training time

                                  - train: out = input * mask / ( 1.0 - p )
                                  - inference: out = input

                               2. downscale_in_infer, downscale the output at inference

                                  - train: out = input * mask
                                  - inference: out = input * (1.0 - p)
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Returns:
U
ustiniankw 已提交
320
        Tensor, The output Tensor, the data type and shape is same as `x`.
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339

    Examples:
        .. code-block:: python

            # required: gpu
            import paddle
            import paddle.incubate.nn.functional as F

            # input: [batch_size, seq_len, embed_dim]
            x = paddle.rand(shape=(2, 4, 128), dtype="float32")
            # residual: [batch_size, seq_len, embed_dim]
            residual = paddle.rand(shape=(2, 4, 128), dtype="float32")
            # linear bias: [embed_dim]
            bias = paddle.rand(shape=[128], dtype="float32")
            # output: [batch_size, seq_len, embed_dim]
            output = F.fused_bias_dropout_residual_layer_norm(
                x, residual, bias)
            # [2, 4, 128]
            print(output.shape)
U
ustiniankw 已提交
340

341 342 343 344
    """
    seed = None
    if mode not in ('downscale_in_infer', 'upscale_in_train'):
        raise ValueError(
345 346
            "mode argument should be 'downscale_in_infer' or 'upscale_in_train'"
        )
347 348 349
    mode = (
        'downgrade_in_infer' if mode == 'downscale_in_infer' else mode
    )  # semantic transfer
350 351

    if ln_scale is not None:
352 353 354 355 356 357
        assert (
            len(ln_scale.shape) == 1
        ), "The dims of the shape of ln_scale should be 1."
        assert (
            x.shape[len(x.shape) - 1] == ln_scale.shape[0]
        ), "The dim of ln_scale must equal to the last dim of x."
358
    if ln_bias is not None:
359 360 361 362 363 364
        assert (
            len(ln_bias.shape) == 1
        ), "The dims of the shape of ln_bias should be 1."
        assert (
            x.shape[len(x.shape) - 1] == ln_bias.shape[0]
        ), "The dim of ln_bias must equal to the last dim of x."
365 366 367 368

    if _non_static_mode():
        if default_main_program().random_seed != 0:
            seed = default_main_program().random_seed
369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
        (
            _,
            _,
            _,
            _,
            final_out,
        ) = _legacy_C_ops.fused_bias_dropout_residual_layer_norm(
            x,
            residual,
            bias,
            ln_scale,
            ln_bias,
            'dropout_rate',
            dropout_rate,
            'ln_epsilon',
            ln_epsilon,
            'is_test',
            not training,
            'dropout_fix_seed',
            seed is not None,
            'dropout_seed',
            seed if seed is not None else 0,
            'dropout_implementation',
            mode,
        )
394 395
        return final_out
    else:
396 397 398
        helper = LayerHelper(
            'fused_bias_dropout_residual_layer_norm', **locals()
        )
399 400
        dtype = x.dtype
        # check dtypes
401 402 403 404 405 406 407 408 409 410 411 412
        check_variable_and_dtype(
            x,
            'x',
            ['float16', 'float32', 'float64'],
            'fused_bias_dropout_residual_layer_norm',
        )
        check_dtype(
            dtype,
            'dtype',
            ['float16', 'float32', 'float64'],
            'fused_bias_dropout_residual_layer_norm',
        )
413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428
        # set inputs
        inputs = dict()
        inputs['X'] = [x]
        inputs['Residual'] = [residual]
        if bias is not None:
            inputs['Bias'] = [bias]
        if ln_scale:
            inputs['LnScale'] = [ln_scale]
        if ln_bias:
            inputs['LnBias'] = [ln_bias]
        if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
            seed = helper.main_program.random_seed
        # set attrs
        attrs = {
            'ln_epsilon': ln_epsilon,
            'dropout_rate': dropout_rate,
L
Li Min 已提交
429
            'is_test': not training,
430 431 432 433 434 435
            'dropout_fix_seed': seed is not None,
            'dropout_seed': seed if seed is not None else 0,
            'dropout_implementation': mode,
        }
        # set outputs
        dropout_mask_out = helper.create_variable_for_type_inference(
436 437
            dtype=core.VarDesc.VarType.UINT8, stop_gradient=True
        )
438
        ln_mean_out = helper.create_variable_for_type_inference(
439 440
            dtype=dtype, stop_gradient=True
        )
441
        ln_variance_out = helper.create_variable_for_type_inference(
442 443
            dtype=dtype, stop_gradient=True
        )
444
        bias_dropout_residual_out = helper.create_variable_for_type_inference(
445 446
            dtype=dtype
        )
447 448
        final_out = helper.create_variable_for_type_inference(dtype=dtype)

449 450 451 452 453 454 455 456 457 458 459 460
        helper.append_op(
            type='fused_bias_dropout_residual_layer_norm',
            inputs=inputs,
            outputs={
                "BiasDropoutResidualOut": bias_dropout_residual_out,
                "DropoutMaskOut": dropout_mask_out,
                "LnMean": ln_mean_out,
                "LnVariance": ln_variance_out,
                'Y': final_out,
            },
            attrs=attrs,
        )
461 462 463
        return final_out


464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484
def fused_multi_head_attention(
    x,
    qkv_weight,
    linear_weight,
    pre_layer_norm=False,
    pre_ln_scale=None,
    pre_ln_bias=None,
    ln_scale=None,
    ln_bias=None,
    pre_ln_epsilon=1e-05,
    qkv_bias=None,
    linear_bias=None,
    cache_kv=None,
    attn_mask=None,
    dropout_rate=0.5,
    attn_dropout_rate=0.5,
    ln_epsilon=1e-05,
    training=True,
    mode='upscale_in_train',
    ring_id=-1,
    add_residual=True,
485 486
    num_heads=-1,
    transpose_qkv_wb=False,
487 488
    name=None,
):
489
    r"""
L
Li Min 已提交
490 491
    Attention mapps queries and a set of key-value pairs to outputs, and
    Multi-Head Attention performs multiple parallel attention to jointly attending
492
    to information from different representation subspaces. This API only
L
Li Min 已提交
493
    support self_attention. The pseudo code is as follows:
494 495 496

    .. code-block:: python

497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519
        residual = x
        if pre_layer_norm:
            out = layer_norm(x)
        else:
            out = x
        # compute q, k, v
        out = matmul(out, qkv_weight) + qkv_bias
        out = transpose(out, perm=[2, 0, 3, 1, 4])
        # extract q, k and v from out
        q = out[0:1,::] * (head_dim ** -0.5)
        k = out[1:2,::]
        v = out[2:3,::]
        out = matmul(q, k, transpose_y=True)
        out = out + attn_mask
        out = softmax(out)
        out = dropout(out)
        out = matmul(out, v)
        # combine heads
        out = transpose(out, perm=[0, 2, 1, 3])
        # project to output
        out = linear(out)
        if add_residual:
            out = residual + dropout(out)
520
        else:
521 522 523 524
            out = dropout(out)
        if not pre_layer_norm:
            out = layer_norm(out)

L
Li Min 已提交
525 526

    Parameters:
527
        x (Tensor): The input tensor of fused_multi_head_attention. The shape is
L
Li Min 已提交
528 529 530
            `[batch\_size, sequence\_len, embed\_dim]`.
        qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`.
        linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`.
531
        pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture
532
                                        (False). Default False.
L
Li Min 已提交
533 534 535 536
        pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None.
        pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None.
        ln_scale (Tensor, optional): The weight tensor of layernorm. Default None.
        ln_bias (Tensor, optional): The bias tensor of layernorm. Default None.
537
        pre_ln_epsilon (float, optional): Small float value added to denominator of the pre layer_norm
L
Li Min 已提交
538
            to avoid dividing by zero. Default is 1e-5.
539
        qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`.
L
Li Min 已提交
540 541
            Default None.
        linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
542
        cache_kv (Tensor, optional): For generation model, cache structure. The shape is `[2, bsz, num_head, seq_len, head_dim]`. Default None.
543
        attn_mask (Tensor, optional):  A tensor used in multi-head attention to prevents attention to
544
            some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
545 546 547 548
            with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the
            data type is bool, the unwanted positions have `False` values and the others have `True` values.
            When the data type is int, the unwanted positions have 0 values and the others have 1 values.
            When the data type is float, the unwanted positions have `-INF` values and the others have 0 values.
549
            It can be None when nothing wanted or needed to be prevented attention to. Default None.
L
Li Min 已提交
550
        dropout_rate (float, optional): The dropout probability used on attention
551
            weights to drop some attention targets for the dropout after attention.
552
            0 for no dropout. Default 0.5.
L
Li Min 已提交
553
        attn_dropout_rate (float, optional): The dropout probability used on attention
554
            weights to drop some attention targets for the dropout in attention.
555
            0 for no dropout. Default 0.5.
556
        ln_epsilon (float, optional): Small float value added to denominator of layer_norm
L
Li Min 已提交
557
            to avoid dividing by zero. Default is 1e-5.
L
Li Min 已提交
558 559 560 561 562 563 564 565 566 567 568 569
        training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
        mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']

                               1. upscale_in_train(default), upscale the output at training time

                                  - train: out = input * mask / ( 1.0 - p )
                                  - inference: out = input

                               2. downscale_in_infer, downscale the output at inference

                                  - train: out = input * mask
                                  - inference: out = input * (1.0 - p)
570
        ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp
571
        add_residual (bool, optional): Whether add residual at the end. Default is True.
572 573
        num_heads (int, optional): If enable transpose_qkv_wb, should provide the num_heads. Default is -1, means not transpose qkv wb.
        transpose_qkv_wb (bool, optional): Whether transpose the qkv_weight and qkv_bias in the op. Only support GPU for now. Default is false, means not transpose qkv wb.
574
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
575

576 577 578
    Returns:
        Tensor: The output Tensor, the data type and shape is same as `x`.

L
Li Min 已提交
579 580 581
    Examples:

        .. code-block:: python
582 583

            # required: gpu
L
Li Min 已提交
584
            import paddle
585
            import paddle.incubate.nn.functional as F
L
Li Min 已提交
586 587 588

            # input: [batch_size, seq_len, embed_dim]
            x = paddle.rand(shape=(2, 4, 128), dtype="float32")
589
            # qkv_weight: [3, num_head, head_dim, embed_dim]
L
Li Min 已提交
590
            qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32")
591
            # qkv_bias: [3, num_head, head_dim]
L
Li Min 已提交
592 593 594 595 596 597 598 599 600 601 602 603
            qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32")
            # linear_weight: [embed_dim, embed_dim]
            linear_weight = paddle.rand(shape=(128, 128), dtype="float32")
            # linear_bias: [embed_dim]
            linear_bias = paddle.rand(shape=[128], dtype="float32")
            # self attention mask: [batch_size, num_heads, seq_len, seq_len]
            attn_mask = paddle.rand(shape=(2, 4, 4, 4), dtype="float32")

            # output: [batch_size, seq_len, embed_dim]
            output = F.fused_multi_head_attention(
                x, qkv_weight, linear_weight, False,
                None, None, None, None, 1e-5, qkv_bias,
604
                linear_bias, None, attn_mask)
L
Li Min 已提交
605 606 607
            # [2, 4, 128]
            print(output.shape)
    """
608 609 610 611

    seed = None
    if mode not in ('downscale_in_infer', 'upscale_in_train'):
        raise ValueError(
612 613
            "mode argument should be 'downscale_in_infer' or 'upscale_in_train'"
        )
614 615 616
    mode = (
        'downgrade_in_infer' if mode == 'downscale_in_infer' else mode
    )  # semantic transfer
617

J
Jiabin Yang 已提交
618
    if _non_static_mode():
619 620
        if default_main_program().random_seed != 0:
            seed = default_main_program().random_seed
621 622
        # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out,
        # qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out,
L
Li Min 已提交
623
        # linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out
624
        if not transpose_qkv_wb:
625
            assert (
626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665
                len(qkv_weight.shape) == 4
            ), "The dims of the shape of qkv_weight should be 4."
            assert (
                qkv_weight.shape[0] == 3
            ), "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]."
            assert (
                qkv_weight.shape[3] == x.shape[2]
            ), "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim."
            if ring_id == -1:
                # under mp, the num head will be split, this equation will not hold
                assert (
                    qkv_weight.shape[1] * qkv_weight.shape[2]
                    == qkv_weight.shape[3]
                ), "embed_dim must be divisible by num_heads."
        else:
            assert (
                num_heads > 0
            ), "When enable transpose_qkv_wb, the num_heads should be provided and greater than 0."
            assert len(qkv_weight.shape) == 2, (
                "When enable transpose_qkv_wb, the dims of the shape of qkv_weight "
                "should be 2 when enable transpose_qkv_wb."
            )
            if ring_id == -1:
                # under mp, the num head will be split, this equation will not hold
                assert qkv_weight.shape[1] == 3 * qkv_weight.shape[0], (
                    "When enable transpose_qkv_wb, the shape of qkv_weight should be "
                    "[embed_dim, 3 * embed_dim] when enable transpose_qkv_wb."
                )
            assert qkv_weight.shape[0] == x.shape[2], (
                "When enable transpose_qkv_wb, the 1st dim of qkv_weight and 2nd dim of x "
                "should be the same, i.e., embed_dim."
            )
            if qkv_bias is not None:
                assert (
                    len(qkv_bias.shape) == 1
                ), "When enable transpose_qkv_wb, the dims of the shape of qkv_bias should be 1."
                assert qkv_bias.shape[0] == qkv_weight.shape[1], (
                    "When enable transpose_qkv_wb, the 1st dim of qkv_bias and 2nd dim of "
                    "qkv_weight should be the same, i.e., embed_dim."
                )
666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698
        (
            _,
            _,
            _,
            _,
            _,
            _,
            _,
            _,
            _,
            _,
            _,
            _,
            _,
            _,
            _,
            _,
            _,
            _,
            cache_kv_out,
            final_out,
        ) = _legacy_C_ops.fused_attention(
            x,
            pre_ln_scale,
            pre_ln_bias,
            qkv_weight,
            qkv_bias,
            cache_kv,
            attn_mask,
            linear_weight,
            linear_bias,
            ln_scale,
            ln_bias,
699 700 701 702
            'num_heads',
            num_heads,
            'transpose_qkv_wb',
            transpose_qkv_wb,
703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731
            'pre_layer_norm',
            pre_layer_norm,
            'epsilon',
            pre_ln_epsilon,
            'dropout_rate',
            dropout_rate,
            'attn_dropout_rate',
            attn_dropout_rate,
            'ln_epsilon',
            ln_epsilon,
            'is_test',
            not training,
            'attn_dropout_fix_seed',
            seed is not None,
            'dropout_fix_seed',
            seed is not None,
            'attn_dropout_seed',
            seed if seed is not None else 0,
            'dropout_seed',
            seed if seed is not None else 0,
            'attn_dropout_implementation',
            mode,
            'dropout_implementation',
            mode,
            'add_residual',
            add_residual,
            'ring_id',
            ring_id,
        )
732 733
        if cache_kv is not None:
            return final_out, cache_kv_out
L
Li Min 已提交
734
        return final_out
735 736 737 738
    else:
        helper = LayerHelper('fused_multi_head_attention', **locals())
        dtype = x.dtype
        # check dtypes
739 740 741 742 743 744 745 746 747 748 749 750
        check_variable_and_dtype(
            x,
            'x',
            ['float16', 'float32', 'float64'],
            'fused_multihead_attention',
        )
        check_dtype(
            dtype,
            'dtype',
            ['float16', 'float32', 'float64'],
            'fused_multi_head_attention',
        )
751 752 753 754 755 756 757 758 759

        # set inputs
        inputs = dict()
        inputs['X'] = [x]
        if pre_ln_scale:
            inputs['LnScale'] = [pre_ln_scale]
        if pre_ln_bias:
            inputs['LnBias'] = [pre_ln_bias]
        inputs['QKVW'] = [qkv_weight]
760 761
        if qkv_bias is not None:
            inputs['QKVBias'] = [qkv_bias]
762 763
        inputs['SrcMask'] = attn_mask
        inputs['OutLinearW'] = [linear_weight]
764 765
        if linear_bias is not None:
            inputs['OutLinearBias'] = [linear_bias]
766 767 768 769
        if ln_scale:
            inputs['Ln2Scale'] = [ln_scale]
        if ln_bias:
            inputs['Ln2Bias'] = [ln_bias]
770 771
        if cache_kv:
            inputs['CacheKV'] = [cache_kv]
772

773 774 775
        if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
            seed = helper.main_program.random_seed

776 777 778 779 780 781
        # set attrs
        attrs = {
            'pre_layer_norm': pre_layer_norm,
            'epsilon': pre_ln_epsilon,
            'ln_epsilon': ln_epsilon,
            'dropout_rate': dropout_rate,
782
            'attn_dropout_rate': attn_dropout_rate,
L
Li Min 已提交
783
            'is_test': not training,
784 785 786 787 788 789
            'attn_dropout_fix_seed': seed is not None,
            'dropout_fix_seed': seed is not None,
            'attn_dropout_seed': seed if seed is not None else 0,
            'dropout_seed': seed if seed is not None else 0,
            'attn_dropout_implementation': mode,
            'dropout_implementation': mode,
790
            'add_residual': add_residual,
791
            'ring_id': ring_id,
792 793
            'num_heads': num_heads,
            'transpose_qkv_wb': transpose_qkv_wb,
794 795 796 797
        }

        # set outputs
        pre_ln_mean_out = helper.create_variable_for_type_inference(
798 799
            dtype=dtype, stop_gradient=True
        )
800
        pre_ln_variance_out = helper.create_variable_for_type_inference(
801 802
            dtype=dtype, stop_gradient=True
        )
803 804 805 806 807 808 809 810 811 812
        pre_ln_out = helper.create_variable_for_type_inference(dtype=dtype)

        qkv_out = helper.create_variable_for_type_inference(dtype=dtype)
        qkv_bias_out = helper.create_variable_for_type_inference(dtype=dtype)

        transpose_out = helper.create_variable_for_type_inference(dtype=dtype)
        qk_out = helper.create_variable_for_type_inference(dtype=dtype)
        qktv_out = helper.create_variable_for_type_inference(dtype=dtype)
        softmax_out = helper.create_variable_for_type_inference(dtype=dtype)
        attn_dropout_mask_out = helper.create_variable_for_type_inference(
813 814
            dtype=core.VarDesc.VarType.UINT8, stop_gradient=True
        )
815
        attn_dropout_out = helper.create_variable_for_type_inference(
816 817
            dtype=dtype
        )
818 819 820 821
        attn_mask_out = helper.create_variable_for_type_inference(dtype=dtype)
        fmha_out = helper.create_variable_for_type_inference(dtype=dtype)
        out_linear_out = helper.create_variable_for_type_inference(dtype=dtype)
        dropout_mask_out = helper.create_variable_for_type_inference(
822 823
            dtype=core.VarDesc.VarType.UINT8, stop_gradient=True
        )
824
        ln_mean_out = helper.create_variable_for_type_inference(
825 826
            dtype=dtype, stop_gradient=True
        )
827
        ln_variance_out = helper.create_variable_for_type_inference(
828 829
            dtype=dtype, stop_gradient=True
        )
830
        bias_dropout_residual_out = helper.create_variable_for_type_inference(
831 832
            dtype=dtype
        )
833
        final_out = helper.create_variable_for_type_inference(dtype=dtype)
834
        cache_kv_out = helper.create_variable_for_type_inference(dtype=dtype)
835

836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862
        helper.append_op(
            type='fused_attention',
            inputs=inputs,
            outputs={
                "LnMean": pre_ln_mean_out,
                "LnVariance": pre_ln_variance_out,
                "LnOut": pre_ln_out,
                "QKVOut": qkv_out,
                "QKVBiasOut": qkv_bias_out,
                "TransposeOut2": transpose_out,
                "QKOut": qk_out,
                "QKTVOut": qktv_out,
                "SoftmaxOut": softmax_out,
                "AttnDropoutMaskOut": attn_dropout_mask_out,
                "AttnDropoutOut": attn_dropout_out,
                "SrcMaskOut": attn_mask_out,
                "FMHAOut": fmha_out,
                "OutLinearOut": out_linear_out,
                "DropoutMaskOut": dropout_mask_out,
                "Ln2Mean": ln_mean_out,
                "Ln2Variance": ln_variance_out,
                "BiasDropoutResidualOut": bias_dropout_residual_out,
                'Y': final_out,
                'CacheKVOut': cache_kv_out,
            },
            attrs=attrs,
        )
863 864

        return (final_out, cache_kv_out) if cache_kv else final_out
865 866


867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884
def fused_multi_transformer(
    x,
    ln_scales,
    ln_biases,
    qkv_weights,
    qkv_biases,
    linear_weights,
    linear_biases,
    ffn_ln_scales,
    ffn_ln_biases,
    ffn1_weights,
    ffn1_biases,
    ffn2_weights,
    ffn2_biases,
    pre_layer_norm=True,
    epsilon=1e-05,
    cache_kvs=None,
    pre_caches=None,
885
    rotary_embs=None,
886 887 888
    time_step=None,
    attn_mask=None,
    dropout_rate=0.0,
889
    rotary_emb_dims=0,
890 891 892 893 894 895 896
    activation="gelu",
    training=False,
    mode='upscale_in_train',
    trans_qkvw=True,
    ring_id=-1,
    name=None,
):
897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952
    r"""
    This is a fusion operator to compute multi transformer layers in transformer model architecture.
    This operator only supports running on GPU. The function of the transformer layer is consistent
    with the following pseudo code:

    .. code-block:: python

        if pre_layer_norm:
            out = layer_norm(x)
            out = qkv_linear(out) + qkv_bias
        else:
            out = qkv_linear(x) + qkv_bias
        out = transpose(out, perm=[2, 0, 3, 1, 4])
        # extract q, k and v from out.
        q = out[0:1, ::]
        k = out[1:2, ::]
        v = out[2:3, ::]
        out = q * k^t
        out = attn_mask + out
        out = softmax(out)
        out = dropout(out)
        out = out * v
        out = transpose(out, perm=[0, 2, 1, 3])
        out = linear(out)
        if pre_layer_norm:
            out = x + dropout(out + bias)
        else:
            out = layer_norm(x + dropout(out + bias))

        residual = out;
        if pre_layer_norm:
            out = ffn_layer_norm(out)
        out = ffn1_linear(out)
        out = dropout(activation(out + ffn1_bias))
        out = ffn2_linear(out)
        out = residual + dropout(out + ffn2_bias)
        if not pre_layer_norm:
            out = ffn_layer_norm(out)

    Args:
        x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16 or float32, the shape is `[batch\_size, sequence\_length, d\_model]`.
        ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of attention layer_norm, the shape is `[d\_model]`.
        ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of attention layer_norm. the shape is `[d\_model]`.
        qkv_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention qkv computation. The shape is `[3, num\_head, dim\_head, d\_model]`.
        qkv_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention qkv computation. The shape is `[3, num\_head, dim\_head]`.
        linear_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention linear. The shape is `[num\_head * dim\_head, d\_model]`.
        linear_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention linear. The shape is `[d\_model]`.
        ffn_ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward layer_norm, the shape is `[d\_model]`
        ffn_ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of feedforward layer_norm, the shape is `[d\_model]`
        ffn1_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward first linear, the shape is `[d\_model, dim\_feedforward]`.
        ffn1_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward first linear, the shape is `[dim\_feedforward]`.
        ffn2_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward second linear, the shape is `[dim\_feedforward, d\_model]`.
        ffn2_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward second linear, the shape is `[d_model]`.
        pre_layer_norm (bool, optional): whether it is pre_layer_norm(True) or post_layer_norm(False). Default True.
        epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5.
        cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None.
953
        pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None.
954
        rotary_embs (Tensor optional): The RoPE embs for rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None.
955 956 957 958 959
        time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None.
        attn_mask (Tensor, optional):  A tensor used in multi-head attention to prevents attention to
            some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
            with shape `[batch_size, 1, sequence_length, sequence_length]`. Default None.
        dropout_rate (float, optional): The dropout probability of setting units to zero. Default 0.0.
960 961
        rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None,
            1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0.
962 963 964 965 966 967 968 969 970 971 972 973 974
        activation (str, optional): The activation. Default "gelu".
        training (bool, optional): A flag indicating whether it is in train phrase or not. Default False.
        mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']

                               1. upscale_in_train(default), upscale the output at training time

                                  - train: out = input * mask / ( 1.0 - p )
                                  - inference: out = input

                               2. downscale_in_infer, downscale the output at inference

                                  - train: out = input * mask
                                  - inference: out = input * (1.0 - p)
975 976 977
        trans_qkvw (bool, optional): Whether to transpose for weights of qkv.
            If true, the shape eights of qkv should be [3, num_head, dim_head, dim_embed].
            Otherwise the shape of weights of qkv should be [dim_embed, 3, num_head, dim_head]. Default True.
978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035
        ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using mp.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor|tuple: If `cache_kvs` is None, return a tensor that has
        the same shape and data type with `x`, representing the output
        of Transformer layers. If `cache_kvs` is not None, return the
        tuple (output, cache_kvs), which output is the output of
        Transformer layers, cache_kvs is inplace with input `cache_kvs`.

    Examples:
        .. code-block:: python

            # required: gpu
            import paddle
            import paddle.incubate.nn.functional as F

            # input: [batch_size, seq_len, embed_dim]
            x = paddle.rand(shape=(2, 4, 128), dtype="float32")

            # ln_scale: [embed_dim], ln_bias: [embed_dim]
            ln_scale = paddle.rand(shape=(128,), dtype="float32")
            ln_bias = paddle.rand(shape=(128,), dtype="float32")

            # qkv_weight: [3, num_head, head_dim, embed_dim], qkv_bias: [3, num_head, head_dim]
            qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32")
            qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32")

            # linear_weight: [embed_dim, embed_dim], linear_bias: [embed_dim]
            linear_weight = paddle.rand(shape=(128, 128), dtype="float32")
            linear_bias = paddle.rand(shape=(128,), dtype="float32")

            # ffn_ln_scale: [embed_dim], ffn_ln_bias: [embed_dim]
            ffn_ln_scale = paddle.rand(shape=(128,), dtype="float32")
            ffn_ln_bias = paddle.rand(shape=(128,), dtype="float32")

            # ffn1_weight: [embed_dim, 4*embed_dim], ffn1_bias: [4*embed_dim]
            ffn1_weight = paddle.rand(shape=(128, 4*128), dtype="float32")
            ffn1_bias = paddle.rand(shape=(4*128,), dtype="float32")

            # ffn2_weight: [4*embed_dim, embed_dim], ffn2_bias: [embed_dim]
            ffn2_weight = paddle.rand(shape=(4*128, 128), dtype="float32")
            ffn2_bias = paddle.rand(shape=(128,), dtype="float32")

            # self attention mask: [batch_size, 1, seq_len, seq_len]
            attn_mask = paddle.rand(shape=(2, 1, 4, 4), dtype="float32")

            # output: [batch_size, seq_len, embed_dim]
            output = F.fused_multi_transformer(
                x, [ln_scale], [ln_bias], [qkv_weight], [qkv_bias],
                [linear_weight], [linear_bias], [ffn_ln_scale], [ffn_ln_bias],
                [ffn1_weight], [ffn1_bias], [ffn2_weight], [ffn2_bias],
                attn_mask=attn_mask)
            # [2, 4, 128]
            print(output.shape)
    """
    if mode not in ('downscale_in_infer', 'upscale_in_train'):
        raise ValueError(
1036 1037
            "mode argument should be 'downscale_in_infer' or 'upscale_in_train'"
        )
1038 1039 1040
    mode = (
        'downgrade_in_infer' if mode == 'downscale_in_infer' else mode
    )  # semantic transfer
1041 1042

    if _non_static_mode():
1043
        cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer(
1044 1045 1046 1047 1048 1049 1050
            x,
            ln_scales,
            ln_biases,
            qkv_weights,
            qkv_biases,
            cache_kvs,
            pre_caches,
1051
            rotary_embs,
1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068
            time_step,
            attn_mask,
            linear_weights,
            linear_biases,
            ffn_ln_scales,
            ffn_ln_biases,
            ffn1_weights,
            ffn1_biases,
            ffn2_weights,
            ffn2_biases,
            cache_kvs,
            'pre_layer_norm',
            pre_layer_norm,
            'epsilon',
            epsilon,
            'dropout_rate',
            dropout_rate,
1069 1070
            'rotary_emb_dims',
            rotary_emb_dims,
1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081
            'is_test',
            not training,
            'dropout_implementation',
            mode,
            'act_method',
            activation,
            'trans_qkvw',
            trans_qkvw,
            'ring_id',
            ring_id,
        )
1082 1083 1084 1085 1086 1087 1088
        if cache_kvs is not None:
            return final_out, cache_kv_out
        return final_out
    else:
        helper = LayerHelper('fused_multi_transformer', **locals())
        dtype = x.dtype
        # check dtypes
1089 1090 1091 1092 1093 1094
        check_variable_and_dtype(
            x, 'x', ['float16', 'float32'], 'fused_multi_transformer'
        )
        check_dtype(
            dtype, 'dtype', ['float16', 'float32'], 'fused_multi_transformer'
        )
1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108

        # set inputs
        inputs = dict()
        inputs['X'] = [x]
        inputs['LnScale'] = ln_scales
        inputs['LnBias'] = ln_biases
        inputs['QKVW'] = qkv_weights
        if qkv_biases is not None:
            inputs['QKVBias'] = qkv_biases
        if cache_kvs is not None:
            assert len(cache_kvs) == len(qkv_weights)
            inputs['CacheKV'] = cache_kvs
            if time_step is not None:
                inputs['TimeStep'] = time_step
1109 1110
        if pre_caches is not None:
            inputs['PreCaches'] = pre_caches
1111 1112
        if rotary_emb_dims > 0:
            inputs['RotaryPosEmb'] = rotary_embs
1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131
        inputs['SrcMask'] = attn_mask
        inputs['OutLinearW'] = linear_weights
        if linear_biases is not None:
            inputs['OutLinearBias'] = linear_biases

        inputs['FFNLnScale'] = ffn_ln_scales
        inputs['FFNLnBias'] = ffn_ln_biases
        inputs['FFN1Weight'] = ffn1_weights
        if ffn1_biases is not None:
            inputs['FFN1Bias'] = ffn1_biases
        inputs['FFN2Weight'] = ffn2_weights
        if ffn2_biases is not None:
            inputs['FFN2Bias'] = ffn2_biases

        # set attrs
        attrs = {
            'pre_layer_norm': pre_layer_norm,
            'epsilon': epsilon,
            'dropout_rate': dropout_rate,
1132
            'rotary_emb_dims': rotary_emb_dims,
L
Li Min 已提交
1133
            'is_test': not training,
1134 1135
            'dropout_implementation': mode,
            'act_method': activation,
1136
            'trans_qkvw': trans_qkvw,
1137
            'ring_id': ring_id,
1138 1139 1140 1141 1142 1143 1144 1145 1146
        }

        outputs = dict()
        final_out = helper.create_variable_for_type_inference(dtype=dtype)
        outputs['Out'] = final_out
        if cache_kvs:
            # NOTE: inplace
            outputs['CacheKVOut'] = cache_kvs

1147 1148 1149 1150 1151 1152
        helper.append_op(
            type='fused_multi_transformer',
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
        )
1153 1154

        return (final_out, cache_kvs) if cache_kvs else final_out