From d12b1ffa4bca06c10ed9b70a2675285cfaae818b Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Wed, 12 Apr 2023 10:32:00 +0800 Subject: [PATCH] move delete_cast_op_pass (#52788) --- paddle/fluid/framework/ir/CMakeLists.txt | 11 +++-- .../ir/{xpu => }/delete_cast_op_pass.cc | 49 +++++++++---------- .../ir/{xpu => }/delete_cast_op_pass.h | 0 .../ir/{xpu => }/delete_cast_op_pass_test.cc | 0 .../inference/api/paddle_pass_builder.cc | 1 + 5 files changed, 31 insertions(+), 30 deletions(-) rename paddle/fluid/framework/ir/{xpu => }/delete_cast_op_pass.cc (93%) rename paddle/fluid/framework/ir/{xpu => }/delete_cast_op_pass.h (100%) rename paddle/fluid/framework/ir/{xpu => }/delete_cast_op_pass_test.cc (100%) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 91c3ba6d608..b1db3dd0a43 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -126,6 +126,7 @@ pass_library(matmul_scale_fuse_pass inference) pass_library(gpu_cpu_map_matmul_to_mul_pass inference) pass_library(dense_fc_to_sparse_pass inference) pass_library(dense_multihead_matmul_to_sparse_pass inference) +pass_library(delete_cast_op_pass inference) pass_library(generate_pass DEPS pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto) @@ -242,7 +243,6 @@ if(WITH_XPU) pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(stack_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) - pass_library(delete_cast_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) endif() cc_library( @@ -407,6 +407,11 @@ cc_test( test_delete_dequant_weight_linear_op_pass SRCS delete_weight_dequant_linear_op_pass_tester.cc DEPS delete_weight_dequant_linear_op_pass) +cc_test( + test_delete_cast_op_pass + SRCS delete_cast_op_pass_test.cc + DEPS delete_cast_op_pass) + if(WITH_GPU OR WITH_ROCM) cc_test( test_embedding_eltwise_layernorm_fuse_pass @@ -521,8 +526,4 @@ if(WITH_XPU) test_stack_fuse_pass SRCS xpu/stack_fuse_pass_test.cc DEPS stack_fuse_pass) - cc_test( - test_delete_cast_op_pass - SRCS xpu/delete_cast_op_pass_test.cc - DEPS delete_cast_op_pass) endif() diff --git a/paddle/fluid/framework/ir/xpu/delete_cast_op_pass.cc b/paddle/fluid/framework/ir/delete_cast_op_pass.cc similarity index 93% rename from paddle/fluid/framework/ir/xpu/delete_cast_op_pass.cc rename to paddle/fluid/framework/ir/delete_cast_op_pass.cc index fb417322476..bfda0f32380 100644 --- a/paddle/fluid/framework/ir/xpu/delete_cast_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_cast_op_pass.cc @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/xpu/delete_cast_op_pass.h" -#include +#include "paddle/fluid/framework/ir/delete_cast_op_pass.h" + #include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include "paddle/fluid/framework/ir/xpu/pass_utils.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -127,11 +126,11 @@ int DeleteCastOpPass::ApplyCastWriteReadPass(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(4) << "handle ApplyCastWriteReadPass fuse"; - GET_IR_NODE(cast0); - GET_IR_NODE(write_to_array); - GET_IR_NODE(cast0_in); - GET_IR_NODE(cast0_out); - GET_IR_NODE(write_to_array_out); + GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern); + GET_IR_NODE_FROM_SUBGRAPH(write_to_array, write_to_array, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast0_in, cast0_in, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(write_to_array_out, write_to_array_out, pattern); // write_to_array_out(in graph1) may not link to any op nodes, so we fine // read_from_array by write_to_array_out name. @@ -281,13 +280,13 @@ int DeleteCastOpPass::ApplyCastLodResetWriteReadPass(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(4) << "handle ApplyCastLodResetWriteReadPass fuse"; - GET_IR_NODE(cast0); - GET_IR_NODE(lod_reset); - GET_IR_NODE(write_to_array); - GET_IR_NODE(cast0_in); - GET_IR_NODE(cast0_out); - GET_IR_NODE(lod_reset_out); - GET_IR_NODE(write_to_array_out); + GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern); + GET_IR_NODE_FROM_SUBGRAPH(lod_reset, lod_reset, pattern); + GET_IR_NODE_FROM_SUBGRAPH(write_to_array, write_to_array, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast0_in, cast0_in, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(lod_reset_out, lod_reset_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(write_to_array_out, write_to_array_out, pattern); // write_to_array_out(in graph1) may not link to any op nodes, so we fine // read_from_array by write_to_array_out name. @@ -482,13 +481,13 @@ int DeleteCastOpPass::ApplyCastIndexSamplePass(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(4) << "handle ApplyCastIndexSamplePass fuse"; - GET_IR_NODE(cast0); - GET_IR_NODE(index_sample); - GET_IR_NODE(cast1); - GET_IR_NODE(cast0_in); - GET_IR_NODE(cast0_out); - GET_IR_NODE(index_sample_out); - GET_IR_NODE(cast1_out); + GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern); + GET_IR_NODE_FROM_SUBGRAPH(index_sample, index_sample, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast1, cast1, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast0_in, cast0_in, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(index_sample_out, index_sample_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast1_out, cast1_out, pattern); index_sample->Op()->RenameInput(cast0_out->Name(), cast0_in->Name()); index_sample->Op()->RenameOutput(index_sample_out->Name(), @@ -545,9 +544,9 @@ int DeleteCastOpPass::ApplyCastPass(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(4) << "handle ApplyCastPass fuse"; - GET_IR_NODE(cast); - GET_IR_NODE(cast_in); - GET_IR_NODE(cast_out); + GET_IR_NODE_FROM_SUBGRAPH(cast, cast, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast_in, cast_in, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast_out, cast_out, pattern); for (auto* out_op_node : cast_out->outputs) { out_op_node->Op()->RenameInput(cast_out->Name(), cast_in->Name()); IR_NODE_LINK_TO(cast_in, out_op_node); diff --git a/paddle/fluid/framework/ir/xpu/delete_cast_op_pass.h b/paddle/fluid/framework/ir/delete_cast_op_pass.h similarity index 100% rename from paddle/fluid/framework/ir/xpu/delete_cast_op_pass.h rename to paddle/fluid/framework/ir/delete_cast_op_pass.h diff --git a/paddle/fluid/framework/ir/xpu/delete_cast_op_pass_test.cc b/paddle/fluid/framework/ir/delete_cast_op_pass_test.cc similarity index 100% rename from paddle/fluid/framework/ir/xpu/delete_cast_op_pass_test.cc rename to paddle/fluid/framework/ir/delete_cast_op_pass_test.cc diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 3cc8b077ad7..a1fe08b081e 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -276,6 +276,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "transpose_flatten_concat_fuse_pass", // "conv2d_fusion_layout_transfer_pass", // "auto_mixed_precision_pass", // + "delete_cast_op_pass", // "inplace_op_var_pass", // should be the last pass. }); -- GitLab