未验证 提交 d12b1ffa 编写于 作者: Y Yuanle Liu 提交者: GitHub

move delete_cast_op_pass (#52788)

上级 b835d958
......@@ -126,6 +126,7 @@ pass_library(matmul_scale_fuse_pass inference)
pass_library(gpu_cpu_map_matmul_to_mul_pass inference)
pass_library(dense_fc_to_sparse_pass inference)
pass_library(dense_multihead_matmul_to_sparse_pass inference)
pass_library(delete_cast_op_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)
......@@ -242,7 +243,6 @@ if(WITH_XPU)
pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(stack_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(delete_cast_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
endif()
cc_library(
......@@ -407,6 +407,11 @@ cc_test(
test_delete_dequant_weight_linear_op_pass
SRCS delete_weight_dequant_linear_op_pass_tester.cc
DEPS delete_weight_dequant_linear_op_pass)
cc_test(
test_delete_cast_op_pass
SRCS delete_cast_op_pass_test.cc
DEPS delete_cast_op_pass)
if(WITH_GPU OR WITH_ROCM)
cc_test(
test_embedding_eltwise_layernorm_fuse_pass
......@@ -521,8 +526,4 @@ if(WITH_XPU)
test_stack_fuse_pass
SRCS xpu/stack_fuse_pass_test.cc
DEPS stack_fuse_pass)
cc_test(
test_delete_cast_op_pass
SRCS xpu/delete_cast_op_pass_test.cc
DEPS delete_cast_op_pass)
endif()
......@@ -12,10 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/delete_cast_op_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/delete_cast_op_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -127,11 +126,11 @@ int DeleteCastOpPass::ApplyCastWriteReadPass(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ApplyCastWriteReadPass fuse";
GET_IR_NODE(cast0);
GET_IR_NODE(write_to_array);
GET_IR_NODE(cast0_in);
GET_IR_NODE(cast0_out);
GET_IR_NODE(write_to_array_out);
GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern);
GET_IR_NODE_FROM_SUBGRAPH(write_to_array, write_to_array, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast0_in, cast0_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(write_to_array_out, write_to_array_out, pattern);
// write_to_array_out(in graph1) may not link to any op nodes, so we fine
// read_from_array by write_to_array_out name.
......@@ -281,13 +280,13 @@ int DeleteCastOpPass::ApplyCastLodResetWriteReadPass(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ApplyCastLodResetWriteReadPass fuse";
GET_IR_NODE(cast0);
GET_IR_NODE(lod_reset);
GET_IR_NODE(write_to_array);
GET_IR_NODE(cast0_in);
GET_IR_NODE(cast0_out);
GET_IR_NODE(lod_reset_out);
GET_IR_NODE(write_to_array_out);
GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern);
GET_IR_NODE_FROM_SUBGRAPH(lod_reset, lod_reset, pattern);
GET_IR_NODE_FROM_SUBGRAPH(write_to_array, write_to_array, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast0_in, cast0_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(lod_reset_out, lod_reset_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(write_to_array_out, write_to_array_out, pattern);
// write_to_array_out(in graph1) may not link to any op nodes, so we fine
// read_from_array by write_to_array_out name.
......@@ -482,13 +481,13 @@ int DeleteCastOpPass::ApplyCastIndexSamplePass(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ApplyCastIndexSamplePass fuse";
GET_IR_NODE(cast0);
GET_IR_NODE(index_sample);
GET_IR_NODE(cast1);
GET_IR_NODE(cast0_in);
GET_IR_NODE(cast0_out);
GET_IR_NODE(index_sample_out);
GET_IR_NODE(cast1_out);
GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern);
GET_IR_NODE_FROM_SUBGRAPH(index_sample, index_sample, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast1, cast1, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast0_in, cast0_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(index_sample_out, index_sample_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast1_out, cast1_out, pattern);
index_sample->Op()->RenameInput(cast0_out->Name(), cast0_in->Name());
index_sample->Op()->RenameOutput(index_sample_out->Name(),
......@@ -545,9 +544,9 @@ int DeleteCastOpPass::ApplyCastPass(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ApplyCastPass fuse";
GET_IR_NODE(cast);
GET_IR_NODE(cast_in);
GET_IR_NODE(cast_out);
GET_IR_NODE_FROM_SUBGRAPH(cast, cast, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast_in, cast_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast_out, cast_out, pattern);
for (auto* out_op_node : cast_out->outputs) {
out_op_node->Op()->RenameInput(cast_out->Name(), cast_in->Name());
IR_NODE_LINK_TO(cast_in, out_op_node);
......
......@@ -276,6 +276,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"transpose_flatten_concat_fuse_pass", //
"conv2d_fusion_layout_transfer_pass", //
"auto_mixed_precision_pass", //
"delete_cast_op_pass", //
"inplace_op_var_pass", // should be the last pass.
});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册