未验证 提交 0d47f387 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Auto gen fused op (#56585)

* add code

* fix bug

* fix bug
上级 773ee87c
...@@ -112,6 +112,7 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g ...@@ -112,6 +112,7 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/ternary.h" #include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/fusion.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h"
{def_primitive} {def_primitive}
......
...@@ -17,11 +17,17 @@ set(op_backward_yaml_file1 ...@@ -17,11 +17,17 @@ set(op_backward_yaml_file1
set(op_backward_yaml_file2 set(op_backward_yaml_file2
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml
) )
set(fused_op_forward_yaml_file
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_ops.parsed.yaml
)
set(fused_op_backward_yaml_file
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_backward.parsed.yaml
)
set(op_yaml_file3 set(op_yaml_file3
${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.yaml) ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.yaml)
set(op_yaml_files set(op_yaml_files
${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3} ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${fused_op_forward_yaml_file},${fused_op_backward_yaml_file},${op_yaml_file3}
) )
set(op_namespace paddle,dialect) set(op_namespace paddle,dialect)
set(dialect_name pd) set(dialect_name pd)
......
...@@ -58,6 +58,7 @@ ...@@ -58,6 +58,7 @@
output: Tensor(out), Tensor(seq_lod), Tensor(max_seq_len) output: Tensor(out), Tensor(seq_lod), Tensor(max_seq_len)
infer_meta : infer_meta :
func: EmbeddingWithEltwiseAddXPUInferMeta func: EmbeddingWithEltwiseAddXPUInferMeta
param : [ids, tables, mask]
kernel: kernel:
func: embedding_with_eltwise_add_xpu func: embedding_with_eltwise_add_xpu
data_type: tables data_type: tables
......
...@@ -466,11 +466,11 @@ void FusedMultiTransformerXpuInferMeta( ...@@ -466,11 +466,11 @@ void FusedMultiTransformerXpuInferMeta(
const std::vector<const MetaTensor*>& ffn2_bias, const std::vector<const MetaTensor*>& ffn2_bias,
const std::vector<const MetaTensor*>& cache_kv, const std::vector<const MetaTensor*>& cache_kv,
const std::vector<const MetaTensor*>& pre_caches, const std::vector<const MetaTensor*>& pre_caches,
const std::vector<const MetaTensor*>& rotary_pos_emb, const MetaTensor& rotary_pos_emb,
const std::vector<const MetaTensor*>& time_step, const MetaTensor& time_step,
const std::vector<const MetaTensor*>& seq_lengths, const MetaTensor& seq_lengths,
const std::vector<const MetaTensor*>& src_mask, const MetaTensor& src_mask,
const std::vector<const MetaTensor*>& gather_index, const MetaTensor& gather_index,
bool pre_layer_norm, bool pre_layer_norm,
int rotary_emb_dims, int rotary_emb_dims,
float epsilon, float epsilon,
......
...@@ -143,11 +143,11 @@ void FusedMultiTransformerXpuInferMeta( ...@@ -143,11 +143,11 @@ void FusedMultiTransformerXpuInferMeta(
const std::vector<const MetaTensor*>& ffn2_bias, const std::vector<const MetaTensor*>& ffn2_bias,
const std::vector<const MetaTensor*>& cache_kv, const std::vector<const MetaTensor*>& cache_kv,
const std::vector<const MetaTensor*>& pre_caches, const std::vector<const MetaTensor*>& pre_caches,
const std::vector<const MetaTensor*>& rotary_pos_emb, const MetaTensor& rotary_pos_emb,
const std::vector<const MetaTensor*>& time_step, const MetaTensor& time_step,
const std::vector<const MetaTensor*>& seq_lengths, const MetaTensor& seq_lengths,
const std::vector<const MetaTensor*>& src_mask, const MetaTensor& src_mask,
const std::vector<const MetaTensor*>& gather_index, const MetaTensor& gather_index,
bool pre_layer_norm, bool pre_layer_norm,
int rotary_emb_dims, int rotary_emb_dims,
float epsilon, float epsilon,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册