From e8c6c7dfe5a7746a12e02a9df0bca4b3ddc08ca5 Mon Sep 17 00:00:00 2001 From: wuhuanzhou Date: Thu, 2 Dec 2021 11:02:28 +0800 Subject: [PATCH] cinn_compiler add RemoveIdentity pass after Decomposer, test=develop (#37738) --- paddle/fluid/framework/paddle2cinn/cinn_compiler.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index 360c927078..7fc8eff3d3 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -193,6 +193,8 @@ std::unique_ptr CinnCompiler::CompileGraph( CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors}; auto frontend_program = symbol(); ProgramPass::Apply(&frontend_program, target, {"Decomposer"}); + auto fetch_ids = symbol.GetFetchIds(); + ::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "RemoveIdentity"); auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>( frontend_program, target); VLOG(1) << "-- The " << compiled_num << "-th compilation (" @@ -201,7 +203,6 @@ std::unique_ptr CinnCompiler::CompileGraph( ApplyPass(cinn_graph.get(), "OpFusion"); auto scope = BuildScope(target, cinn_graph); - auto fetch_ids = symbol.GetFetchIds(); VLOG(4) << "All fetch var ids in CINN: " << string::join_strings(fetch_ids, ','); -- GitLab