未验证 提交 d12b1ffa 编写于 作者: Y Yuanle Liu 提交者: GitHub

move delete_cast_op_pass (#52788)

上级 b835d958
...@@ -126,6 +126,7 @@ pass_library(matmul_scale_fuse_pass inference) ...@@ -126,6 +126,7 @@ pass_library(matmul_scale_fuse_pass inference)
pass_library(gpu_cpu_map_matmul_to_mul_pass inference) pass_library(gpu_cpu_map_matmul_to_mul_pass inference)
pass_library(dense_fc_to_sparse_pass inference) pass_library(dense_fc_to_sparse_pass inference)
pass_library(dense_multihead_matmul_to_sparse_pass inference) pass_library(dense_multihead_matmul_to_sparse_pass inference)
pass_library(delete_cast_op_pass inference)
pass_library(generate_pass DEPS pass_desc_proto) pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto)
...@@ -242,7 +243,6 @@ if(WITH_XPU) ...@@ -242,7 +243,6 @@ if(WITH_XPU)
pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS
${XPU_PASS_DEPS}) ${XPU_PASS_DEPS})
pass_library(stack_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(stack_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(delete_cast_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
endif() endif()
cc_library( cc_library(
...@@ -407,6 +407,11 @@ cc_test( ...@@ -407,6 +407,11 @@ cc_test(
test_delete_dequant_weight_linear_op_pass test_delete_dequant_weight_linear_op_pass
SRCS delete_weight_dequant_linear_op_pass_tester.cc SRCS delete_weight_dequant_linear_op_pass_tester.cc
DEPS delete_weight_dequant_linear_op_pass) DEPS delete_weight_dequant_linear_op_pass)
cc_test(
test_delete_cast_op_pass
SRCS delete_cast_op_pass_test.cc
DEPS delete_cast_op_pass)
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
cc_test( cc_test(
test_embedding_eltwise_layernorm_fuse_pass test_embedding_eltwise_layernorm_fuse_pass
...@@ -521,8 +526,4 @@ if(WITH_XPU) ...@@ -521,8 +526,4 @@ if(WITH_XPU)
test_stack_fuse_pass test_stack_fuse_pass
SRCS xpu/stack_fuse_pass_test.cc SRCS xpu/stack_fuse_pass_test.cc
DEPS stack_fuse_pass) DEPS stack_fuse_pass)
cc_test(
test_delete_cast_op_pass
SRCS xpu/delete_cast_op_pass_test.cc
DEPS delete_cast_op_pass)
endif() endif()
...@@ -12,10 +12,9 @@ ...@@ -12,10 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/xpu/delete_cast_op_pass.h" #include "paddle/fluid/framework/ir/delete_cast_op_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -127,11 +126,11 @@ int DeleteCastOpPass::ApplyCastWriteReadPass(ir::Graph* graph) const { ...@@ -127,11 +126,11 @@ int DeleteCastOpPass::ApplyCastWriteReadPass(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) { Graph* graph) {
VLOG(4) << "handle ApplyCastWriteReadPass fuse"; VLOG(4) << "handle ApplyCastWriteReadPass fuse";
GET_IR_NODE(cast0); GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern);
GET_IR_NODE(write_to_array); GET_IR_NODE_FROM_SUBGRAPH(write_to_array, write_to_array, pattern);
GET_IR_NODE(cast0_in); GET_IR_NODE_FROM_SUBGRAPH(cast0_in, cast0_in, pattern);
GET_IR_NODE(cast0_out); GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern);
GET_IR_NODE(write_to_array_out); GET_IR_NODE_FROM_SUBGRAPH(write_to_array_out, write_to_array_out, pattern);
// write_to_array_out(in graph1) may not link to any op nodes, so we fine // write_to_array_out(in graph1) may not link to any op nodes, so we fine
// read_from_array by write_to_array_out name. // read_from_array by write_to_array_out name.
...@@ -281,13 +280,13 @@ int DeleteCastOpPass::ApplyCastLodResetWriteReadPass(ir::Graph* graph) const { ...@@ -281,13 +280,13 @@ int DeleteCastOpPass::ApplyCastLodResetWriteReadPass(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) { Graph* graph) {
VLOG(4) << "handle ApplyCastLodResetWriteReadPass fuse"; VLOG(4) << "handle ApplyCastLodResetWriteReadPass fuse";
GET_IR_NODE(cast0); GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern);
GET_IR_NODE(lod_reset); GET_IR_NODE_FROM_SUBGRAPH(lod_reset, lod_reset, pattern);
GET_IR_NODE(write_to_array); GET_IR_NODE_FROM_SUBGRAPH(write_to_array, write_to_array, pattern);
GET_IR_NODE(cast0_in); GET_IR_NODE_FROM_SUBGRAPH(cast0_in, cast0_in, pattern);
GET_IR_NODE(cast0_out); GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern);
GET_IR_NODE(lod_reset_out); GET_IR_NODE_FROM_SUBGRAPH(lod_reset_out, lod_reset_out, pattern);
GET_IR_NODE(write_to_array_out); GET_IR_NODE_FROM_SUBGRAPH(write_to_array_out, write_to_array_out, pattern);
// write_to_array_out(in graph1) may not link to any op nodes, so we fine // write_to_array_out(in graph1) may not link to any op nodes, so we fine
// read_from_array by write_to_array_out name. // read_from_array by write_to_array_out name.
...@@ -482,13 +481,13 @@ int DeleteCastOpPass::ApplyCastIndexSamplePass(ir::Graph* graph) const { ...@@ -482,13 +481,13 @@ int DeleteCastOpPass::ApplyCastIndexSamplePass(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) { Graph* graph) {
VLOG(4) << "handle ApplyCastIndexSamplePass fuse"; VLOG(4) << "handle ApplyCastIndexSamplePass fuse";
GET_IR_NODE(cast0); GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern);
GET_IR_NODE(index_sample); GET_IR_NODE_FROM_SUBGRAPH(index_sample, index_sample, pattern);
GET_IR_NODE(cast1); GET_IR_NODE_FROM_SUBGRAPH(cast1, cast1, pattern);
GET_IR_NODE(cast0_in); GET_IR_NODE_FROM_SUBGRAPH(cast0_in, cast0_in, pattern);
GET_IR_NODE(cast0_out); GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern);
GET_IR_NODE(index_sample_out); GET_IR_NODE_FROM_SUBGRAPH(index_sample_out, index_sample_out, pattern);
GET_IR_NODE(cast1_out); GET_IR_NODE_FROM_SUBGRAPH(cast1_out, cast1_out, pattern);
index_sample->Op()->RenameInput(cast0_out->Name(), cast0_in->Name()); index_sample->Op()->RenameInput(cast0_out->Name(), cast0_in->Name());
index_sample->Op()->RenameOutput(index_sample_out->Name(), index_sample->Op()->RenameOutput(index_sample_out->Name(),
...@@ -545,9 +544,9 @@ int DeleteCastOpPass::ApplyCastPass(ir::Graph* graph) const { ...@@ -545,9 +544,9 @@ int DeleteCastOpPass::ApplyCastPass(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) { Graph* graph) {
VLOG(4) << "handle ApplyCastPass fuse"; VLOG(4) << "handle ApplyCastPass fuse";
GET_IR_NODE(cast); GET_IR_NODE_FROM_SUBGRAPH(cast, cast, pattern);
GET_IR_NODE(cast_in); GET_IR_NODE_FROM_SUBGRAPH(cast_in, cast_in, pattern);
GET_IR_NODE(cast_out); GET_IR_NODE_FROM_SUBGRAPH(cast_out, cast_out, pattern);
for (auto* out_op_node : cast_out->outputs) { for (auto* out_op_node : cast_out->outputs) {
out_op_node->Op()->RenameInput(cast_out->Name(), cast_in->Name()); out_op_node->Op()->RenameInput(cast_out->Name(), cast_in->Name());
IR_NODE_LINK_TO(cast_in, out_op_node); IR_NODE_LINK_TO(cast_in, out_op_node);
......
...@@ -276,6 +276,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -276,6 +276,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"transpose_flatten_concat_fuse_pass", // "transpose_flatten_concat_fuse_pass", //
"conv2d_fusion_layout_transfer_pass", // "conv2d_fusion_layout_transfer_pass", //
"auto_mixed_precision_pass", // "auto_mixed_precision_pass", //
"delete_cast_op_pass", //
"inplace_op_var_pass", // should be the last pass. "inplace_op_var_pass", // should be the last pass.
}); });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册