From 0d47f3872715cebbf556a7af1cc4743f93235ee5 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Thu, 24 Aug 2023 10:35:30 +0800 Subject: [PATCH] [IR] Auto gen fused op (#56585) * add code * fix bug * fix bug --- paddle/fluid/ir/dialect/op_generator/op_gen.py | 1 + .../fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt | 8 +++++++- paddle/phi/api/yaml/fused_ops.yaml | 1 + paddle/phi/infermeta/fusion.cc | 10 +++++----- paddle/phi/infermeta/fusion.h | 10 +++++----- 5 files changed, 19 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 87459d7486f..29d4a1b1fab 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -112,6 +112,7 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g #include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/ternary.h" #include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/fusion.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h" {def_primitive} diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt b/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt index ee84abd7f32..69ffb2fcebb 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt @@ -17,11 +17,17 @@ set(op_backward_yaml_file1 set(op_backward_yaml_file2 ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml ) +set(fused_op_forward_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_ops.parsed.yaml +) +set(fused_op_backward_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_backward.parsed.yaml +) set(op_yaml_file3 ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.yaml) set(op_yaml_files - ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3} + ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${fused_op_forward_yaml_file},${fused_op_backward_yaml_file},${op_yaml_file3} ) set(op_namespace paddle,dialect) set(dialect_name pd) diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 0f7d5f2b7a1..648384422ca 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -58,6 +58,7 @@ output: Tensor(out), Tensor(seq_lod), Tensor(max_seq_len) infer_meta : func: EmbeddingWithEltwiseAddXPUInferMeta + param : [ids, tables, mask] kernel: func: embedding_with_eltwise_add_xpu data_type: tables diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index e63852a7201..3143c5cde2e 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -466,11 +466,11 @@ void FusedMultiTransformerXpuInferMeta( const std::vector& ffn2_bias, const std::vector& cache_kv, const std::vector& pre_caches, - const std::vector& rotary_pos_emb, - const std::vector& time_step, - const std::vector& seq_lengths, - const std::vector& src_mask, - const std::vector& gather_index, + const MetaTensor& rotary_pos_emb, + const MetaTensor& time_step, + const MetaTensor& seq_lengths, + const MetaTensor& src_mask, + const MetaTensor& gather_index, bool pre_layer_norm, int rotary_emb_dims, float epsilon, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 618128fdcd4..25c27bdd406 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -143,11 +143,11 @@ void FusedMultiTransformerXpuInferMeta( const std::vector& ffn2_bias, const std::vector& cache_kv, const std::vector& pre_caches, - const std::vector& rotary_pos_emb, - const std::vector& time_step, - const std::vector& seq_lengths, - const std::vector& src_mask, - const std::vector& gather_index, + const MetaTensor& rotary_pos_emb, + const MetaTensor& time_step, + const MetaTensor& seq_lengths, + const MetaTensor& src_mask, + const MetaTensor& gather_index, bool pre_layer_norm, int rotary_emb_dims, float epsilon, -- GitLab