未验证 提交 3bac6264 编写于 作者: S Sonder 提交者: GitHub

Move fused_attention op to phi [迁移反向 GPU OpKernel] (#51909)

* add kernel functions

* update kernel functions

* update func parameters' name

* create codes for gpu device

* 调整文件位置

* fix include error

* remove dependent files to phi/

* restore fused_attention_op.cu

* fix dependence errors

* fix dependence errors

* fix include error

* fix all depandence errors[build success]

* remove useless include

* recover useless include

* use phi::ToNCCLDataType

* fix namespace

* update new register code

* fix error in fused_gemm_epilogue_utils

* fix error in FusedAttentionKernel parm

* finish fused_attention registe code[build success]

* add paddle::optional

* add sig file

* fix build error

* fix a include error

* 恢复正向代码

* update CMkaeList

* trans Compute function to phi [build success]

* add register code and fix include error [build success]

* fix parameter sequence

* add include file

* update #if before include

* update #if before include

* fix grammly error

* update codes for DropoutParam

* remove const cast

* trans some fluid api to phi api

* remove const cast

* trans some fluid api to phi api

* add #if

* update test code

* update test codes

* recover test codes

* fix namespace and remove fluid include

* recover random seed

* remove fluid quant_helper

* fix include error

* include utils in funcs

* change include file

* move grad codes back to fluid floder

* move grad codes back to fluid floder

* fix sig file error

* update include

* recover codes to develop

* update register codes

* fix build error

* recover fluid include

* remove some fluid include

* remove some fluid include

* Update fused_attention_op.cu

* remove fluid include

* add some fluid include

* Update fused_attention_op.cu

* Update fused_attention_op.cu

* Update fused_attention_op.cu

* Update fused_attention_op.cu

* remote useless include
上级 ab163063
......@@ -58,7 +58,83 @@ KernelSignature AttentionFuseOpArgumentMapping(
"CacheKVOut", "Y"});
}
KernelSignature AttentionGradFuseOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fused_attention_grad",
{"Y@GRAD",
"X",
"QKVW",
"QKVBias",
"QKVBiasOut",
"SrcMask",
"SrcMaskOut",
"OutLinearW",
"OutLinearBias",
"LnScale",
"LnBias",
"Ln2Scale",
"Ln2Bias",
"LnOut",
"LnMean",
"LnVariance",
"Ln2Mean",
"Ln2Variance",
"BiasDropoutResidualOut",
"QKVOut",
"TransposeOut2",
"QKOut",
"QKTVOut",
"SoftmaxOut",
"AttnDropoutMaskOut",
"AttnDropoutOut",
"FMHAOut",
"OutLinearOut",
"DropoutMaskOut"},
{"num_heads",
"transpose_qkv_wb",
"pre_layer_norm",
"epsilon",
"attn_dropout_rate",
"is_test",
"attn_dropout_fix_seed",
"attn_dropout_seed",
"attn_dropout_implementation",
"dropout_rate",
"dropout_fix_seed",
"dropout_seed",
"dropout_implementation",
"ln_epsilon",
"add_residual",
"ring_id"},
{
"QKVBias@GRAD",
"QKVBiasOut@GRAD",
"SrcMaskOut@GRAD",
"OutLinearBias@GRAD",
"LnScale@GRAD",
"LnBias@GRAD",
"Ln2Scale@GRAD",
"Ln2Bias@GRAD",
"X@GRAD",
"QKVW@GRAD",
"OutLinearW@GRAD",
"LnOut@GRAD",
"BiasDropoutResidualOut@GRAD",
"QKVOut@GRAD",
"QKTVOut@GRAD",
"TransposeOut2@GRAD",
"QKOut@GRAD",
"SoftmaxOut@GRAD",
"AttnDropoutOut@GRAD",
"FMHAOut@GRAD",
"OutLinearOut@GRAD",
});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(fused_attention,
phi::AttentionFuseOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(fused_attention_grad,
phi::AttentionGradFuseOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册