From f5d356b8f4e198eefb5c6e2b8e936bb6c1eb4aa7 Mon Sep 17 00:00:00 2001 From: TeFeng Chen Date: Thu, 21 Apr 2022 10:25:21 +0800 Subject: [PATCH] [cherry-pick] enable auto-tune when using cinn (#41795) (#42006) cherry-pick #41795 --- cmake/external/cinn.cmake | 3 +-- cmake/external/xpu.cmake | 2 +- .../fluid/framework/paddle2cinn/cinn_compiler.cc | 16 +++++++++++++++- .../fluid/framework/paddle2cinn/cinn_compiler.h | 5 +++++ .../operators/cinn/cinn_launch_context_test.cc | 1 + .../fluid/operators/cinn/cinn_launch_op_test.cc | 9 +++++++++ 6 files changed, 32 insertions(+), 4 deletions(-) diff --git a/cmake/external/cinn.cmake b/cmake/external/cinn.cmake index cd4e0157f2a..1ca029b3add 100644 --- a/cmake/external/cinn.cmake +++ b/cmake/external/cinn.cmake @@ -26,7 +26,7 @@ add_definitions(-w) ###################################### include(ExternalProject) set(CINN_PREFIX_DIR ${THIRD_PARTY_PATH}/CINN) -set(CINN_GIT_TAG 1fd85187b6c18da4dd51f22619d093ef08d61b01) +set(CINN_GIT_TAG 08d7680dd91dfaa65787969050eb8f1143654f10) set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION} -DWITH_CUDA=${WITH_GPU} -DWITH_CUDNN=${WITH_GPU} @@ -85,4 +85,3 @@ add_library(cinn SHARED IMPORTED GLOBAL) set_target_properties(cinn PROPERTIES IMPORTED_LOCATION "${CINN_LIB_LOCATION}/${CINN_LIB_NAME}") include_directories(${CINN_INCLUDE_DIR}) add_dependencies(cinn external_cinn) - diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 90cb686700e..76e0a2e29ed 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -36,7 +36,7 @@ ENDIF() if(NOT DEFINED XPU_BASE_URL) SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") - SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220411") + SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220415") else() SET(XPU_BASE_URL "${XPU_BASE_URL}") endif() diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index 83a5b6f8213..67393c288df 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -21,6 +21,8 @@ #include #include +#include "cinn/auto_schedule/auto_tuner.h" +#include "cinn/auto_schedule/tuning.h" #include "cinn/common/target.h" #include "cinn/common/type.h" #include "cinn/frontend/decomposer/use_decomposer.h" @@ -48,6 +50,7 @@ #include "paddle/phi/core/utils/rw_lock.h" DECLARE_bool(enable_pe_launch_cinn); +DECLARE_bool(enable_cinn_auto_tune); namespace paddle { namespace framework { namespace paddle2cinn { @@ -58,6 +61,7 @@ using inference::analysis::Dot; using ::cinn::common::Target; using ::cinn::common::Float; using ::cinn::hlir::framework::GraphCompiler; +using ::cinn::auto_schedule::AutoTuner; using ::cinn::hlir::framework::BuildScope; using ::cinn::frontend::ProgramPass; using ::cinn::hlir::framework::ApplyPass; @@ -277,10 +281,20 @@ std::unique_ptr CinnCompiler::CompileGraph( if (!FLAGS_enable_pe_launch_cinn) { options.with_buffer_handle_instruction_inserted = true; } + std::unique_ptr auto_tuner; + if (FLAGS_enable_cinn_auto_tune) { + VLOG(4) << "Compile with auto-tune"; + auto_tuner = std::make_unique(target, cinn_graph.get()); + auto_tuner->Initialize(AutoTuner::Config(), graph_compiler.get()); + ::cinn::auto_schedule::TuningOptions tuning_options; + tuning_options.num_measure_trials = 0; + auto tuning_result = auto_tuner->Tune(tuning_options); + options.Apply(tuning_result); + } auto compiled_res = graph_compiler->Build(options, std::move(fetch_ids), stream); auto compiled_obj = std::make_unique(); - *compiled_obj = {std::move(graph_compiler), + *compiled_obj = {std::move(graph_compiler), std::move(auto_tuner), std::move(compiled_res.runtime_program), scope, symbol.var_model_to_program_map()}; compiled_obj->cached_index = compiled_num; diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.h b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h index cf17e68156b..7e5df6faf08 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.h +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h @@ -37,6 +37,10 @@ class GraphCompiler; class Program; class Scope; } // namespace hlir::framework + +namespace auto_schedule { +class AutoTuner; +} // namespace auto_schedule } // namespace cinn namespace paddle { @@ -49,6 +53,7 @@ namespace paddle2cinn { struct CinnCompiledObject { std::unique_ptr<::cinn::hlir::framework::GraphCompiler> compiler; + std::unique_ptr<::cinn::auto_schedule::AutoTuner> auto_tuner; std::unique_ptr<::cinn::hlir::framework::Program> runtime_program; std::shared_ptr<::cinn::hlir::framework::Scope> scope; std::unordered_map paddle2cinn_varmap; diff --git a/paddle/fluid/operators/cinn/cinn_launch_context_test.cc b/paddle/fluid/operators/cinn/cinn_launch_context_test.cc index 15ea9a6926a..ecbfbf2f92e 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context_test.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_context_test.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include +#include "cinn/auto_schedule/auto_tuner.h" #include "cinn/common/target.h" #include "cinn/common/type.h" #include "cinn/hlir/framework/graph_compiler.h" diff --git a/paddle/fluid/operators/cinn/cinn_launch_op_test.cc b/paddle/fluid/operators/cinn/cinn_launch_op_test.cc index 3e363c56eb9..3d6aee1d355 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_op_test.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_op_test.cc @@ -33,6 +33,7 @@ USE_OP(cinn_instruction_run); USE_OP_ITSELF(elementwise_add); DECLARE_double(eager_delete_tensor_gb); DECLARE_bool(enable_pe_launch_cinn); +DECLARE_bool(enable_cinn_auto_tune); PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); #ifdef PADDLE_WITH_CUDA @@ -107,6 +108,14 @@ TEST_F(TestCinnLaunchOp, TestRunInstructionByCinnProgram) { #endif } +TEST_F(TestCinnLaunchOp, TestRunWithAutoTuneEnabled) { + FLAGS_enable_cinn_auto_tune = true; + + // currently only check on cpu, will add a test for gpu after CINN ready + RunAndCheck(platform::CPUPlace()); + RunAndCheck(platform::CPUPlace()); +} + namespace details { // Testing helper function used on CinnLaunchOpKernel in the following: // firstly build test data, then check both expected and illegal situations -- GitLab