diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 87459d7486f3626adff7b9c3a5e957b7b89ee3a6..29d4a1b1fab9a78bde93cad03bed246cc9372ea6 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -112,6 +112,7 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g #include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/ternary.h" #include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/fusion.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h" {def_primitive} diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt b/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt index ee84abd7f32704ba4782701686fd438a627d343a..69ffb2fcebb0651af3ea9fc61ea38ea60a314e02 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt @@ -17,11 +17,17 @@ set(op_backward_yaml_file1 set(op_backward_yaml_file2 ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml ) +set(fused_op_forward_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_ops.parsed.yaml +) +set(fused_op_backward_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_backward.parsed.yaml +) set(op_yaml_file3 ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.yaml) set(op_yaml_files - ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3} + ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${fused_op_forward_yaml_file},${fused_op_backward_yaml_file},${op_yaml_file3} ) set(op_namespace paddle,dialect) set(dialect_name pd) diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 0f7d5f2b7a182b16bbabfc4ac2a4a80da0b56bcc..648384422ca8abd394754ef64ccebdb889afe319 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -58,6 +58,7 @@ output: Tensor(out), Tensor(seq_lod), Tensor(max_seq_len) infer_meta : func: EmbeddingWithEltwiseAddXPUInferMeta + param : [ids, tables, mask] kernel: func: embedding_with_eltwise_add_xpu data_type: tables diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index e63852a720115995effa82a92566caee5359c574..3143c5cde2e1e5b2d05563cb0daabd98ee428fe0 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -466,11 +466,11 @@ void FusedMultiTransformerXpuInferMeta( const std::vector& ffn2_bias, const std::vector& cache_kv, const std::vector& pre_caches, - const std::vector& rotary_pos_emb, - const std::vector& time_step, - const std::vector& seq_lengths, - const std::vector& src_mask, - const std::vector& gather_index, + const MetaTensor& rotary_pos_emb, + const MetaTensor& time_step, + const MetaTensor& seq_lengths, + const MetaTensor& src_mask, + const MetaTensor& gather_index, bool pre_layer_norm, int rotary_emb_dims, float epsilon, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 618128fdcd48c5dfeab2ff071bafe56b73d7b747..25c27bdd406b96da5ee3ddaf8e68194ffa6ac532 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -143,11 +143,11 @@ void FusedMultiTransformerXpuInferMeta( const std::vector& ffn2_bias, const std::vector& cache_kv, const std::vector& pre_caches, - const std::vector& rotary_pos_emb, - const std::vector& time_step, - const std::vector& seq_lengths, - const std::vector& src_mask, - const std::vector& gather_index, + const MetaTensor& rotary_pos_emb, + const MetaTensor& time_step, + const MetaTensor& seq_lengths, + const MetaTensor& src_mask, + const MetaTensor& gather_index, bool pre_layer_norm, int rotary_emb_dims, float epsilon,