From 2811dcd0d06383aed31b3ff3f2c065ab9d1ca7fe Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 5 Jan 2023 12:50:27 +0800 Subject: [PATCH] udpate_fused_attention_en_docs, test=document_fix (#49564) --- python/paddle/incubate/nn/functional/fused_transformer.py | 6 +++--- python/paddle/incubate/nn/layer/fused_transformer.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index d8dfdb0242..01d2161b22 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -526,7 +526,7 @@ def fused_multi_head_attention( Parameters: x (Tensor): The input tensor of fused_multi_head_attention. The shape is `[batch\_size, sequence\_len, embed\_dim]`. - qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`. + qkv_weight (Tensor): The qkv weight tensor. If `transpose_qkv_wb` is False, the shape is `[3, num_head, dim_head, dim_embed]`. Otherwise, the shape is `[dim_embed, 3 * dim_embed]`. linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`. pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture (False). Default False. @@ -536,7 +536,7 @@ def fused_multi_head_attention( ln_bias (Tensor, optional): The bias tensor of layernorm. Default None. pre_ln_epsilon (float, optional): Small float value added to denominator of the pre layer_norm to avoid dividing by zero. Default is 1e-5. - qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`. + qkv_bias (Tensor, optional): The bias of qkv computation. If `transpose_qkv_wb` is False, the shape is `[3, num_head, dim_head]`. Otherwise, the shape is `[3 * dim_embed]`. Default None. linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None. cache_kv (Tensor, optional): For generation model, cache structure. The shape is `[2, bsz, num_head, seq_len, head_dim]`. Default None. @@ -571,7 +571,7 @@ def fused_multi_head_attention( add_residual (bool, optional): Whether add residual at the end. Default is True. num_heads (int, optional): If enable transpose_qkv_wb, should provide the num_heads. Default is -1, means not transpose qkv wb. transpose_qkv_wb (bool, optional): Whether transpose the qkv_weight and qkv_bias in the op. Only support GPU for now. Default is false, means not transpose qkv wb. - name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor: The output Tensor, the data type and shape is same as `x`. diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index b49cc50086..d1305b674f 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -246,12 +246,13 @@ class FusedMultiHeadAttention(Layer): epsilon (float, optional): The small value added to the variance to prevent division by zero. Default: 1e-05. nranks (int, optional): Distributed tensor model parallel nranks. Default is 1, means not using tensor parallel. + ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using tensor parallel. transpose_qkv_wb (bool, optional): Support input qkv matmul weight shape as [hidden_size, 3 * hidden_size] and qkv matmul bias shape as [3 * hidden_size]. Will transpose the weight to [3, num_head, head_dim, hidden_size] and transpose bias to [3, num_head, hidden_size] in the fused_attention_op. Only support for GPU for now. The default value is False, which is not do transpose to qkv_w and qkv_b. - ring_id (int, optional): For distributed tensor model parallel. Default is -1, means not using tensor parallel. + name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Examples: -- GitLab