diff --git a/cmake/external/cinn.cmake b/cmake/external/cinn.cmake index 1ca029b3add4cedab3c5a22487cb31929cf126a9..004bf353d34e8ec0f19582663dfb18231152722e 100644 --- a/cmake/external/cinn.cmake +++ b/cmake/external/cinn.cmake @@ -26,7 +26,7 @@ add_definitions(-w) ###################################### include(ExternalProject) set(CINN_PREFIX_DIR ${THIRD_PARTY_PATH}/CINN) -set(CINN_GIT_TAG 08d7680dd91dfaa65787969050eb8f1143654f10) +set(CINN_GIT_TAG eedb801ca39bfc6b9621bc76c24a0bf98cb8404b) set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION} -DWITH_CUDA=${WITH_GPU} -DWITH_CUDNN=${WITH_GPU} diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index 67393c288df86b083244ad3f60e2172a13562a00..51dca93c7c7f0c37bea5aa5bd6458f9953814f69 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -25,14 +25,10 @@ #include "cinn/auto_schedule/tuning.h" #include "cinn/common/target.h" #include "cinn/common/type.h" -#include "cinn/frontend/decomposer/use_decomposer.h" -#include "cinn/frontend/pass/use_program_pass.h" -#include "cinn/frontend/program_pass.h" +#include "cinn/frontend/optimize.h" #include "cinn/frontend/syntax.h" #include "cinn/hlir/framework/graph.h" #include "cinn/hlir/framework/graph_compiler.h" -#include "cinn/hlir/framework/pass.h" -#include "cinn/hlir/pass/use_pass.h" #include "gflags/gflags.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/ir/graph.h" @@ -58,13 +54,11 @@ namespace paddle2cinn { using ir::Graph; using ir::Node; using inference::analysis::Dot; -using ::cinn::common::Target; -using ::cinn::common::Float; -using ::cinn::hlir::framework::GraphCompiler; using ::cinn::auto_schedule::AutoTuner; +using ::cinn::common::Target; +using ::cinn::frontend::Optimize; using ::cinn::hlir::framework::BuildScope; -using ::cinn::frontend::ProgramPass; -using ::cinn::hlir::framework::ApplyPass; +using ::cinn::hlir::framework::GraphCompiler; CinnCompiler* CinnCompiler::GetInstance() { static CinnCompiler instance; @@ -75,7 +69,7 @@ const CinnCompiledObject& CinnCompiler::Compile( const Graph& graph, const std::map& input_tensors, const Target& target, void* stream) { - VLOG(1) << "-- The graph to be compiled is:\n" << VizGraph(graph); + VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(graph); CinnCacheKeyByAddress cur_key_by_address(graph, input_tensors, target.arch_str()); CinnCacheKeyByStructure cur_key_by_struct; @@ -258,22 +252,15 @@ std::unique_ptr CinnCompiler::CompileGraph( CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors}; auto frontend_program = symbol(); auto fetch_ids = symbol.GetFetchIds(); - ProgramPass::Apply(&frontend_program, fetch_ids, target, {"Decomposer"}); - ::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "RemoveIdentity"); - ::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "TransposeFolding"); - ProgramPass::Apply(&frontend_program, fetch_ids, target, {"GemmRewriter"}); + VLOG(4) << "All fetch var ids in CINN: " + << string::join_strings(fetch_ids, ','); - auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>( - frontend_program, target); - VLOG(1) << "-- The " << compiled_num << "-th compilation (" + auto cinn_graph = Optimize(&frontend_program, fetch_ids, target); + VLOG(4) << "-- The " << compiled_num << "-th compilation (" << target.arch_str() << "), and its related graph:\n" << cinn_graph->Visualize(); - ApplyPass(cinn_graph.get(), "OpFusion"); - auto scope = BuildScope(target, cinn_graph); - - VLOG(4) << "All fetch var ids in CINN: " - << string::join_strings(fetch_ids, ','); + auto scope = BuildScope(target, cinn_graph); auto graph_compiler = std::make_unique(target, scope, cinn_graph); GraphCompiler::CompileOptions options;