fused_transformer.py 43.8 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from paddle.fluid.layer_helper import LayerHelper
J
Jiabin Yang 已提交
16
from paddle.fluid.framework import _non_static_mode, default_main_program
17
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
18
from paddle.fluid import core, dygraph_utils
L
Li Min 已提交
19 20 21 22 23
from paddle import _C_ops

__all__ = []


24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
def _verify_dropout_rate(dropout_rate):
    if not isinstance(dropout_rate, (float, int)):
        raise TypeError("dropout_rate argument should be a number")
    if dropout_rate < 0 or dropout_rate > 1:
        raise ValueError("dropout_rate argument should between 0 and 1")


def fused_feedforward(x,
                      linear1_weight,
                      linear2_weight,
                      linear1_bias=None,
                      linear2_bias=None,
                      ln1_scale=None,
                      ln1_bias=None,
                      ln2_scale=None,
                      ln2_bias=None,
                      dropout1_rate=0.5,
                      dropout2_rate=0.5,
                      activation="relu",
                      ln1_epsilon=1e-5,
                      ln2_epsilon=1e-5,
                      pre_layer_norm=False,
46 47
                      training=True,
                      mode='upscale_in_train',
48
                      ring_id=-1,
49
                      add_residual=True,
50
                      name=None):
51
    r"""
52 53 54 55 56 57
    This is a fusion operator to compute feed forward layer in transformer model architecture.
    This operator only supports running on GPU. The function of the operator is consistent with
    the following pseudo code:

    .. code-block:: python

58
        residual = x
59
        if pre_layer_norm:
60 61 62 63 64 65 66 67
            out = layer_norm1(x)
        else:
            out = x
        out = linear2(dropout1(activation(linear1(src))))
        if add_residual:
            out = residual + dropout2(out)
        else:
            out = dropout2(out)
68
        if not pre_layer_norm:
69 70
            out = layer_norm2(out)

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87

    Args:
        x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16, float32 or float64, the shape is`[batch\_size, sequence\_length, d_model]`.
        linear1_weight (Tensor): The weight of first linear, the data type is same as `x`, the shape is `[d\_model, dim\_feedforward]`.
        linear2_weight (Tensor): The weight of second linear, the data type is same as `x`, the shape is `[dim\_feedforward, d\_model]`.
        linear1_bias (Tensor, optional): The bias of first linear, the data type is same as `x`, the shape is `[dim_feedforward]`. Default None.
        linear2_bias (Tensor, optional): The bias of second linear, the data type is same as `x`, the shape is `[d_model]`. Default None.
        ln1_scale (Tensor, optional): the weight of first layer_norm, the data type is float32 or float64, the shape is same as `x`. Default None.
        ln1_bias (Tensor, optional): The bias of first layer_norm, the data type is float32 or float64, the shape is `[d\_model]`. Default None.
        ln2_scale (Tensor, optional): The weight of second layer_norm, the data type is float32 or float64, the shape is same as `x`. Default None.
        ln2_bias (Tensor, optional): The bias of second layer_norm, the data type is float32 or float64, the shape is `[d\_model]`. Default None.
        dropout1_rate (float, optional): The first dropout probability of setting units to zero. Default 0.5.
        dropout2_rate (float, optional): The second dropout probability of setting units to zero. Default 0.5.
        activation (str, optional): The activation. Default "relu".
        ln1_epsilon (float, optional): Small float of first layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
        ln2_epsilon (float, optional): Small float of second layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
        pre_layer_norm (bool, optional): add layer_norm in the pre-processing stage or post-processing state.
L
Li Min 已提交
88 89 90 91 92 93 94 95 96 97 98 99
        training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
        mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']

                               1. upscale_in_train(default), upscale the output at training time

                                  - train: out = input * mask / ( 1.0 - p )
                                  - inference: out = input

                               2. downscale_in_infer, downscale the output at inference

                                  - train: out = input * mask
                                  - inference: out = input * (1.0 - p)
100
        ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using tensor parallel.
101
        add_residual (bool, optional): Whether add residual at the end. Default is True.
102 103 104 105 106 107 108 109 110 111
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor: The output Tensor, the data type and shape is same as `x`.

    Examples:
        .. code-block:: python

            # required: gpu
            import paddle
112 113 114 115 116 117 118
            import paddle.incubate.nn.functional as F

            x = paddle.randn(shape=(1, 8, 8), dtype="float32")
            linear1_weight = paddle.randn(shape=(8, 8), dtype="float32")
            linear2_weight = paddle.randn(shape=(8, 8), dtype="float32")
            out = F.fused_feedforward(x, linear1_weight, linear2_weight)
            print(out.shape)
119 120 121 122 123
            # (1, 8, 8)
    """
    _verify_dropout_rate(dropout1_rate)
    _verify_dropout_rate(dropout2_rate)

124 125 126
    seed = None
    if mode not in ('downscale_in_infer', 'upscale_in_train'):
        raise ValueError(
127 128
            "mode argument should be 'downscale_in_infer' or 'upscale_in_train'"
        )
129 130
    mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode  #semantic transfer

J
Jiabin Yang 已提交
131
    if _non_static_mode():
132 133
        if default_main_program().random_seed != 0:
            seed = default_main_program().random_seed
134 135 136 137 138
        out, _, _, _, _, _, _, _, _, _, _ = _C_ops.fused_feedforward(
            x, None, None, linear1_weight, linear1_bias, linear2_weight,
            linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias,
            'pre_layer_norm', pre_layer_norm, 'ln1_epsilon', ln1_epsilon,
            'ln2_epsilon', ln2_epsilon, 'act_method', activation,
139
            'dropout1_rate', dropout1_rate, 'dropout2_rate', dropout2_rate,
140 141 142 143
            "is_test", not training, "dropout1_fix_seed", seed is not None,
            "dropout2_fix_seed", seed is not None, "dropout1_seed",
            seed if seed is not None else 0, "dropout2_seed",
            seed if seed is not None else 0, 'dropout1_implementation', mode,
144 145
            'dropout2_implementation', mode, 'add_residual', add_residual,
            'ring_id', ring_id)
146 147 148 149 150 151 152 153 154 155 156 157 158 159
        return out

    helper = LayerHelper("fused_feedforward")
    dtype = x.dtype
    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                             'fused_feedforward')
    check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'],
                'fused_feedforward')

    out = helper.create_variable_for_type_inference(x.dtype)
    dropout1_mask = helper.create_variable_for_type_inference(
        'uint8', stop_gradient=True)
    dropout2_mask = helper.create_variable_for_type_inference(
        'uint8', stop_gradient=True)
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
    ln1_mean = helper.create_variable_for_type_inference(x.dtype,
                                                         stop_gradient=True)
    ln1_variance = helper.create_variable_for_type_inference(x.dtype,
                                                             stop_gradient=True)
    ln2_mean = helper.create_variable_for_type_inference(x.dtype,
                                                         stop_gradient=True)
    ln2_variance = helper.create_variable_for_type_inference(x.dtype,
                                                             stop_gradient=True)
    linear1_out = helper.create_variable_for_type_inference(x.dtype,
                                                            stop_gradient=True)
    ln1_out = helper.create_variable_for_type_inference(x.dtype,
                                                        stop_gradient=True)
    dropout1_out = helper.create_variable_for_type_inference(x.dtype,
                                                             stop_gradient=True)
    dropout2_out = helper.create_variable_for_type_inference(x.dtype,
                                                             stop_gradient=True)
176

177 178 179
    if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
        seed = helper.main_program.random_seed

180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
    helper.append_op(type='fused_feedforward',
                     inputs={
                         'X': x,
                         'Linear1Weight': linear1_weight,
                         'Linear1Bias': linear1_bias,
                         'Linear2Weight': linear2_weight,
                         'Linear2Bias': linear2_bias,
                         'Ln1Scale': ln1_scale,
                         'Ln1Bias': ln1_bias,
                         'Ln2Scale': ln2_scale,
                         'Ln2Bias': ln2_bias,
                     },
                     outputs={
                         'Out': out,
                         'Dropout1Mask': dropout1_mask,
                         'Dropout2Mask': dropout2_mask,
                         'Ln1Mean': ln1_mean,
                         'Ln1Variance': ln1_variance,
                         'Ln2Mean': ln2_mean,
                         'Ln2Variance': ln2_variance,
                         'Linear1Out': linear1_out,
                         'Ln1Out': ln1_out,
                         'Dropout1Out': dropout1_out,
                         'Dropout2Out': dropout2_out,
                     },
                     attrs={
                         'dropout1_rate': dropout1_rate,
                         'dropout2_rate': dropout2_rate,
                         'act_method': activation,
                         'pre_layer_norm': pre_layer_norm,
                         'ln1_epsilon': ln1_epsilon,
                         'ln2_epsilon': ln2_epsilon,
212
                         'is_test': not training,
213 214 215 216 217
                         'dropout1_fix_seed': seed is not None,
                         'dropout2_fix_seed': seed is not None,
                         'dropout1_seed': seed if seed is not None else 0,
                         'dropout2_seed': seed if seed is not None else 0,
                         'dropout1_implementation': mode,
218
                         'dropout2_implementation': mode,
219
                         'add_residual': add_residual,
220
                         'ring_id': ring_id,
221
                     })
222 223 224
    return out


225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
def fused_bias_dropout_residual_layer_norm(x,
                                           residual,
                                           bias=None,
                                           ln_scale=None,
                                           ln_bias=None,
                                           dropout_rate=0.5,
                                           ln_epsilon=1e-5,
                                           training=True,
                                           mode='upscale_in_train',
                                           name=None):
    r"""
    The fused_bias_dropout_residual_layer_norm operator. The pseudo code is as follows:

    .. code-block:: python
        y = layer_norm(residual + dropout(bias + x))

    Parameters:
        x (Tensor): The input tensor. The shape is `[*, embed\_dim]`.
        residual (Tensor): The residual tensor. The shape is same as x.
        bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
        ln_scale (Tensor, optional): The weight tensor of layernorm. The shape is `[embed_dim]`. Default None.
        ln_bias (Tensor, optional): The bias tensor of layernorm. The shape is `[embed_dim]`. Default None.
        dropout_rate (float, optional): The dropout probability used on attention
            weights to drop some attention targets for the dropout after attention.
            0 for no dropout. Default 0.5.
        ln_epsilon (float, optional): Small float value added to denominator of layer_norm
            to avoid dividing by zero. Default is 1e-5.
        training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
        mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']

                               1. upscale_in_train(default), upscale the output at training time

                                  - train: out = input * mask / ( 1.0 - p )
                                  - inference: out = input

                               2. downscale_in_infer, downscale the output at inference

                                  - train: out = input * mask
                                  - inference: out = input * (1.0 - p)
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor: The output Tensor, the data type and shape is same as `x`.

    Examples:

        .. code-block:: python

            # required: gpu
            import paddle
            import paddle.incubate.nn.functional as F

            # input: [batch_size, seq_len, embed_dim]
            x = paddle.rand(shape=(2, 4, 128), dtype="float32")
            # residual: [batch_size, seq_len, embed_dim]
            residual = paddle.rand(shape=(2, 4, 128), dtype="float32")
            # linear bias: [embed_dim]
            bias = paddle.rand(shape=[128], dtype="float32")
            # output: [batch_size, seq_len, embed_dim]
            output = F.fused_bias_dropout_residual_layer_norm(
                x, residual, bias)
            # [2, 4, 128]
            print(output.shape)
    """
    seed = None
    if mode not in ('downscale_in_infer', 'upscale_in_train'):
        raise ValueError(
292 293
            "mode argument should be 'downscale_in_infer' or 'upscale_in_train'"
        )
294 295 296
    mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode  #semantic transfer

    if ln_scale is not None:
297 298
        assert len(ln_scale.shape
                   ) == 1, "The dims of the shape of ln_scale should be 1."
299 300 301 302 303 304 305 306 307 308 309 310 311
        assert x.shape[len(x.shape) - 1] == ln_scale.shape[
            0], "The dim of ln_scale must equal to the last dim of x."
    if ln_bias is not None:
        assert len(
            ln_bias.shape) == 1, "The dims of the shape of ln_bias should be 1."
        assert x.shape[len(x.shape) - 1] == ln_bias.shape[
            0], "The dim of ln_bias must equal to the last dim of x."

    if _non_static_mode():
        if default_main_program().random_seed != 0:
            seed = default_main_program().random_seed
        _, _, _, _, final_out = _C_ops.fused_bias_dropout_residual_layer_norm(
            x, residual, bias, ln_scale, ln_bias, 'dropout_rate', dropout_rate,
L
Li Min 已提交
312
            'ln_epsilon', ln_epsilon, 'is_test', not training,
313 314
            'dropout_fix_seed', seed is not None, 'dropout_seed',
            seed if seed is not None else 0, 'dropout_implementation', mode)
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
        return final_out
    else:
        helper = LayerHelper('fused_bias_dropout_residual_layer_norm',
                             **locals())
        dtype = x.dtype
        # check dtypes
        check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                                 'fused_bias_dropout_residual_layer_norm')
        check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'],
                    'fused_bias_dropout_residual_layer_norm')
        # set inputs
        inputs = dict()
        inputs['X'] = [x]
        inputs['Residual'] = [residual]
        if bias is not None:
            inputs['Bias'] = [bias]
        if ln_scale:
            inputs['LnScale'] = [ln_scale]
        if ln_bias:
            inputs['LnBias'] = [ln_bias]
        if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
            seed = helper.main_program.random_seed
        # set attrs
        attrs = {
            'ln_epsilon': ln_epsilon,
            'dropout_rate': dropout_rate,
L
Li Min 已提交
341
            'is_test': not training,
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
            'dropout_fix_seed': seed is not None,
            'dropout_seed': seed if seed is not None else 0,
            'dropout_implementation': mode,
        }
        # set outputs
        dropout_mask_out = helper.create_variable_for_type_inference(
            dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
        ln_mean_out = helper.create_variable_for_type_inference(
            dtype=dtype, stop_gradient=True)
        ln_variance_out = helper.create_variable_for_type_inference(
            dtype=dtype, stop_gradient=True)
        bias_dropout_residual_out = helper.create_variable_for_type_inference(
            dtype=dtype)
        final_out = helper.create_variable_for_type_inference(dtype=dtype)

357 358 359 360 361 362 363 364 365 366 367
        helper.append_op(type='fused_bias_dropout_residual_layer_norm',
                         inputs=inputs,
                         outputs={
                             "BiasDropoutResidualOut":
                             bias_dropout_residual_out,
                             "DropoutMaskOut": dropout_mask_out,
                             "LnMean": ln_mean_out,
                             "LnVariance": ln_variance_out,
                             'Y': final_out,
                         },
                         attrs=attrs)
368 369 370
        return final_out


L
Li Min 已提交
371 372 373 374 375 376 377 378 379 380 381
def fused_multi_head_attention(x,
                               qkv_weight,
                               linear_weight,
                               pre_layer_norm=False,
                               pre_ln_scale=None,
                               pre_ln_bias=None,
                               ln_scale=None,
                               ln_bias=None,
                               pre_ln_epsilon=1e-05,
                               qkv_bias=None,
                               linear_bias=None,
382
                               cache_kv=None,
L
Li Min 已提交
383 384 385 386
                               attn_mask=None,
                               dropout_rate=0.5,
                               attn_dropout_rate=0.5,
                               ln_epsilon=1e-05,
387 388
                               training=True,
                               mode='upscale_in_train',
389
                               ring_id=-1,
390
                               add_residual=True,
L
Li Min 已提交
391
                               name=None):
392
    r"""
L
Li Min 已提交
393 394
    Attention mapps queries and a set of key-value pairs to outputs, and
    Multi-Head Attention performs multiple parallel attention to jointly attending
395
    to information from different representation subspaces. This API only
L
Li Min 已提交
396
    support self_attention. The pseudo code is as follows:
397 398 399

    .. code-block:: python

400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
        residual = x
        if pre_layer_norm:
            out = layer_norm(x)
        else:
            out = x
        # compute q, k, v
        out = matmul(out, qkv_weight) + qkv_bias
        out = transpose(out, perm=[2, 0, 3, 1, 4])
        # extract q, k and v from out
        q = out[0:1,::] * (head_dim ** -0.5)
        k = out[1:2,::]
        v = out[2:3,::]
        out = matmul(q, k, transpose_y=True)
        out = out + attn_mask
        out = softmax(out)
        out = dropout(out)
        out = matmul(out, v)
        # combine heads
        out = transpose(out, perm=[0, 2, 1, 3])
        # project to output
        out = linear(out)
        if add_residual:
            out = residual + dropout(out)
423
        else:
424 425 426 427
            out = dropout(out)
        if not pre_layer_norm:
            out = layer_norm(out)

L
Li Min 已提交
428 429

    Parameters:
430
        x (Tensor): The input tensor of fused_multi_head_attention. The shape is
L
Li Min 已提交
431 432 433
            `[batch\_size, sequence\_len, embed\_dim]`.
        qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`.
        linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`.
434
        pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture
435
                                        (False). Default False.
L
Li Min 已提交
436 437 438 439
        pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None.
        pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None.
        ln_scale (Tensor, optional): The weight tensor of layernorm. Default None.
        ln_bias (Tensor, optional): The bias tensor of layernorm. Default None.
440
        pre_ln_epsilon (float, optional): Small float value added to denominator of the pre layer_norm
L
Li Min 已提交
441
            to avoid dividing by zero. Default is 1e-5.
442
        qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`.
L
Li Min 已提交
443 444
            Default None.
        linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
445
        cache_kv (Tensor, optional): For generation model, cache structure. The shape is `[2, bsz, num_head, seq_len, head_dim]`. Default None.
446
        attn_mask (Tensor, optional):  A tensor used in multi-head attention to prevents attention to
447
            some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
448 449 450 451
            with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the
            data type is bool, the unwanted positions have `False` values and the others have `True` values.
            When the data type is int, the unwanted positions have 0 values and the others have 1 values.
            When the data type is float, the unwanted positions have `-INF` values and the others have 0 values.
452
            It can be None when nothing wanted or needed to be prevented attention to. Default None.
L
Li Min 已提交
453
        dropout_rate (float, optional): The dropout probability used on attention
454
            weights to drop some attention targets for the dropout after attention.
455
            0 for no dropout. Default 0.5.
L
Li Min 已提交
456
        attn_dropout_rate (float, optional): The dropout probability used on attention
457
            weights to drop some attention targets for the dropout in attention.
458
            0 for no dropout. Default 0.5.
459
        ln_epsilon (float, optional): Small float value added to denominator of layer_norm
L
Li Min 已提交
460
            to avoid dividing by zero. Default is 1e-5.
L
Li Min 已提交
461 462 463 464 465 466 467 468 469 470 471 472
        training (bool, optional): A flag indicating whether it is in train phrase or not. Default True.
        mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']

                               1. upscale_in_train(default), upscale the output at training time

                                  - train: out = input * mask / ( 1.0 - p )
                                  - inference: out = input

                               2. downscale_in_infer, downscale the output at inference

                                  - train: out = input * mask
                                  - inference: out = input * (1.0 - p)
473
        ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp
474
        add_residual (bool, optional): Whether add residual at the end. Default is True.
475
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
476

477 478 479
    Returns:
        Tensor: The output Tensor, the data type and shape is same as `x`.

L
Li Min 已提交
480 481 482
    Examples:

        .. code-block:: python
483 484

            # required: gpu
L
Li Min 已提交
485
            import paddle
486
            import paddle.incubate.nn.functional as F
L
Li Min 已提交
487 488 489

            # input: [batch_size, seq_len, embed_dim]
            x = paddle.rand(shape=(2, 4, 128), dtype="float32")
490
            # qkv_weight: [3, num_head, head_dim, embed_dim]
L
Li Min 已提交
491
            qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32")
492
            # qkv_bias: [3, num_head, head_dim]
L
Li Min 已提交
493 494 495 496 497 498 499 500 501 502 503 504
            qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32")
            # linear_weight: [embed_dim, embed_dim]
            linear_weight = paddle.rand(shape=(128, 128), dtype="float32")
            # linear_bias: [embed_dim]
            linear_bias = paddle.rand(shape=[128], dtype="float32")
            # self attention mask: [batch_size, num_heads, seq_len, seq_len]
            attn_mask = paddle.rand(shape=(2, 4, 4, 4), dtype="float32")

            # output: [batch_size, seq_len, embed_dim]
            output = F.fused_multi_head_attention(
                x, qkv_weight, linear_weight, False,
                None, None, None, None, 1e-5, qkv_bias,
505
                linear_bias, None, attn_mask)
L
Li Min 已提交
506 507 508
            # [2, 4, 128]
            print(output.shape)
    """
509 510 511 512

    seed = None
    if mode not in ('downscale_in_infer', 'upscale_in_train'):
        raise ValueError(
513 514
            "mode argument should be 'downscale_in_infer' or 'upscale_in_train'"
        )
515 516
    mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode  #semantic transfer

J
Jiabin Yang 已提交
517
    if _non_static_mode():
518 519
        if default_main_program().random_seed != 0:
            seed = default_main_program().random_seed
520 521
        # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out,
        # qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out,
L
Li Min 已提交
522
        # linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out
523 524 525 526 527 528
        assert len(qkv_weight.shape
                   ) == 4, "The dims of the shape of qkv_weight should be 4."
        assert qkv_weight.shape[
            0] == 3, "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]."
        assert qkv_weight.shape[3] == x.shape[
            2], "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim."
529 530 531 532
        if ring_id == -1:
            # under mp, the num head will be split, this equation will not hold
            assert qkv_weight.shape[1] * qkv_weight.shape[2] == qkv_weight.shape[
                3], "embed_dim must be divisible by num_heads."
533

534 535 536 537 538
        _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, cache_kv_out, final_out = _C_ops.fused_attention(
            x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, cache_kv,
            attn_mask, linear_weight, linear_bias, ln_scale, ln_bias,
            'pre_layer_norm', pre_layer_norm, 'epsilon', pre_ln_epsilon,
            'dropout_rate', dropout_rate, 'attn_dropout_rate',
L
Li Min 已提交
539 540
            attn_dropout_rate, 'ln_epsilon', ln_epsilon, 'is_test',
            not training, 'attn_dropout_fix_seed', seed is not None,
541 542 543
            'dropout_fix_seed', seed is not None, 'attn_dropout_seed',
            seed if seed is not None else 0, 'dropout_seed',
            seed if seed is not None else 0, 'attn_dropout_implementation',
544 545
            mode, 'dropout_implementation', mode, 'add_residual', add_residual,
            'ring_id', ring_id)
546 547
        if cache_kv is not None:
            return final_out, cache_kv_out
L
Li Min 已提交
548
        return final_out
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565
    else:
        helper = LayerHelper('fused_multi_head_attention', **locals())
        dtype = x.dtype
        # check dtypes
        check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                                 'fused_multihead_attention')
        check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'],
                    'fused_multi_head_attention')

        # set inputs
        inputs = dict()
        inputs['X'] = [x]
        if pre_ln_scale:
            inputs['LnScale'] = [pre_ln_scale]
        if pre_ln_bias:
            inputs['LnBias'] = [pre_ln_bias]
        inputs['QKVW'] = [qkv_weight]
566 567
        if qkv_bias is not None:
            inputs['QKVBias'] = [qkv_bias]
568 569
        inputs['SrcMask'] = attn_mask
        inputs['OutLinearW'] = [linear_weight]
570 571
        if linear_bias is not None:
            inputs['OutLinearBias'] = [linear_bias]
572 573 574 575
        if ln_scale:
            inputs['Ln2Scale'] = [ln_scale]
        if ln_bias:
            inputs['Ln2Bias'] = [ln_bias]
576
        if cache_kv: inputs['CacheKV'] = [cache_kv]
577

578 579 580
        if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
            seed = helper.main_program.random_seed

581 582 583 584 585 586
        # set attrs
        attrs = {
            'pre_layer_norm': pre_layer_norm,
            'epsilon': pre_ln_epsilon,
            'ln_epsilon': ln_epsilon,
            'dropout_rate': dropout_rate,
587
            'attn_dropout_rate': attn_dropout_rate,
L
Li Min 已提交
588
            'is_test': not training,
589 590 591 592 593 594
            'attn_dropout_fix_seed': seed is not None,
            'dropout_fix_seed': seed is not None,
            'attn_dropout_seed': seed if seed is not None else 0,
            'dropout_seed': seed if seed is not None else 0,
            'attn_dropout_implementation': mode,
            'dropout_implementation': mode,
595
            'add_residual': add_residual,
596
            'ring_id': ring_id
597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628
        }

        # set outputs
        pre_ln_mean_out = helper.create_variable_for_type_inference(
            dtype=dtype, stop_gradient=True)
        pre_ln_variance_out = helper.create_variable_for_type_inference(
            dtype=dtype, stop_gradient=True)
        pre_ln_out = helper.create_variable_for_type_inference(dtype=dtype)

        qkv_out = helper.create_variable_for_type_inference(dtype=dtype)
        qkv_bias_out = helper.create_variable_for_type_inference(dtype=dtype)

        transpose_out = helper.create_variable_for_type_inference(dtype=dtype)
        qk_out = helper.create_variable_for_type_inference(dtype=dtype)
        qktv_out = helper.create_variable_for_type_inference(dtype=dtype)
        softmax_out = helper.create_variable_for_type_inference(dtype=dtype)
        attn_dropout_mask_out = helper.create_variable_for_type_inference(
            dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
        attn_dropout_out = helper.create_variable_for_type_inference(
            dtype=dtype)
        attn_mask_out = helper.create_variable_for_type_inference(dtype=dtype)
        fmha_out = helper.create_variable_for_type_inference(dtype=dtype)
        out_linear_out = helper.create_variable_for_type_inference(dtype=dtype)
        dropout_mask_out = helper.create_variable_for_type_inference(
            dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
        ln_mean_out = helper.create_variable_for_type_inference(
            dtype=dtype, stop_gradient=True)
        ln_variance_out = helper.create_variable_for_type_inference(
            dtype=dtype, stop_gradient=True)
        bias_dropout_residual_out = helper.create_variable_for_type_inference(
            dtype=dtype)
        final_out = helper.create_variable_for_type_inference(dtype=dtype)
629
        cache_kv_out = helper.create_variable_for_type_inference(dtype=dtype)
630

631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656
        helper.append_op(type='fused_attention',
                         inputs=inputs,
                         outputs={
                             "LnMean": pre_ln_mean_out,
                             "LnVariance": pre_ln_variance_out,
                             "LnOut": pre_ln_out,
                             "QKVOut": qkv_out,
                             "QKVBiasOut": qkv_bias_out,
                             "TransposeOut2": transpose_out,
                             "QKOut": qk_out,
                             "QKTVOut": qktv_out,
                             "SoftmaxOut": softmax_out,
                             "AttnDropoutMaskOut": attn_dropout_mask_out,
                             "AttnDropoutOut": attn_dropout_out,
                             "SrcMaskOut": attn_mask_out,
                             "FMHAOut": fmha_out,
                             "OutLinearOut": out_linear_out,
                             "DropoutMaskOut": dropout_mask_out,
                             "Ln2Mean": ln_mean_out,
                             "Ln2Variance": ln_variance_out,
                             "BiasDropoutResidualOut":
                             bias_dropout_residual_out,
                             'Y': final_out,
                             'CacheKVOut': cache_kv_out
                         },
                         attrs=attrs)
657 658

        return (final_out, cache_kv_out) if cache_kv else final_out
659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817


def fused_multi_transformer(x,
                            ln_scales,
                            ln_biases,
                            qkv_weights,
                            qkv_biases,
                            linear_weights,
                            linear_biases,
                            ffn_ln_scales,
                            ffn_ln_biases,
                            ffn1_weights,
                            ffn1_biases,
                            ffn2_weights,
                            ffn2_biases,
                            pre_layer_norm=True,
                            epsilon=1e-05,
                            cache_kvs=None,
                            time_step=None,
                            attn_mask=None,
                            dropout_rate=0.0,
                            activation="gelu",
                            training=False,
                            mode='upscale_in_train',
                            ring_id=-1,
                            name=None):
    r"""
    This is a fusion operator to compute multi transformer layers in transformer model architecture.
    This operator only supports running on GPU. The function of the transformer layer is consistent
    with the following pseudo code:

    .. code-block:: python

        if pre_layer_norm:
            out = layer_norm(x)
            out = qkv_linear(out) + qkv_bias
        else:
            out = qkv_linear(x) + qkv_bias
        out = transpose(out, perm=[2, 0, 3, 1, 4])
        # extract q, k and v from out.
        q = out[0:1, ::]
        k = out[1:2, ::]
        v = out[2:3, ::]
        out = q * k^t
        out = attn_mask + out
        out = softmax(out)
        out = dropout(out)
        out = out * v
        out = transpose(out, perm=[0, 2, 1, 3])
        out = linear(out)
        if pre_layer_norm:
            out = x + dropout(out + bias)
        else:
            out = layer_norm(x + dropout(out + bias))

        residual = out;
        if pre_layer_norm:
            out = ffn_layer_norm(out)
        out = ffn1_linear(out)
        out = dropout(activation(out + ffn1_bias))
        out = ffn2_linear(out)
        out = residual + dropout(out + ffn2_bias)
        if not pre_layer_norm:
            out = ffn_layer_norm(out)

    Args:
        x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16 or float32, the shape is `[batch\_size, sequence\_length, d\_model]`.
        ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of attention layer_norm, the shape is `[d\_model]`.
        ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of attention layer_norm. the shape is `[d\_model]`.
        qkv_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention qkv computation. The shape is `[3, num\_head, dim\_head, d\_model]`.
        qkv_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention qkv computation. The shape is `[3, num\_head, dim\_head]`.
        linear_weights (list(Tensor)|tuple(Tensor)): The weight tensors of attention linear. The shape is `[num\_head * dim\_head, d\_model]`.
        linear_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of attention linear. The shape is `[d\_model]`.
        ffn_ln_scales (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward layer_norm, the shape is `[d\_model]`
        ffn_ln_biases (list(Tensor)|tuple(Tensor)): The bias tensors of feedforward layer_norm, the shape is `[d\_model]`
        ffn1_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward first linear, the shape is `[d\_model, dim\_feedforward]`.
        ffn1_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward first linear, the shape is `[dim\_feedforward]`.
        ffn2_weights (list(Tensor)|tuple(Tensor)): The weight tensors of feedforward second linear, the shape is `[dim\_feedforward, d\_model]`.
        ffn2_biases (list(Tensor)|tuple(Tensor)|None): The bias tensors of feedforward second linear, the shape is `[d_model]`.
        pre_layer_norm (bool, optional): whether it is pre_layer_norm(True) or post_layer_norm(False). Default True.
        epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5.
        cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None.
        time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None.
        attn_mask (Tensor, optional):  A tensor used in multi-head attention to prevents attention to
            some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
            with shape `[batch_size, 1, sequence_length, sequence_length]`. Default None.
        dropout_rate (float, optional): The dropout probability of setting units to zero. Default 0.0.
        activation (str, optional): The activation. Default "gelu".
        training (bool, optional): A flag indicating whether it is in train phrase or not. Default False.
        mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']

                               1. upscale_in_train(default), upscale the output at training time

                                  - train: out = input * mask / ( 1.0 - p )
                                  - inference: out = input

                               2. downscale_in_infer, downscale the output at inference

                                  - train: out = input * mask
                                  - inference: out = input * (1.0 - p)
        ring_id (int, optional): For distributed forward in tensor model parallel, only support NCCL. Default is -1, means not using mp.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor|tuple: If `cache_kvs` is None, return a tensor that has
        the same shape and data type with `x`, representing the output
        of Transformer layers. If `cache_kvs` is not None, return the
        tuple (output, cache_kvs), which output is the output of
        Transformer layers, cache_kvs is inplace with input `cache_kvs`.

    Examples:
        .. code-block:: python

            # required: gpu
            import paddle
            import paddle.incubate.nn.functional as F
            import numpy as np

            # input: [batch_size, seq_len, embed_dim]
            x = paddle.rand(shape=(2, 4, 128), dtype="float32")

            # ln_scale: [embed_dim], ln_bias: [embed_dim]
            ln_scale = paddle.rand(shape=(128,), dtype="float32")
            ln_bias = paddle.rand(shape=(128,), dtype="float32")

            # qkv_weight: [3, num_head, head_dim, embed_dim], qkv_bias: [3, num_head, head_dim]
            qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32")
            qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32")

            # linear_weight: [embed_dim, embed_dim], linear_bias: [embed_dim]
            linear_weight = paddle.rand(shape=(128, 128), dtype="float32")
            linear_bias = paddle.rand(shape=(128,), dtype="float32")

            # ffn_ln_scale: [embed_dim], ffn_ln_bias: [embed_dim]
            ffn_ln_scale = paddle.rand(shape=(128,), dtype="float32")
            ffn_ln_bias = paddle.rand(shape=(128,), dtype="float32")

            # ffn1_weight: [embed_dim, 4*embed_dim], ffn1_bias: [4*embed_dim]
            ffn1_weight = paddle.rand(shape=(128, 4*128), dtype="float32")
            ffn1_bias = paddle.rand(shape=(4*128,), dtype="float32")

            # ffn2_weight: [4*embed_dim, embed_dim], ffn2_bias: [embed_dim]
            ffn2_weight = paddle.rand(shape=(4*128, 128), dtype="float32")
            ffn2_bias = paddle.rand(shape=(128,), dtype="float32")

            # self attention mask: [batch_size, 1, seq_len, seq_len]
            attn_mask = paddle.rand(shape=(2, 1, 4, 4), dtype="float32")

            # output: [batch_size, seq_len, embed_dim]
            output = F.fused_multi_transformer(
                x, [ln_scale], [ln_bias], [qkv_weight], [qkv_bias],
                [linear_weight], [linear_bias], [ffn_ln_scale], [ffn_ln_bias],
                [ffn1_weight], [ffn1_bias], [ffn2_weight], [ffn2_bias],
                attn_mask=attn_mask)
            # [2, 4, 128]
            print(output.shape)
    """
    if mode not in ('downscale_in_infer', 'upscale_in_train'):
        raise ValueError(
818 819
            "mode argument should be 'downscale_in_infer' or 'upscale_in_train'"
        )
820 821 822 823 824 825 826 827
    mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode  #semantic transfer

    if _non_static_mode():
        cache_kv_out, final_out = _C_ops.fused_multi_transformer(
            x, ln_scales, ln_biases, qkv_weights, qkv_biases, cache_kvs,
            time_step, attn_mask, linear_weights, linear_biases, ffn_ln_scales,
            ffn_ln_biases, ffn1_weights, ffn1_biases, ffn2_weights, ffn2_biases,
            cache_kvs, 'pre_layer_norm', pre_layer_norm, 'epsilon', epsilon,
L
Li Min 已提交
828
            'dropout_rate', dropout_rate, 'is_test', not training,
829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874
            'dropout_implementation', mode, 'act_method', activation, 'ring_id',
            ring_id)
        if cache_kvs is not None:
            return final_out, cache_kv_out
        return final_out
    else:
        helper = LayerHelper('fused_multi_transformer', **locals())
        dtype = x.dtype
        # check dtypes
        check_variable_and_dtype(x, 'x', ['float16', 'float32'],
                                 'fused_multi_transformer')
        check_dtype(dtype, 'dtype', ['float16', 'float32'],
                    'fused_multi_transformer')

        # set inputs
        inputs = dict()
        inputs['X'] = [x]
        inputs['LnScale'] = ln_scales
        inputs['LnBias'] = ln_biases
        inputs['QKVW'] = qkv_weights
        if qkv_biases is not None:
            inputs['QKVBias'] = qkv_biases
        if cache_kvs is not None:
            assert len(cache_kvs) == len(qkv_weights)
            inputs['CacheKV'] = cache_kvs
            if time_step is not None:
                inputs['TimeStep'] = time_step
        inputs['SrcMask'] = attn_mask
        inputs['OutLinearW'] = linear_weights
        if linear_biases is not None:
            inputs['OutLinearBias'] = linear_biases

        inputs['FFNLnScale'] = ffn_ln_scales
        inputs['FFNLnBias'] = ffn_ln_biases
        inputs['FFN1Weight'] = ffn1_weights
        if ffn1_biases is not None:
            inputs['FFN1Bias'] = ffn1_biases
        inputs['FFN2Weight'] = ffn2_weights
        if ffn2_biases is not None:
            inputs['FFN2Bias'] = ffn2_biases

        # set attrs
        attrs = {
            'pre_layer_norm': pre_layer_norm,
            'epsilon': epsilon,
            'dropout_rate': dropout_rate,
L
Li Min 已提交
875
            'is_test': not training,
876 877 878 879 880 881 882 883 884 885 886 887
            'dropout_implementation': mode,
            'act_method': activation,
            'ring_id': ring_id
        }

        outputs = dict()
        final_out = helper.create_variable_for_type_inference(dtype=dtype)
        outputs['Out'] = final_out
        if cache_kvs:
            # NOTE: inplace
            outputs['CacheKVOut'] = cache_kvs

888 889 890 891
        helper.append_op(type='fused_multi_transformer',
                         inputs=inputs,
                         outputs=outputs,
                         attrs=attrs)
892 893

        return (final_out, cache_kvs) if cache_kvs else final_out