diff --git a/cmake/external/cinn.cmake b/cmake/external/cinn.cmake index 75df827a438841f71fce8170c7e12a96bca26439..cd4e0157f2a324005864f82fb0a53334dd060b97 100644 --- a/cmake/external/cinn.cmake +++ b/cmake/external/cinn.cmake @@ -26,7 +26,7 @@ add_definitions(-w) ###################################### include(ExternalProject) set(CINN_PREFIX_DIR ${THIRD_PARTY_PATH}/CINN) -set(CINN_GIT_TAG e11c5e672f9961e28cfa403d86f99808beb58817) +set(CINN_GIT_TAG 1fd85187b6c18da4dd51f22619d093ef08d61b01) set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION} -DWITH_CUDA=${WITH_GPU} -DWITH_CUDNN=${WITH_GPU} diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index c015e90f71e54691e92c3a36c3d6e053372f64f3..6cde65f6ab580eba1d8ccaa2d08a4b9ccc097be6 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -223,9 +223,12 @@ std::unique_ptr CinnCompiler::CompileGraph( const Target& target, std::int64_t compiled_num, void* stream) const { CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors}; auto frontend_program = symbol(); - ProgramPass::Apply(&frontend_program, target, {"Decomposer"}); auto fetch_ids = symbol.GetFetchIds(); + ProgramPass::Apply(&frontend_program, fetch_ids, target, {"Decomposer"}); ::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "RemoveIdentity"); + ::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "TransposeFolding"); + ProgramPass::Apply(&frontend_program, fetch_ids, target, {"GemmRewriter"}); + auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>( frontend_program, target); VLOG(1) << "-- The " << compiled_num << "-th compilation ("