From c761b48be104838a9e2576d7b9ff713ea5ef75f5 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Tue, 29 Mar 2022 21:44:32 -0700 Subject: [PATCH] Apply TransposeFolding & GemmRewriter passes. (#41084) --- cmake/external/cinn.cmake | 2 +- paddle/fluid/framework/paddle2cinn/cinn_compiler.cc | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cmake/external/cinn.cmake b/cmake/external/cinn.cmake index 75df827a43..cd4e0157f2 100644 --- a/cmake/external/cinn.cmake +++ b/cmake/external/cinn.cmake @@ -26,7 +26,7 @@ add_definitions(-w) ###################################### include(ExternalProject) set(CINN_PREFIX_DIR ${THIRD_PARTY_PATH}/CINN) -set(CINN_GIT_TAG e11c5e672f9961e28cfa403d86f99808beb58817) +set(CINN_GIT_TAG 1fd85187b6c18da4dd51f22619d093ef08d61b01) set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION} -DWITH_CUDA=${WITH_GPU} -DWITH_CUDNN=${WITH_GPU} diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index c015e90f71..6cde65f6ab 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -223,9 +223,12 @@ std::unique_ptr CinnCompiler::CompileGraph( const Target& target, std::int64_t compiled_num, void* stream) const { CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors}; auto frontend_program = symbol(); - ProgramPass::Apply(&frontend_program, target, {"Decomposer"}); auto fetch_ids = symbol.GetFetchIds(); + ProgramPass::Apply(&frontend_program, fetch_ids, target, {"Decomposer"}); ::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "RemoveIdentity"); + ::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "TransposeFolding"); + ProgramPass::Apply(&frontend_program, fetch_ids, target, {"GemmRewriter"}); + auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>( frontend_program, target); VLOG(1) << "-- The " << compiled_num << "-th compilation (" -- GitLab