未验证 提交 f5d356b8 编写于 作者: T TeFeng Chen 提交者: GitHub

[cherry-pick] enable auto-tune when using cinn (#41795) (#42006)

cherry-pick #41795
上级 efddf9ea
......@@ -26,7 +26,7 @@ add_definitions(-w)
######################################
include(ExternalProject)
set(CINN_PREFIX_DIR ${THIRD_PARTY_PATH}/CINN)
set(CINN_GIT_TAG 1fd85187b6c18da4dd51f22619d093ef08d61b01)
set(CINN_GIT_TAG 08d7680dd91dfaa65787969050eb8f1143654f10)
set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION}
-DWITH_CUDA=${WITH_GPU}
-DWITH_CUDNN=${WITH_GPU}
......@@ -85,4 +85,3 @@ add_library(cinn SHARED IMPORTED GLOBAL)
set_target_properties(cinn PROPERTIES IMPORTED_LOCATION "${CINN_LIB_LOCATION}/${CINN_LIB_NAME}")
include_directories(${CINN_INCLUDE_DIR})
add_dependencies(cinn external_cinn)
......@@ -36,7 +36,7 @@ ENDIF()
if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220411")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220415")
else()
SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif()
......
......@@ -21,6 +21,8 @@
#include <string>
#include <unordered_map>
#include "cinn/auto_schedule/auto_tuner.h"
#include "cinn/auto_schedule/tuning.h"
#include "cinn/common/target.h"
#include "cinn/common/type.h"
#include "cinn/frontend/decomposer/use_decomposer.h"
......@@ -48,6 +50,7 @@
#include "paddle/phi/core/utils/rw_lock.h"
DECLARE_bool(enable_pe_launch_cinn);
DECLARE_bool(enable_cinn_auto_tune);
namespace paddle {
namespace framework {
namespace paddle2cinn {
......@@ -58,6 +61,7 @@ using inference::analysis::Dot;
using ::cinn::common::Target;
using ::cinn::common::Float;
using ::cinn::hlir::framework::GraphCompiler;
using ::cinn::auto_schedule::AutoTuner;
using ::cinn::hlir::framework::BuildScope;
using ::cinn::frontend::ProgramPass;
using ::cinn::hlir::framework::ApplyPass;
......@@ -277,10 +281,20 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
if (!FLAGS_enable_pe_launch_cinn) {
options.with_buffer_handle_instruction_inserted = true;
}
std::unique_ptr<AutoTuner> auto_tuner;
if (FLAGS_enable_cinn_auto_tune) {
VLOG(4) << "Compile with auto-tune";
auto_tuner = std::make_unique<AutoTuner>(target, cinn_graph.get());
auto_tuner->Initialize(AutoTuner::Config(), graph_compiler.get());
::cinn::auto_schedule::TuningOptions tuning_options;
tuning_options.num_measure_trials = 0;
auto tuning_result = auto_tuner->Tune(tuning_options);
options.Apply(tuning_result);
}
auto compiled_res =
graph_compiler->Build(options, std::move(fetch_ids), stream);
auto compiled_obj = std::make_unique<CinnCompiledObject>();
*compiled_obj = {std::move(graph_compiler),
*compiled_obj = {std::move(graph_compiler), std::move(auto_tuner),
std::move(compiled_res.runtime_program), scope,
symbol.var_model_to_program_map()};
compiled_obj->cached_index = compiled_num;
......
......@@ -37,6 +37,10 @@ class GraphCompiler;
class Program;
class Scope;
} // namespace hlir::framework
namespace auto_schedule {
class AutoTuner;
} // namespace auto_schedule
} // namespace cinn
namespace paddle {
......@@ -49,6 +53,7 @@ namespace paddle2cinn {
struct CinnCompiledObject {
std::unique_ptr<::cinn::hlir::framework::GraphCompiler> compiler;
std::unique_ptr<::cinn::auto_schedule::AutoTuner> auto_tuner;
std::unique_ptr<::cinn::hlir::framework::Program> runtime_program;
std::shared_ptr<::cinn::hlir::framework::Scope> scope;
std::unordered_map<std::string, std::string> paddle2cinn_varmap;
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory>
#include <set>
#include <utility>
#include "cinn/auto_schedule/auto_tuner.h"
#include "cinn/common/target.h"
#include "cinn/common/type.h"
#include "cinn/hlir/framework/graph_compiler.h"
......
......@@ -33,6 +33,7 @@ USE_OP(cinn_instruction_run);
USE_OP_ITSELF(elementwise_add);
DECLARE_double(eager_delete_tensor_gb);
DECLARE_bool(enable_pe_launch_cinn);
DECLARE_bool(enable_cinn_auto_tune);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
#ifdef PADDLE_WITH_CUDA
......@@ -107,6 +108,14 @@ TEST_F(TestCinnLaunchOp, TestRunInstructionByCinnProgram) {
#endif
}
TEST_F(TestCinnLaunchOp, TestRunWithAutoTuneEnabled) {
FLAGS_enable_cinn_auto_tune = true;
// currently only check on cpu, will add a test for gpu after CINN ready
RunAndCheck(platform::CPUPlace());
RunAndCheck(platform::CPUPlace());
}
namespace details {
// Testing helper function used on CinnLaunchOpKernel in the following:
// firstly build test data, then check both expected and illegal situations
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册