未验证 提交 af127342 编写于 作者: 张经纬 提交者: GitHub

[CodeStyle][CINN] format cpp code via clang-format (#54961)

* fix clang-format

* 'fix_clang-format'

* fix remaining errors

* format

* empty commit, re-trigger all ci

* empty commit, re-trigger all ci

---------
Co-authored-by: NSigureMo <sigure.qaq@gmail.com>
上级 a7419ff5
...@@ -47,7 +47,8 @@ repos: ...@@ -47,7 +47,8 @@ repos:
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$
exclude: | exclude: |
(?x)^( (?x)^(
paddle/utils/.* paddle/utils/.*|
paddle/cinn/utils/registry.h
)$ )$
# For Python files # For Python files
- repo: https://github.com/psf/black.git - repo: https://github.com/psf/black.git
......
...@@ -58,26 +58,32 @@ void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block) { ...@@ -58,26 +58,32 @@ void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block) {
const ir::Load* load_expr = x->As<ir::Load>(); const ir::Load* load_expr = x->As<ir::Load>();
if (load_expr != nullptr) { if (load_expr != nullptr) {
const ir::Tensor t = load_expr->tensor.as_tensor_ref(); const ir::Tensor t = load_expr->tensor.as_tensor_ref();
sche_block->read_buffers.emplace_back(ir::BufferRange(t->buffer, IndicesToVars(load_expr->indices))); sche_block->read_buffers.emplace_back(
ir::BufferRange(t->buffer, IndicesToVars(load_expr->indices)));
return false; return false;
} }
const ir::Store* store_expr = x->As<ir::Store>(); const ir::Store* store_expr = x->As<ir::Store>();
if (store_expr != nullptr) { if (store_expr != nullptr) {
const ir::Tensor t = store_expr->tensor.as_tensor_ref(); const ir::Tensor t = store_expr->tensor.as_tensor_ref();
sche_block->write_buffers.emplace_back(ir::BufferRange(t->buffer, IndicesToVars(store_expr->indices))); sche_block->write_buffers.emplace_back(
ir::BufferRange(t->buffer, IndicesToVars(store_expr->indices)));
return false; return false;
} }
return false; return false;
}); });
} }
bool ContainsNodeType(ir::Expr expr, const std::unordered_set<ir::IrNodeTy>& node_types) { bool ContainsNodeType(ir::Expr expr,
std::set<ir::Expr> collection = ir::CollectIRNodesWithoutTensor( const std::unordered_set<ir::IrNodeTy>& node_types) {
expr, [&](const Expr* x) { return node_types.find(x->node_type()) != node_types.end(); }); std::set<ir::Expr> collection =
ir::CollectIRNodesWithoutTensor(expr, [&](const Expr* x) {
return node_types.find(x->node_type()) != node_types.end();
});
return !collection.empty(); return !collection.empty();
} }
std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(const std::vector<ir::LoweredFunc>& lowered_funcs) { std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(
const std::vector<ir::LoweredFunc>& lowered_funcs) {
std::unordered_set<std::string> result; std::unordered_set<std::string> result;
for (const ir::LoweredFunc& func : lowered_funcs) { for (const ir::LoweredFunc& func : lowered_funcs) {
for (const ir::Argument& arg : func->args) { for (const ir::Argument& arg : func->args) {
...@@ -90,18 +96,22 @@ std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(const std::vector< ...@@ -90,18 +96,22 @@ std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(const std::vector<
} }
bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize) { bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize) {
const ir::ScheduleBlock* sche_block = sche_block_realize.schedule_block.As<ir::ScheduleBlock>(); const ir::ScheduleBlock* sche_block =
if (sche_block->write_buffers.size() != 1 || sche_block->read_buffers.empty()) { sche_block_realize.schedule_block.As<ir::ScheduleBlock>();
if (sche_block->write_buffers.size() != 1 ||
sche_block->read_buffers.empty()) {
return false; return false;
} }
const ir::Expr& write_buffer = sche_block->write_buffers[0].As<ir::_BufferRange_>()->buffer; const ir::Expr& write_buffer =
sche_block->write_buffers[0].As<ir::_BufferRange_>()->buffer;
// Enumerate each read region, get the number of schedule block iter vars // Enumerate each read region, get the number of schedule block iter vars
// which are not used to index the read region // which are not used to index the read region
int total_unused_iter_vars = 0; int total_unused_iter_vars = 0;
for (const ir::Expr& read_buffer_expr : sche_block->read_buffers) { for (const ir::Expr& read_buffer_expr : sche_block->read_buffers) {
const ir::_BufferRange_* read_buffer = read_buffer_expr.As<ir::_BufferRange_>(); const ir::_BufferRange_* read_buffer =
read_buffer_expr.As<ir::_BufferRange_>();
// Skip the reduction buffer // Skip the reduction buffer
if (read_buffer->buffer == write_buffer) { if (read_buffer->buffer == write_buffer) {
continue; continue;
...@@ -133,7 +143,9 @@ bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize) { ...@@ -133,7 +143,9 @@ bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize) {
return total_unused_iter_vars >= 1; return total_unused_iter_vars >= 1;
} }
ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::LoweredFunc& old_func, ir::Expr& body) { ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target,
const ir::LoweredFunc& old_func,
ir::Expr& body) {
ir::ModuleExpr mod_expr(std::vector<ir::Expr>({body})); ir::ModuleExpr mod_expr(std::vector<ir::Expr>({body}));
ir::IRSchedule ir_sch(mod_expr); ir::IRSchedule ir_sch(mod_expr);
...@@ -143,8 +155,10 @@ ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::Lo ...@@ -143,8 +155,10 @@ ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::Lo
const std::string& buf_name = buf->name; const std::string& buf_name = buf->name;
std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks(); std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks();
for (ir::Expr& e : all_block_realizes) { for (ir::Expr& e : all_block_realizes) {
const ir::ScheduleBlockRealize* sche_block_realize = e.As<ir::ScheduleBlockRealize>(); const ir::ScheduleBlockRealize* sche_block_realize =
const std::string& sche_name = sche_block_realize->schedule_block.As<ir::ScheduleBlock>()->name; e.As<ir::ScheduleBlockRealize>();
const std::string& sche_name =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>()->name;
if (buf_name == "_" + sche_name) { if (buf_name == "_" + sche_name) {
VLOG(6) << "Set local buffer for temp buffer " << buf_name; VLOG(6) << "Set local buffer for temp buffer " << buf_name;
ir_sch.SetBuffer(e, "local", true); ir_sch.SetBuffer(e, "local", true);
...@@ -159,14 +173,17 @@ ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::Lo ...@@ -159,14 +173,17 @@ ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::Lo
#endif #endif
// Get new temp bufs by analyzing. // Get new temp bufs by analyzing.
std::vector<ir::Buffer> new_temp_bufs = lang::GetTempBuffers(old_func->args, updated_body); std::vector<ir::Buffer> new_temp_bufs =
ir::LoweredFunc new_func = ir::_LoweredFunc_::Make(old_func->name, old_func->args, updated_body, new_temp_bufs); lang::GetTempBuffers(old_func->args, updated_body);
ir::LoweredFunc new_func = ir::_LoweredFunc_::Make(
old_func->name, old_func->args, updated_body, new_temp_bufs);
#ifdef CINN_WITH_CUDA #ifdef CINN_WITH_CUDA
if (target == common::DefaultNVGPUTarget()) { if (target == common::DefaultNVGPUTarget()) {
new_func->PrepareCudaAxisInfoFromBody(); new_func->PrepareCudaAxisInfoFromBody();
} }
#endif #endif
new_func = optim::Optimize(Expr(new_func), target, false).as_lowered_func_ref(); new_func =
optim::Optimize(Expr(new_func), target, false).as_lowered_func_ref();
new_func->PrepareBufferCastExprs(/*with_expr_gen_tensor = */ false); new_func->PrepareBufferCastExprs(/*with_expr_gen_tensor = */ false);
return new_func; return new_func;
......
...@@ -27,12 +27,14 @@ namespace auto_schedule { ...@@ -27,12 +27,14 @@ namespace auto_schedule {
void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block); void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block);
bool ContainsNodeType(ir::Expr expr, const std::unordered_set<ir::IrNodeTy>& node_types); bool ContainsNodeType(ir::Expr expr,
const std::unordered_set<ir::IrNodeTy>& node_types);
/** /**
* Collects all input lowered_funcs and return names of all output arguments * Collects all input lowered_funcs and return names of all output arguments
*/ */
std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(const std::vector<ir::LoweredFunc>& lowered_funcs); std::unordered_set<std::string> GetOutputNamesFromLoweredFunc(
const std::vector<ir::LoweredFunc>& lowered_funcs);
/** /**
* Determine whether a schedule block needs multileveltiling * Determine whether a schedule block needs multileveltiling
...@@ -42,7 +44,9 @@ bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize); ...@@ -42,7 +44,9 @@ bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize);
/** /**
* Update a LoweredFunc by regenerating related fields with a new function body * Update a LoweredFunc by regenerating related fields with a new function body
*/ */
ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::LoweredFunc& old_func, ir::Expr& body); ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target,
const ir::LoweredFunc& old_func,
ir::Expr& body);
} // namespace auto_schedule } // namespace auto_schedule
} // namespace cinn } // namespace cinn
...@@ -50,7 +50,8 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) { ...@@ -50,7 +50,8 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) {
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); {M, N}, [&](Var i, Var j) { return A(i, j); }, "B");
poly::StageMap stages = poly::CreateStages({A, B}); poly::StageMap stages = poly::CreateStages({A, B});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true);
ASSERT_FALSE(funcs.empty()); ASSERT_FALSE(funcs.empty());
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
...@@ -65,8 +66,10 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) { ...@@ -65,8 +66,10 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) {
std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks(); std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks();
ASSERT_EQ(all_block_realizes.size(), 1UL); ASSERT_EQ(all_block_realizes.size(), 1UL);
ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes[0].As<ir::ScheduleBlockRealize>(); ir::ScheduleBlockRealize* sche_block_realize =
ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>(); all_block_realizes[0].As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
AnalyzeScheduleBlockReadWriteBuffer(sche_block); AnalyzeScheduleBlockReadWriteBuffer(sche_block);
/* /*
...@@ -113,7 +116,8 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) { ...@@ -113,7 +116,8 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) {
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = poly::CreateStages({C}); poly::StageMap stages = poly::CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("AddDiffShape", stages, {C}, {}, {}, nullptr, target, true); std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"AddDiffShape", stages, {C}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before MultiLevelTiling: "; VLOG(6) << "Expr before MultiLevelTiling: ";
...@@ -126,8 +130,10 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) { ...@@ -126,8 +130,10 @@ TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) {
std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks(); std::vector<ir::Expr> all_block_realizes = ir_sch.GetAllBlocks();
ASSERT_EQ(all_block_realizes.size(), 1UL); ASSERT_EQ(all_block_realizes.size(), 1UL);
ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes[0].As<ir::ScheduleBlockRealize>(); ir::ScheduleBlockRealize* sche_block_realize =
ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>(); all_block_realizes[0].As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
AnalyzeScheduleBlockReadWriteBuffer(sche_block); AnalyzeScheduleBlockReadWriteBuffer(sche_block);
VLOG(6) << "ScheduleBlockRealize: "; VLOG(6) << "ScheduleBlockRealize: ";
...@@ -164,7 +170,8 @@ TEST(AnalyzeIr, ContainsNodeType) { ...@@ -164,7 +170,8 @@ TEST(AnalyzeIr, ContainsNodeType) {
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); {M, N}, [&](Var i, Var j) { return A(i, j); }, "B");
poly::StageMap stages = poly::CreateStages({A, B}); poly::StageMap stages = poly::CreateStages({A, B});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true);
ASSERT_FALSE(funcs.empty()); ASSERT_FALSE(funcs.empty());
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
...@@ -172,9 +179,12 @@ TEST(AnalyzeIr, ContainsNodeType) { ...@@ -172,9 +179,12 @@ TEST(AnalyzeIr, ContainsNodeType) {
VLOG(6) << "Analyzing for Expr:"; VLOG(6) << "Analyzing for Expr:";
VLOG(6) << ast_expr; VLOG(6) << ast_expr;
ASSERT_TRUE(ContainsNodeType(ast_expr, {ir::IrNodeTy::Load, ir::IrNodeTy::Store})); ASSERT_TRUE(
ASSERT_TRUE(ContainsNodeType(ast_expr, {ir::IrNodeTy::Load, ir::IrNodeTy::IfThenElse})); ContainsNodeType(ast_expr, {ir::IrNodeTy::Load, ir::IrNodeTy::Store}));
ASSERT_FALSE(ContainsNodeType(ast_expr, {ir::IrNodeTy::IfThenElse, ir::IrNodeTy::Sum})); ASSERT_TRUE(ContainsNodeType(ast_expr,
{ir::IrNodeTy::Load, ir::IrNodeTy::IfThenElse}));
ASSERT_FALSE(ContainsNodeType(ast_expr,
{ir::IrNodeTy::IfThenElse, ir::IrNodeTy::Sum}));
} }
} // namespace auto_schedule } // namespace auto_schedule
......
...@@ -38,13 +38,17 @@ ...@@ -38,13 +38,17 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
AutoTuner::AutoTuner(const common::Target& target, hlir::framework::Graph* graph) : target_(target), graph_(graph) {} AutoTuner::AutoTuner(const common::Target& target,
hlir::framework::Graph* graph)
: target_(target), graph_(graph) {}
void AutoTuner::Initialize(const Config& config, hlir::framework::GraphCompiler* graph_compiler) { void AutoTuner::Initialize(const Config& config,
hlir::framework::GraphCompiler* graph_compiler) {
// create builder, runner, and schedule measurer // create builder, runner, and schedule measurer
builder_ = std::make_unique<SimpleBuilder>(graph_compiler); builder_ = std::make_unique<SimpleBuilder>(graph_compiler);
runner_ = std::make_unique<SimpleRunner>(config.runner_repeat_times); runner_ = std::make_unique<SimpleRunner>(config.runner_repeat_times);
schedule_measurer_ = std::make_unique<ScheduleMeasurer>(builder_.get(), runner_.get()); schedule_measurer_ =
std::make_unique<ScheduleMeasurer>(builder_.get(), runner_.get());
// initialize database // initialize database
database_ = std::move(Database::Make(config.database_config)); database_ = std::move(Database::Make(config.database_config));
...@@ -53,29 +57,43 @@ void AutoTuner::Initialize(const Config& config, hlir::framework::GraphCompiler* ...@@ -53,29 +57,43 @@ void AutoTuner::Initialize(const Config& config, hlir::framework::GraphCompiler*
TaskCreator task_creator; TaskCreator task_creator;
tasks_ = task_creator.CreateTuneTaskOpLevel(graph_); tasks_ = task_creator.CreateTuneTaskOpLevel(graph_);
const auto& dtype_dict = graph_->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype"); const auto& dtype_dict =
const auto& shape_dict = graph_->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape"); graph_->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
const auto& shape_dict = graph_->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
op_lowerer_ = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target_); op_lowerer_ = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target_);
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (auto i = 0; i < tasks_.size(); ++i) { for (auto i = 0; i < tasks_.size(); ++i) {
auto&& task = tasks_[i]; auto&& task = tasks_[i];
task.Initialize(shape_dict, dtype_dict, op_lowerer_.get()); task.Initialize(shape_dict, dtype_dict, op_lowerer_.get());
// Register the initial ModuleExpr corresponding to the task // Register the initial ModuleExpr corresponding to the task
task_registry->Regist(task.serialized_key, ir::ModuleExpr(task.GetLoweredFuncBodyExprs())); task_registry->Regist(task.serialized_key,
VLOG(3) << "Add a task, id:" << i << ", serialized_key:\n" << task.serialized_key; ir::ModuleExpr(task.GetLoweredFuncBodyExprs()));
VLOG(3) << "Add a task, id:" << i << ", serialized_key:\n"
<< task.serialized_key;
} }
// create task optimizers // create task optimizers
utils::LinearRandomEngine::StateType initial_seed = utils::LinearRandomEngine::GetDeviceRandomValue(); utils::LinearRandomEngine::StateType initial_seed =
utils::LinearRandomEngine::GetDeviceRandomValue();
task_optimizers_.resize(tasks_.size()); task_optimizers_.resize(tasks_.size());
std::transform(tasks_.begin(), tasks_.end(), task_optimizers_.begin(), [&](TuneTask& task) { std::transform(tasks_.begin(),
tasks_.end(),
task_optimizers_.begin(),
[&](TuneTask& task) {
return std::make_unique<TaskOptimizer>( return std::make_unique<TaskOptimizer>(
&task, schedule_measurer_.get(), database_.get(), utils::ForkRandomState(&initial_seed)); &task,
schedule_measurer_.get(),
database_.get(),
utils::ForkRandomState(&initial_seed));
}); });
// create task scheduler // create task scheduler
task_scheduler_ = TaskScheduler::Make(tasks_, config.task_schedule_config, config.task_schedule_strategy); task_scheduler_ = TaskScheduler::Make(
tasks_, config.task_schedule_config, config.task_schedule_strategy);
} }
void PrintResult(std::shared_ptr<hlir::framework::Graph::Group> group) { void PrintResult(std::shared_ptr<hlir::framework::Graph::Group> group) {
...@@ -127,7 +145,8 @@ void PrintResult(const TuningResult& result) { ...@@ -127,7 +145,8 @@ void PrintResult(const TuningResult& result) {
TuningResult AutoTuner::Tune(const TuningOptions& options) { TuningResult AutoTuner::Tune(const TuningOptions& options) {
CHECK_GT(options.num_tuning_rounds, 0) << "Invalid config"; CHECK_GT(options.num_tuning_rounds, 0) << "Invalid config";
VLOG(3) << "Begin tuning with round num=" << options.num_tuning_rounds << ", tasks size=" << tasks_.size(); VLOG(3) << "Begin tuning with round num=" << options.num_tuning_rounds
<< ", tasks size=" << tasks_.size();
TuningResult result; TuningResult result;
result.subgraphs.resize(tasks_.size()); result.subgraphs.resize(tasks_.size());
......
...@@ -49,7 +49,8 @@ class AutoTuner { ...@@ -49,7 +49,8 @@ class AutoTuner {
AutoTuner(const common::Target& target, hlir::framework::Graph* graph); AutoTuner(const common::Target& target, hlir::framework::Graph* graph);
// Initialize tuner with specific config and auxiliary objects. // Initialize tuner with specific config and auxiliary objects.
void Initialize(const Config& config, hlir::framework::GraphCompiler* graph_compiler); void Initialize(const Config& config,
hlir::framework::GraphCompiler* graph_compiler);
// Perform the tuning process and return the final result // Perform the tuning process and return the final result
TuningResult Tune(const TuningOptions& options); TuningResult Tune(const TuningOptions& options);
......
...@@ -76,11 +76,13 @@ class TestAutoTuner : public ::testing::Test { ...@@ -76,11 +76,13 @@ class TestAutoTuner : public ::testing::Test {
auto program = CreateAddReluProgram(); auto program = CreateAddReluProgram();
auto graph = cinn::frontend::Optimize(&program, fetch_ids, target); auto graph = cinn::frontend::Optimize(&program, fetch_ids, target);
compiled_scope = BuildScope(target, graph); compiled_scope = BuildScope(target, graph);
graph_compiler = std::make_unique<GraphCompiler>(target, compiled_scope, graph); graph_compiler =
std::make_unique<GraphCompiler>(target, compiled_scope, graph);
tuner = std::make_unique<AutoTuner>(target, graph.get()); tuner = std::make_unique<AutoTuner>(target, graph.get());
} }
TuningResult InitializeAndTune(const AutoTuner::Config& config, const TuningOptions& options) { TuningResult InitializeAndTune(const AutoTuner::Config& config,
const TuningOptions& options) {
tuner->Initialize(config, graph_compiler.get()); tuner->Initialize(config, graph_compiler.get());
return tuner->Tune(options); return tuner->Tune(options);
} }
...@@ -108,7 +110,8 @@ class TestAutoTuner : public ::testing::Test { ...@@ -108,7 +110,8 @@ class TestAutoTuner : public ::testing::Test {
VLOG(6) << "Print lowered_funcs before building"; VLOG(6) << "Print lowered_funcs before building";
VLOG(6) << compile_options.lowered_funcs[0][0]; VLOG(6) << compile_options.lowered_funcs[0][0];
VLOG(6) << compile_options.lowered_funcs[1][0]; VLOG(6) << compile_options.lowered_funcs[1][0];
auto runtime_program = graph_compiler->Build(compile_options).runtime_program; auto runtime_program =
graph_compiler->Build(compile_options).runtime_program;
ASSERT_EQ(1, runtime_program->size()); ASSERT_EQ(1, runtime_program->size());
runtime_program->Execute(); runtime_program->Execute();
} }
......
...@@ -28,7 +28,8 @@ ...@@ -28,7 +28,8 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
float ExprCostModel::Predict(const ir::ModuleExpr& sample, const common::Target& target) const { float ExprCostModel::Predict(const ir::ModuleExpr& sample,
const common::Target& target) const {
if (trained_times_.load() == 0) { if (trained_times_.load() == 0) {
return SearchState::NOT_INIT_COST; return SearchState::NOT_INIT_COST;
} }
...@@ -44,7 +45,8 @@ void ExprCostModel::Train(const std::vector<const ir::ModuleExpr*>& samples, ...@@ -44,7 +45,8 @@ void ExprCostModel::Train(const std::vector<const ir::ModuleExpr*>& samples,
const common::Target& target) { const common::Target& target) {
trained_times_.store(1); trained_times_.store(1);
size_t total_size = samples.size(); size_t total_size = samples.size();
CHECK_EQ(total_size, labels.size()) << "Samples must have same size as labels"; CHECK_EQ(total_size, labels.size())
<< "Samples must have same size as labels";
std::vector<std::vector<float>> train_feature_numbers(total_size); std::vector<std::vector<float>> train_feature_numbers(total_size);
FeatureExtractor extractor; FeatureExtractor extractor;
for (size_t i = 0; i < total_size; ++i) { for (size_t i = 0; i < total_size; ++i) {
...@@ -61,7 +63,8 @@ void ExprCostModel::Update(const std::vector<const ir::ModuleExpr*>& samples, ...@@ -61,7 +63,8 @@ void ExprCostModel::Update(const std::vector<const ir::ModuleExpr*>& samples,
const common::Target& target) { const common::Target& target) {
++trained_times_; ++trained_times_;
size_t total_size = samples.size(); size_t total_size = samples.size();
CHECK_EQ(total_size, labels.size()) << "Samples must have same size as labels"; CHECK_EQ(total_size, labels.size())
<< "Samples must have same size as labels";
std::vector<std::vector<float>> train_feature_numbers(total_size); std::vector<std::vector<float>> train_feature_numbers(total_size);
FeatureExtractor extractor; FeatureExtractor extractor;
for (size_t i = 0; i < total_size; ++i) { for (size_t i = 0; i < total_size; ++i) {
......
...@@ -29,7 +29,8 @@ namespace auto_schedule { ...@@ -29,7 +29,8 @@ namespace auto_schedule {
*/ */
class ExprCostModel : public XgbCostModel { class ExprCostModel : public XgbCostModel {
public: public:
virtual float Predict(const ir::ModuleExpr& sample, const common::Target& target) const; virtual float Predict(const ir::ModuleExpr& sample,
const common::Target& target) const;
void Train(const std::vector<const ir::ModuleExpr*>& samples, void Train(const std::vector<const ir::ModuleExpr*>& samples,
const std::vector<float>& labels, const std::vector<float>& labels,
const common::Target& target); const common::Target& target);
......
...@@ -49,7 +49,8 @@ Feature::Feature(const common::Target& target) ...@@ -49,7 +49,8 @@ Feature::Feature(const common::Target& target)
parent_indices_(1, -1) {} parent_indices_(1, -1) {}
std::vector<float> Feature::ToFixedSizeVector() { std::vector<float> Feature::ToFixedSizeVector() {
std::vector<float> ret(LoopBlockFeature::kTotalSize + 1, 0); // LoopBlockFeature::kTotalSize plus 1 for target std::vector<float> ret(LoopBlockFeature::kTotalSize + 1,
0); // LoopBlockFeature::kTotalSize plus 1 for target
if (target_ == common::DefaultNVGPUTarget()) { if (target_ == common::DefaultNVGPUTarget()) {
ret[0] = 1; ret[0] = 1;
...@@ -165,11 +166,17 @@ void Feature::IntoLoopBlock() { ...@@ -165,11 +166,17 @@ void Feature::IntoLoopBlock() {
current_loop_block_index_ = stack_encoded_feature_.size() - 1; current_loop_block_index_ = stack_encoded_feature_.size() - 1;
} }
void Feature::ExitLoopBlock() { current_loop_block_index_ = parent_indices_[current_loop_block_index_]; } void Feature::ExitLoopBlock() {
current_loop_block_index_ = parent_indices_[current_loop_block_index_];
}
LoopBlockFeature& Feature::CurrentLoopBlock() { return stack_encoded_feature_[current_loop_block_index_]; } LoopBlockFeature& Feature::CurrentLoopBlock() {
return stack_encoded_feature_[current_loop_block_index_];
}
const LoopBlockFeature& Feature::CurrentLoopBlock() const { return stack_encoded_feature_[current_loop_block_index_]; } const LoopBlockFeature& Feature::CurrentLoopBlock() const {
return stack_encoded_feature_[current_loop_block_index_];
}
} // namespace auto_schedule } // namespace auto_schedule
} // namespace cinn } // namespace cinn
...@@ -24,10 +24,18 @@ namespace cinn { ...@@ -24,10 +24,18 @@ namespace cinn {
namespace auto_schedule { namespace auto_schedule {
/* Loop feature enums */ /* Loop feature enums */
enum class ForOptimizeFeatureEnum : int { kNone, kGpuBind, kParallel, kUnroll, kVectorize }; enum class ForOptimizeFeatureEnum : int {
kNone,
kGpuBind,
kParallel,
kUnroll,
kVectorize
};
/* function to scale feature numbers */ /* function to scale feature numbers */
inline float slog(float x) { return x < 0 ? std::log2(-x + 1) : std::log2(x + 1); } inline float slog(float x) {
return x < 0 ? std::log2(-x + 1) : std::log2(x + 1);
}
class LoopBlockFeature { class LoopBlockFeature {
public: public:
...@@ -106,7 +114,9 @@ class LoopBlockFeature { ...@@ -106,7 +114,9 @@ class LoopBlockFeature {
static constexpr int kThreadFeatureSize = 8; static constexpr int kThreadFeatureSize = 8;
static constexpr int kTotalSize = kArithSize + kMemSize + kReduceBroadcastSize + kOptApplySize + kThreadFeatureSize; static constexpr int kTotalSize = kArithSize + kMemSize +
kReduceBroadcastSize + kOptApplySize +
kThreadFeatureSize;
/* Non-feature attributes, used to maintain during feature_extractor */ /* Non-feature attributes, used to maintain during feature_extractor */
...@@ -158,10 +168,11 @@ class Feature { ...@@ -158,10 +168,11 @@ class Feature {
// some_compute_3 // some_compute_3
// } // }
// //
// We go through the code and push loops into stack, then the features are encoded as // We go through the code and push loops into stack, then the features are
// [loop_block_feature_0, loop_block_feature_1, loop_block_feature_2, loop_block_feature_3] // encoded as [loop_block_feature_0, loop_block_feature_1,
// where loop_block_feature_i stores the features of some_compute_i (such // loop_block_feature_2, loop_block_feature_3] where loop_block_feature_i
// as number of arithmetic operations) // stores the features of some_compute_i (such as number of arithmetic
// operations)
// //
// loop_block_feature_0.num_sub_loops = 2 // loop_block_feature_0.num_sub_loops = 2
// loop_block_feature_1.num_sub_loops = 1 // loop_block_feature_1.num_sub_loops = 1
......
...@@ -47,7 +47,8 @@ FeatureExtractor::FeatureExtractor() {} ...@@ -47,7 +47,8 @@ FeatureExtractor::FeatureExtractor() {}
void FeatureExtractor::Visit(const Expr *x) { IRVisitor::Visit(x); } void FeatureExtractor::Visit(const Expr *x) { IRVisitor::Visit(x); }
Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr, const common::Target &target) { Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr,
const common::Target &target) {
feature_ = Feature(target); feature_ = Feature(target);
for (const ir::Expr &e : mod_expr.GetExprs()) { for (const ir::Expr &e : mod_expr.GetExprs()) {
Visit(&e); Visit(&e);
...@@ -87,7 +88,8 @@ NotVisitExprFields(_Tensor_) ...@@ -87,7 +88,8 @@ NotVisitExprFields(_Tensor_)
#define VisitForDtypePattern(NodeType, member) \ #define VisitForDtypePattern(NodeType, member) \
void FeatureExtractor::Visit(const NodeType *x) { \ void FeatureExtractor::Visit(const NodeType *x) { \
if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { \ if (x->type() == common::F32() || x->type() == common::F16() || \
x->type() == common::F64()) { \
feature_.CurrentLoopBlock().float_##member += x->type().lanes(); \ feature_.CurrentLoopBlock().float_##member += x->type().lanes(); \
} else { \ } else { \
feature_.CurrentLoopBlock().int_##member += x->type().lanes(); \ feature_.CurrentLoopBlock().int_##member += x->type().lanes(); \
...@@ -120,8 +122,10 @@ VisitForDtypePattern(Let, other_call); ...@@ -120,8 +122,10 @@ VisitForDtypePattern(Let, other_call);
#define VisitForMultiOperandsDtypePattern(NodeType, member) \ #define VisitForMultiOperandsDtypePattern(NodeType, member) \
void FeatureExtractor::Visit(const NodeType *x) { \ void FeatureExtractor::Visit(const NodeType *x) { \
if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { \ if (x->type() == common::F32() || x->type() == common::F16() || \
feature_.CurrentLoopBlock().float_##member += (x->operands().size() - 1); \ x->type() == common::F64()) { \
feature_.CurrentLoopBlock().float_##member += \
(x->operands().size() - 1); \
} else { \ } else { \
feature_.CurrentLoopBlock().int_##member += (x->operands().size() - 1); \ feature_.CurrentLoopBlock().int_##member += (x->operands().size() - 1); \
} \ } \
...@@ -166,7 +170,8 @@ void FeatureExtractor::Visit(const For *x) { ...@@ -166,7 +170,8 @@ void FeatureExtractor::Visit(const For *x) {
LoopBlockFeature &loop_feature = feature_.CurrentLoopBlock(); LoopBlockFeature &loop_feature = feature_.CurrentLoopBlock();
if (x->min.is_constant() && x->extent.is_constant()) { if (x->min.is_constant() && x->extent.is_constant()) {
loop_feature.loop_length = (x->extent.get_constant() - x->min.get_constant()); loop_feature.loop_length =
(x->extent.get_constant() - x->min.get_constant());
} else { } else {
loop_feature.loop_length = -1; // -1 represents unknown loop_feature.loop_length = -1; // -1 represents unknown
} }
...@@ -223,13 +228,16 @@ void FeatureExtractor::Visit(const PolyFor *x) { ...@@ -223,13 +228,16 @@ void FeatureExtractor::Visit(const PolyFor *x) {
/* Visit for Reduce and Broadcast */ /* Visit for Reduce and Broadcast */
void FeatureExtractor::Visit(const Reduce *x) { void FeatureExtractor::Visit(const Reduce *x) {
if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { if (x->type() == common::F32() || x->type() == common::F16() ||
x->type() == common::F64()) {
switch (x->reduce_type) { switch (x->reduce_type) {
case Reduce::ReduceType::kSum: case Reduce::ReduceType::kSum:
feature_.CurrentLoopBlock().float_reduce_sum_or_sub += x->type().lanes(); feature_.CurrentLoopBlock().float_reduce_sum_or_sub +=
x->type().lanes();
break; break;
case Reduce::ReduceType::kSub: case Reduce::ReduceType::kSub:
feature_.CurrentLoopBlock().float_reduce_sum_or_sub += x->type().lanes(); feature_.CurrentLoopBlock().float_reduce_sum_or_sub +=
x->type().lanes();
break; break;
case Reduce::ReduceType::kDiv: case Reduce::ReduceType::kDiv:
feature_.CurrentLoopBlock().float_reduce_div += x->type().lanes(); feature_.CurrentLoopBlock().float_reduce_div += x->type().lanes();
...@@ -238,10 +246,12 @@ void FeatureExtractor::Visit(const Reduce *x) { ...@@ -238,10 +246,12 @@ void FeatureExtractor::Visit(const Reduce *x) {
feature_.CurrentLoopBlock().float_reduce_mul += x->type().lanes(); feature_.CurrentLoopBlock().float_reduce_mul += x->type().lanes();
break; break;
case Reduce::ReduceType::kMax: case Reduce::ReduceType::kMax:
feature_.CurrentLoopBlock().float_reduce_max_or_min += x->type().lanes(); feature_.CurrentLoopBlock().float_reduce_max_or_min +=
x->type().lanes();
break; break;
case Reduce::ReduceType::kMin: case Reduce::ReduceType::kMin:
feature_.CurrentLoopBlock().float_reduce_max_or_min += x->type().lanes(); feature_.CurrentLoopBlock().float_reduce_max_or_min +=
x->type().lanes();
break; break;
} }
} else { } else {
......
...@@ -49,7 +49,8 @@ TEST(FeatureExtractor, SimpleAssign) { ...@@ -49,7 +49,8 @@ TEST(FeatureExtractor, SimpleAssign) {
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); {M, N}, [&](Var i, Var j) { return A(i, j); }, "B");
poly::StageMap stages = poly::CreateStages({A, B}); poly::StageMap stages = poly::CreateStages({A, B});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr to test: " << ast_expr; VLOG(6) << "Expr to test: " << ast_expr;
...@@ -62,7 +63,8 @@ TEST(FeatureExtractor, SimpleAssign) { ...@@ -62,7 +63,8 @@ TEST(FeatureExtractor, SimpleAssign) {
std::vector<float> to_check = feature.ToFixedSizeVector(); std::vector<float> to_check = feature.ToFixedSizeVector();
ASSERT_EQ(to_check.size(), static_cast<size_t>(LoopBlockFeature::kTotalSize + 1)); ASSERT_EQ(to_check.size(),
static_cast<size_t>(LoopBlockFeature::kTotalSize + 1));
VLOG(6) << "Feature data before slog:"; VLOG(6) << "Feature data before slog:";
for (size_t i = 0; i < to_check.size(); ++i) { for (size_t i = 0; i < to_check.size(); ++i) {
VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1); VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1);
...@@ -77,9 +79,11 @@ TEST(FeatureExtractor, SimpleAssign) { ...@@ -77,9 +79,11 @@ TEST(FeatureExtractor, SimpleAssign) {
ASSERT_EQ(to_check[0], 0); ASSERT_EQ(to_check[0], 0);
#endif #endif
// mem_read // mem_read
ASSERT_EQ(to_check[17], slog(M.get_constant() * N.get_constant())); // mem_read ASSERT_EQ(to_check[17],
slog(M.get_constant() * N.get_constant())); // mem_read
// mem_write // mem_write
ASSERT_EQ(to_check[18], slog(M.get_constant() * N.get_constant())); // mem_write ASSERT_EQ(to_check[18],
slog(M.get_constant() * N.get_constant())); // mem_write
// non-opt loops, including root block // non-opt loops, including root block
ASSERT_EQ(to_check[29], slog(3)); ASSERT_EQ(to_check[29], slog(3));
} }
...@@ -101,10 +105,13 @@ TEST(FeatureExtractor, MatrixMultiply) { ...@@ -101,10 +105,13 @@ TEST(FeatureExtractor, MatrixMultiply) {
ir::Var k(K.as_int32(), "reduce_axis_k"); ir::Var k(K.as_int32(), "reduce_axis_k");
ir::Tensor C = lang::Compute( ir::Tensor C = lang::Compute(
{M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); {M, N},
[&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
poly::StageMap stages = poly::CreateStages({C}); poly::StageMap stages = poly::CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true); std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true);
std::vector<Expr> vec_ast{funcs[0]->body}; std::vector<Expr> vec_ast{funcs[0]->body};
ir::ModuleExpr mod_expr(vec_ast); ir::ModuleExpr mod_expr(vec_ast);
...@@ -121,7 +128,8 @@ TEST(FeatureExtractor, MatrixMultiply) { ...@@ -121,7 +128,8 @@ TEST(FeatureExtractor, MatrixMultiply) {
std::vector<float> to_check = feature.ToFixedSizeVector(); std::vector<float> to_check = feature.ToFixedSizeVector();
ASSERT_EQ(to_check.size(), static_cast<size_t>(LoopBlockFeature::kTotalSize + 1)); ASSERT_EQ(to_check.size(),
static_cast<size_t>(LoopBlockFeature::kTotalSize + 1));
std::unordered_set<size_t> non_zero_indice = {0, 1, 2, 17, 18, 29, 30, 37}; std::unordered_set<size_t> non_zero_indice = {0, 1, 2, 17, 18, 29, 30, 37};
for (size_t i = 0; i < to_check.size(); ++i) { for (size_t i = 0; i < to_check.size(); ++i) {
VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1); VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1);
......
...@@ -57,7 +57,8 @@ pybind11::array VectorToNumpy(const std::vector<std::vector<Dtype>>& vec) { ...@@ -57,7 +57,8 @@ pybind11::array VectorToNumpy(const std::vector<std::vector<Dtype>>& vec) {
Dtype* py_data = static_cast<Dtype*>(ret.mutable_data()); Dtype* py_data = static_cast<Dtype*>(ret.mutable_data());
for (size_t i = 0; i < vec.size(); ++i) { for (size_t i = 0; i < vec.size(); ++i) {
assert(vec[i].size() == shape[1] && "Sub vectors must have same size in VectorToNumpy"); assert(vec[i].size() == shape[1] &&
"Sub vectors must have same size in VectorToNumpy");
memcpy(py_data + (shape[1] * i), vec[i].data(), shape[1] * sizeof(Dtype)); memcpy(py_data + (shape[1] * i), vec[i].data(), shape[1] * sizeof(Dtype));
} }
return ret; return ret;
...@@ -71,19 +72,23 @@ pybind11::array VectorToNumpy(const std::vector<std::vector<Dtype>>& vec) { ...@@ -71,19 +72,23 @@ pybind11::array VectorToNumpy(const std::vector<std::vector<Dtype>>& vec) {
void AddDistPkgToPythonSysPath() { void AddDistPkgToPythonSysPath() {
pybind11::module sys_py_mod = pybind11::module::import("sys"); pybind11::module sys_py_mod = pybind11::module::import("sys");
// short version such as "3.7", "3.8", ... // short version such as "3.7", "3.8", ...
std::string py_short_version = sys_py_mod.attr("version").cast<std::string>().substr(0, 3); std::string py_short_version =
sys_py_mod.attr("version").cast<std::string>().substr(0, 3);
std::string site_pkg_str = "/usr/local/lib/python" + py_short_version + "/dist-packages"; std::string site_pkg_str =
"/usr/local/lib/python" + py_short_version + "/dist-packages";
sys_py_mod.attr("path").attr("append")(site_pkg_str); sys_py_mod.attr("path").attr("append")(site_pkg_str);
// TODO(zhhsplendid): warning to users if setuptools hasn't been installed // TODO(zhhsplendid): warning to users if setuptools hasn't been installed
DIR* site_pkg_dir = opendir(site_pkg_str.c_str()); DIR* site_pkg_dir = opendir(site_pkg_str.c_str());
if (site_pkg_dir != nullptr) { if (site_pkg_dir != nullptr) {
std::regex setuptool_regex("setuptools-.*-py" + py_short_version + "\\.egg"); std::regex setuptool_regex("setuptools-.*-py" + py_short_version +
"\\.egg");
struct dirent* entry = nullptr; struct dirent* entry = nullptr;
while ((entry = readdir(site_pkg_dir)) != nullptr) { while ((entry = readdir(site_pkg_dir)) != nullptr) {
if (std::regex_match(entry->d_name, setuptool_regex)) { if (std::regex_match(entry->d_name, setuptool_regex)) {
sys_py_mod.attr("path").attr("append")(site_pkg_str + "/" + entry->d_name); sys_py_mod.attr("path").attr("append")(site_pkg_str + "/" +
entry->d_name);
} }
} }
closedir(site_pkg_dir); closedir(site_pkg_dir);
...@@ -100,36 +105,45 @@ XgbCostModel::XgbCostModel() { ...@@ -100,36 +105,45 @@ XgbCostModel::XgbCostModel() {
xgb_booster_ = xgb_module_.attr("Booster")(); xgb_booster_ = xgb_module_.attr("Booster")();
} }
void XgbCostModel::Train(const std::vector<std::vector<float>>& samples, const std::vector<float>& labels) { void XgbCostModel::Train(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) {
update_samples_ = samples; update_samples_ = samples;
update_labels_ = labels; update_labels_ = labels;
pybind11::array np_samples = VectorToNumpy<float>(samples); pybind11::array np_samples = VectorToNumpy<float>(samples);
pybind11::array np_labels = VectorToNumpy<float>(labels); pybind11::array np_labels = VectorToNumpy<float>(labels);
pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels); pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels);
xgb_booster_ = xgb_module_.attr("train")(pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_)); xgb_booster_ = xgb_module_.attr("train")(
pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_));
} }
std::vector<float> XgbCostModel::Predict(const std::vector<std::vector<float>>& samples) const { std::vector<float> XgbCostModel::Predict(
const std::vector<std::vector<float>>& samples) const {
pybind11::array np_samples = VectorToNumpy<float>(samples); pybind11::array np_samples = VectorToNumpy<float>(samples);
pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples); pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples);
pybind11::array py_result = xgb_booster_.attr("predict")(dmatrix); pybind11::array py_result = xgb_booster_.attr("predict")(dmatrix);
return py_result.cast<std::vector<float>>(); return py_result.cast<std::vector<float>>();
} }
void XgbCostModel::Update(const std::vector<std::vector<float>>& samples, const std::vector<float>& labels) { void XgbCostModel::Update(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) {
update_samples_.insert(update_samples_.end(), samples.begin(), samples.end()); update_samples_.insert(update_samples_.end(), samples.begin(), samples.end());
update_labels_.insert(update_labels_.end(), labels.begin(), labels.end()); update_labels_.insert(update_labels_.end(), labels.begin(), labels.end());
pybind11::array np_samples = VectorToNumpy<float>(update_samples_); pybind11::array np_samples = VectorToNumpy<float>(update_samples_);
pybind11::array np_labels = VectorToNumpy<float>(update_labels_); pybind11::array np_labels = VectorToNumpy<float>(update_labels_);
pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels); pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels);
xgb_booster_ = xgb_module_.attr("train")(pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_)); xgb_booster_ = xgb_module_.attr("train")(
pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_));
} }
void XgbCostModel::Save(const std::string& path) { xgb_booster_.attr("save_model")(pybind11::str(path)); } void XgbCostModel::Save(const std::string& path) {
xgb_booster_.attr("save_model")(pybind11::str(path));
}
void XgbCostModel::Load(const std::string& path) { xgb_booster_.attr("load_model")(pybind11::str(path)); } void XgbCostModel::Load(const std::string& path) {
xgb_booster_.attr("load_model")(pybind11::str(path));
}
} // namespace auto_schedule } // namespace auto_schedule
} // namespace cinn } // namespace cinn
...@@ -47,11 +47,14 @@ class XgbCostModel : public CostModel { ...@@ -47,11 +47,14 @@ class XgbCostModel : public CostModel {
XgbCostModel(); XgbCostModel();
~XgbCostModel() = default; ~XgbCostModel() = default;
void Train(const std::vector<std::vector<float>>& samples, const std::vector<float>& labels) override; void Train(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) override;
std::vector<float> Predict(const std::vector<std::vector<float>>& samples) const override; std::vector<float> Predict(
const std::vector<std::vector<float>>& samples) const override;
void Update(const std::vector<std::vector<float>>& samples, const std::vector<float>& labels) override; void Update(const std::vector<std::vector<float>>& samples,
const std::vector<float>& labels) override;
void Save(const std::string& path) override; void Save(const std::string& path) override;
......
...@@ -34,7 +34,8 @@ TEST(CostModel, Basic) { ...@@ -34,7 +34,8 @@ TEST(CostModel, Basic) {
int batch_size = 16; int batch_size = 16;
int feature_size = 8; int feature_size = 8;
std::vector<float> labels(batch_size, 1.0); std::vector<float> labels(batch_size, 1.0);
std::vector<std::vector<float>> samples(batch_size, std::vector<float>(feature_size)); std::vector<std::vector<float>> samples(batch_size,
std::vector<float>(feature_size));
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < feature_size; ++j) { for (int j = 0; j < feature_size; ++j) {
samples[i][j] = rand() % 10; samples[i][j] = rand() % 10;
......
...@@ -26,7 +26,8 @@ ...@@ -26,7 +26,8 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
bool TuningRecord::Compare::operator()(const TuningRecord& lhs, const TuningRecord& rhs) const { bool TuningRecord::Compare::operator()(const TuningRecord& lhs,
const TuningRecord& rhs) const {
return lhs.execution_cost < rhs.execution_cost; return lhs.execution_cost < rhs.execution_cost;
} }
...@@ -39,15 +40,18 @@ proto::TuningRecord TuningRecord::ToProto() const { ...@@ -39,15 +40,18 @@ proto::TuningRecord TuningRecord::ToProto() const {
return record_proto; return record_proto;
} }
Database::Database(int capacity_per_task) : capacity_per_task_(capacity_per_task) { Database::Database(int capacity_per_task)
CHECK_GT(capacity_per_task_, 0) << "capacity_per_task_ should be greater than 0"; : capacity_per_task_(capacity_per_task) {
CHECK_GT(capacity_per_task_, 0)
<< "capacity_per_task_ should be greater than 0";
} }
std::unique_ptr<Database> Database::Make(const DatabaseConfig& config) { std::unique_ptr<Database> Database::Make(const DatabaseConfig& config) {
if (config.type == DatabaseType::kMemory) { if (config.type == DatabaseType::kMemory) {
return std::make_unique<Database>(config.capacity_per_task); return std::make_unique<Database>(config.capacity_per_task);
} else if (config.type == DatabaseType::kJSONFile) { } else if (config.type == DatabaseType::kJSONFile) {
return std::make_unique<JSONFileDatabase>(config.capacity_per_task, config.record_file_path, true); return std::make_unique<JSONFileDatabase>(
config.capacity_per_task, config.record_file_path, true);
} }
LOG(FATAL) << "Unimplemented database type."; LOG(FATAL) << "Unimplemented database type.";
...@@ -81,13 +85,16 @@ std::vector<TuningRecord> Database::LookUp(const std::string& task_key) { ...@@ -81,13 +85,16 @@ std::vector<TuningRecord> Database::LookUp(const std::string& task_key) {
return results; return results;
} }
std::vector<TuningRecord> Database::GetTopK(const std::string& task_key, int k) { std::vector<TuningRecord> Database::GetTopK(const std::string& task_key,
int k) {
auto fit = key2record_.find(task_key); auto fit = key2record_.find(task_key);
if (fit == key2record_.end() || k <= 0) { if (fit == key2record_.end() || k <= 0) {
return {}; return {};
} }
if (k > capacity_per_task_) { if (k > capacity_per_task_) {
LOG(WARNING) << "Top k=" << k << " is greater than the capacity, will adjust k=" << capacity_per_task_; LOG(WARNING) << "Top k=" << k
<< " is greater than the capacity, will adjust k="
<< capacity_per_task_;
k = capacity_per_task_; k = capacity_per_task_;
} }
...@@ -103,8 +110,10 @@ std::vector<TuningRecord> Database::GetTopK(const std::string& task_key, int k) ...@@ -103,8 +110,10 @@ std::vector<TuningRecord> Database::GetTopK(const std::string& task_key, int k)
} }
size_t Database::Size() { size_t Database::Size() {
auto res = auto res = std::accumulate(key2record_.begin(),
std::accumulate(key2record_.begin(), key2record_.end(), size_t(0), [](size_t res, const auto& kv) -> size_t { key2record_.end(),
size_t(0),
[](size_t res, const auto& kv) -> size_t {
return std::move(res) + kv.second.size(); return std::move(res) + kv.second.size();
}); });
return res; return res;
......
...@@ -39,7 +39,9 @@ struct TuningRecord { ...@@ -39,7 +39,9 @@ struct TuningRecord {
predicted_cost(record.predicted_cost()), predicted_cost(record.predicted_cost()),
trace(record.trace()), trace(record.trace()),
execution_cost(record.execution_cost()) {} execution_cost(record.execution_cost()) {}
TuningRecord(const std::string& task_key, const SearchState& state, double execution_cost) TuningRecord(const std::string& task_key,
const SearchState& state,
double execution_cost)
: task_key(task_key), : task_key(task_key),
predicted_cost(state->predicted_cost), predicted_cost(state->predicted_cost),
trace(state->ir_schedule.GetTraceDesc().ToProto()), trace(state->ir_schedule.GetTraceDesc().ToProto()),
...@@ -63,10 +65,10 @@ struct DatabaseConfig { ...@@ -63,10 +65,10 @@ struct DatabaseConfig {
std::string record_file_path = "/tmp/tuning_record.json"; std::string record_file_path = "/tmp/tuning_record.json";
}; };
// A database supports insert or lookup historial tuning result with specified traits. // A database supports insert or lookup historial tuning result with specified
// It can be implemented with a concrete storage to save/load underlying data, // traits. It can be implemented with a concrete storage to save/load underlying
// such as memory, file, database server and so on, this base class can be regarded as // data, such as memory, file, database server and so on, this base class can be
// one using memory as its underlying storage medium. // regarded as one using memory as its underlying storage medium.
class Database { class Database {
public: public:
explicit Database(int capacity_per_task); explicit Database(int capacity_per_task);
...@@ -93,7 +95,9 @@ class Database { ...@@ -93,7 +95,9 @@ class Database {
void Insert(const TuningRecord& record); void Insert(const TuningRecord& record);
// map task_key to its records // map task_key to its records
std::unordered_map<std::string, std::multiset<TuningRecord, TuningRecord::Compare>> key2record_; std::unordered_map<std::string,
std::multiset<TuningRecord, TuningRecord::Compare>>
key2record_;
// the max number of candidates stored // the max number of candidates stored
const int capacity_per_task_; const int capacity_per_task_;
}; };
......
...@@ -57,8 +57,10 @@ TEST_F(TestDatabase, GetTopK) { ...@@ -57,8 +57,10 @@ TEST_F(TestDatabase, GetTopK) {
ASSERT_TRUE(test_db.GetTopK("k5", 2).empty()); ASSERT_TRUE(test_db.GetTopK("k5", 2).empty());
ASSERT_EQ(test_db.GetTopK("k4", 3).size(), 1); ASSERT_EQ(test_db.GetTopK("k4", 3).size(), 1);
test_db.AddRecord(TuningRecord("k4", SearchState(ir::IRSchedule(), 1.2), 2.0)); test_db.AddRecord(
test_db.AddRecord(TuningRecord("k4", SearchState(ir::IRSchedule(), 1.0), 3.0)); TuningRecord("k4", SearchState(ir::IRSchedule(), 1.2), 2.0));
test_db.AddRecord(
TuningRecord("k4", SearchState(ir::IRSchedule(), 1.0), 3.0));
auto records = test_db.GetTopK("k4", 3); auto records = test_db.GetTopK("k4", 3);
ASSERT_EQ(records.size(), 2); ASSERT_EQ(records.size(), 2);
......
...@@ -35,7 +35,8 @@ void AppendLineToFile(const std::string& file_path, const std::string& line) { ...@@ -35,7 +35,8 @@ void AppendLineToFile(const std::string& file_path, const std::string& line) {
} }
// read lines from a json file // read lines from a json file
std::vector<std::string> ReadLinesFromFile(const std::string& file_path, bool allow_new_file) { std::vector<std::string> ReadLinesFromFile(const std::string& file_path,
bool allow_new_file) {
std::ifstream is(file_path); std::ifstream is(file_path);
if (is.good()) { if (is.good()) {
std::vector<std::string> json_strs; std::vector<std::string> json_strs;
...@@ -51,20 +52,26 @@ std::vector<std::string> ReadLinesFromFile(const std::string& file_path, bool al ...@@ -51,20 +52,26 @@ std::vector<std::string> ReadLinesFromFile(const std::string& file_path, bool al
return {}; return {};
} }
JSONFileDatabase::JSONFileDatabase(int capacity_per_task, const std::string& record_file_path, bool allow_new_file) JSONFileDatabase::JSONFileDatabase(int capacity_per_task,
const std::string& record_file_path,
bool allow_new_file)
: Database(capacity_per_task), record_file_path_(record_file_path) { : Database(capacity_per_task), record_file_path_(record_file_path) {
VLOG(3) << "Auto schedule will save/load tuning records on file:" << record_file_path; VLOG(3) << "Auto schedule will save/load tuning records on file:"
<< record_file_path;
auto json_lines = ReadLinesFromFile(record_file_path_, allow_new_file); auto json_lines = ReadLinesFromFile(record_file_path_, allow_new_file);
std::vector<cinn::auto_schedule::proto::TuningRecord> all_records_proto(json_lines.size()); std::vector<cinn::auto_schedule::proto::TuningRecord> all_records_proto(
json_lines.size());
// convert JSON string to proto object // convert JSON string to proto object
auto worker_fn = [this, &json_lines, &all_records_proto](int index) { auto worker_fn = [this, &json_lines, &all_records_proto](int index) {
cinn::auto_schedule::proto::TuningRecord record_proto; cinn::auto_schedule::proto::TuningRecord record_proto;
auto status = google::protobuf::util::JsonStringToMessage(json_lines[index], &record_proto); auto status = google::protobuf::util::JsonStringToMessage(json_lines[index],
&record_proto);
CHECK(status.ok()) << "Failed to parse JSON: " << json_lines[index]; CHECK(status.ok()) << "Failed to parse JSON: " << json_lines[index];
all_records_proto[index].Swap(&record_proto); all_records_proto[index].Swap(&record_proto);
}; };
utils::parallel_run(worker_fn, utils::SequenceDispatcher(0, json_lines.size()), -1); utils::parallel_run(
worker_fn, utils::SequenceDispatcher(0, json_lines.size()), -1);
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
...@@ -81,8 +88,10 @@ JSONFileDatabase::JSONFileDatabase(int capacity_per_task, const std::string& rec ...@@ -81,8 +88,10 @@ JSONFileDatabase::JSONFileDatabase(int capacity_per_task, const std::string& rec
std::string JSONFileDatabase::RecordToJSON(const TuningRecord& record) { std::string JSONFileDatabase::RecordToJSON(const TuningRecord& record) {
proto::TuningRecord record_proto = record.ToProto(); proto::TuningRecord record_proto = record.ToProto();
std::string json_string; std::string json_string;
auto status = google::protobuf::util::MessageToJsonString(record_proto, &json_string); auto status =
CHECK(status.ok()) << "Failed to serialize record to JSON, task key = " << record.task_key; google::protobuf::util::MessageToJsonString(record_proto, &json_string);
CHECK(status.ok()) << "Failed to serialize record to JSON, task key = "
<< record.task_key;
VLOG(4) << "json_string = \n" << json_string; VLOG(4) << "json_string = \n" << json_string;
return json_string; return json_string;
......
...@@ -19,16 +19,20 @@ ...@@ -19,16 +19,20 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
// JSONFileDatabase is a database implemented by JSON file to save/load underlying data. // JSONFileDatabase is a database implemented by JSON file to save/load
// underlying data.
class JSONFileDatabase : public Database { class JSONFileDatabase : public Database {
public: public:
/*! /*!
* \brief Build a JSONFileDatabase object from a json file. * \brief Build a JSONFileDatabase object from a json file.
* \param capacity_per_task The max number of candidates stored. * \param capacity_per_task The max number of candidates stored.
* \param record_file_path The path of the json file. * \param record_file_path The path of the json file.
* \param allow_new_file Whether to create new file when the given path is not found. * \param allow_new_file Whether to create new file when the given path is not
* found.
*/ */
JSONFileDatabase(int capacity_per_task, const std::string& record_file_path, bool allow_new_file); JSONFileDatabase(int capacity_per_task,
const std::string& record_file_path,
bool allow_new_file);
~JSONFileDatabase() = default; ~JSONFileDatabase() = default;
// convert a TuningRecord object to string in JSON format // convert a TuningRecord object to string in JSON format
...@@ -46,7 +50,8 @@ class JSONFileDatabase : public Database { ...@@ -46,7 +50,8 @@ class JSONFileDatabase : public Database {
void AppendLineToFile(const std::string& file_path, const std::string& line); void AppendLineToFile(const std::string& file_path, const std::string& line);
// read lines from a json file // read lines from a json file
std::vector<std::string> ReadLinesFromFile(const std::string& file_path, bool allow_new_file = true); std::vector<std::string> ReadLinesFromFile(const std::string& file_path,
bool allow_new_file = true);
} // namespace auto_schedule } // namespace auto_schedule
} // namespace cinn } // namespace cinn
...@@ -31,7 +31,8 @@ namespace cinn { ...@@ -31,7 +31,8 @@ namespace cinn {
namespace auto_schedule { namespace auto_schedule {
// Return lowerd ir AST for example functions used in this test // Return lowerd ir AST for example functions used in this test
std::vector<ir::LoweredFunc> LowerCompute(const std::vector<int>& shape, const Target& target) { std::vector<ir::LoweredFunc> LowerCompute(const std::vector<int>& shape,
const Target& target) {
CHECK(shape.size() == 2) << "shape should be 2"; CHECK(shape.size() == 2) << "shape should be 2";
std::vector<Expr> domain; std::vector<Expr> domain;
for (auto i = 0; i < shape.size(); ++i) { for (auto i = 0; i < shape.size(); ++i) {
...@@ -46,11 +47,13 @@ std::vector<ir::LoweredFunc> LowerCompute(const std::vector<int>& shape, const T ...@@ -46,11 +47,13 @@ std::vector<ir::LoweredFunc> LowerCompute(const std::vector<int>& shape, const T
C = Compute( C = Compute(
domain, [&B](Var i, Var j) { return B(i, j); }, "C"); domain, [&B](Var i, Var j) { return B(i, j); }, "C");
return cinn::lang::LowerVec("test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); return cinn::lang::LowerVec(
"test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true);
} }
// Create a new IRSchedule with copied ir::LoweredFunc AST // Create a new IRSchedule with copied ir::LoweredFunc AST
ir::IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs, const std::string& task_key) { ir::IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs,
const std::string& task_key) {
std::vector<Expr> exprs; std::vector<Expr> exprs;
for (auto&& func : lowered_funcs) { for (auto&& func : lowered_funcs) {
exprs.emplace_back(optim::IRCopy(func->body)); exprs.emplace_back(optim::IRCopy(func->body));
...@@ -63,7 +66,9 @@ ir::IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs, ...@@ -63,7 +66,9 @@ ir::IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs,
class TestJSONFileDatabase : public ::testing::Test { class TestJSONFileDatabase : public ::testing::Test {
public: public:
TestJSONFileDatabase() : record_file_path("/tmp/test_record.json"), test_db(2, record_file_path, true) {} TestJSONFileDatabase()
: record_file_path("/tmp/test_record.json"),
test_db(2, record_file_path, true) {}
void SetUp() override { lowered_funcs = LowerCompute({32, 32}, target); } void SetUp() override { lowered_funcs = LowerCompute({32, 32}, target); }
...@@ -97,14 +102,19 @@ TEST_F(TestJSONFileDatabase, Serialize) { ...@@ -97,14 +102,19 @@ TEST_F(TestJSONFileDatabase, Serialize) {
TuningRecord record1("test", SearchState(std::move(ir_sch), 2.0), 1.0); TuningRecord record1("test", SearchState(std::move(ir_sch), 2.0), 1.0);
std::string str = test_db.RecordToJSON(record1); std::string str = test_db.RecordToJSON(record1);
VLOG(3) << "RecordToJSON: " << str; VLOG(3) << "RecordToJSON: " << str;
// Because the serialization of protobuf does not guarantee the order, we give all possible results. // Because the serialization of protobuf does not guarantee the order, we give
// all possible results.
std::string case1 = std::string case1 =
"{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," "{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":"
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":\"INTS\",\"ints\":[0,1]},{\"name\":\"block_" "{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":"
"\"INTS\",\"ints\":[0,1]},{\"name\":\"block_"
"name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}"; "name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}";
std::string case2 = std::string case2 =
"{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," "{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":"
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":\"STRING\",\"s\":\"B\"},{\"name\":\"loops_" "{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":"
"\"STRING\",\"s\":\"B\"},{\"name\":\"loops_"
"index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}"; "index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}";
EXPECT_EQ(true, str == case1 || str == case2); EXPECT_EQ(true, str == case1 || str == case2);
} }
...@@ -114,32 +124,48 @@ TEST_F(TestJSONFileDatabase, SaveLoad) { ...@@ -114,32 +124,48 @@ TEST_F(TestJSONFileDatabase, SaveLoad) {
auto fused1 = ir_sch1.Fuse("B", {0, 1}); auto fused1 = ir_sch1.Fuse("B", {0, 1});
ir::IRSchedule ir_sch2 = MakeIRSchedule(lowered_funcs, "k2"); ir::IRSchedule ir_sch2 = MakeIRSchedule(lowered_funcs, "k2");
test_db.AddRecord(TuningRecord("k1", SearchState(std::move(ir_sch1), 1.5), 1.0)); test_db.AddRecord(
test_db.AddRecord(TuningRecord("k2", SearchState(std::move(ir_sch2), 3.5), 3.0)); TuningRecord("k1", SearchState(std::move(ir_sch1), 1.5), 1.0));
test_db.AddRecord(
TuningRecord("k2", SearchState(std::move(ir_sch2), 3.5), 3.0));
std::vector<std::string> strs = ReadLinesFromFile(record_file_path); std::vector<std::string> strs = ReadLinesFromFile(record_file_path);
ASSERT_EQ(strs.size(), 2); ASSERT_EQ(strs.size(), 2);
// Because the serialization of protobuf does not guarantee the order, we give all possible results. // Because the serialization of protobuf does not guarantee the order, we give
// all possible results.
std::string case1 = std::string case1 =
"{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," "{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":"
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":\"INTS\",\"ints\":[0,1]},{\"name\":\"block_" "{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":"
"\"INTS\",\"ints\":[0,1]},{\"name\":\"block_"
"name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}"; "name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}";
std::string case2 = std::string case2 =
"{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," "{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":"
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":\"STRING\",\"s\":\"B\"},{\"name\":\"loops_" "{\"steps\":[{\"type\":\"FuseWithName\","
"\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":"
"\"STRING\",\"s\":\"B\"},{\"name\":\"loops_"
"index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}"; "index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}";
EXPECT_EQ(true, strs[0] == case1 || strs[0] == case2); EXPECT_EQ(true, strs[0] == case1 || strs[0] == case2);
EXPECT_EQ(strs[1], "{\"taskKey\":\"k2\",\"executionCost\":3,\"predictedCost\":3.5,\"trace\":{}}"); EXPECT_EQ(strs[1],
"{\"taskKey\":\"k2\",\"executionCost\":3,\"predictedCost\":3.5,"
"\"trace\":{}}");
} }
TEST_F(TestJSONFileDatabase, Basic) { TEST_F(TestJSONFileDatabase, Basic) {
test_db.AddRecord(TuningRecord("k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0)); test_db.AddRecord(TuningRecord(
test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0)); "k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0));
test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0)); test_db.AddRecord(TuningRecord(
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 8.0), 3.0)); "k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0));
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 7.0), 4.0)); test_db.AddRecord(TuningRecord(
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 6.0), 5.0)); "k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0));
test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 4.0)); test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 8.0), 3.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 7.0), 4.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 6.0), 5.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 4.0));
ASSERT_EQ(test_db.Size(), 6); ASSERT_EQ(test_db.Size(), 6);
auto records = test_db.LookUp("k3"); auto records = test_db.LookUp("k3");
...@@ -152,15 +178,24 @@ TEST_F(TestJSONFileDatabase, Basic) { ...@@ -152,15 +178,24 @@ TEST_F(TestJSONFileDatabase, Basic) {
} }
TEST_F(TestJSONFileDatabase, GetTopK) { TEST_F(TestJSONFileDatabase, GetTopK) {
test_db.AddRecord(TuningRecord("k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0)); test_db.AddRecord(TuningRecord(
test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0)); "k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0));
test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0)); test_db.AddRecord(TuningRecord(
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 3.0)); "k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0));
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 4.0)); test_db.AddRecord(TuningRecord(
test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 5.0)); "k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0));
test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 2.0), 4.0)); test_db.AddRecord(TuningRecord(
test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.2), 2.0)); "k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 3.0));
test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 3.0)); test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 4.0));
test_db.AddRecord(TuningRecord(
"k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 5.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 2.0), 4.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.2), 2.0));
test_db.AddRecord(TuningRecord(
"k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 3.0));
auto records = test_db.GetTopK("k4", 3); auto records = test_db.GetTopK("k4", 3);
ASSERT_EQ(records.size(), 2); ASSERT_EQ(records.size(), 2);
...@@ -171,8 +206,10 @@ TEST_F(TestJSONFileDatabase, GetTopK) { ...@@ -171,8 +206,10 @@ TEST_F(TestJSONFileDatabase, GetTopK) {
TEST_F(TestJSONFileDatabase, Reload) { TEST_F(TestJSONFileDatabase, Reload) {
ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs, "k1"); ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs, "k1");
auto fused = ir_sch.Fuse("B", {0, 1}); auto fused = ir_sch.Fuse("B", {0, 1});
test_db.AddRecord(TuningRecord("k1", SearchState(std::move(ir_sch), 1.0), 1.0)); test_db.AddRecord(
test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0)); TuningRecord("k1", SearchState(std::move(ir_sch), 1.0), 1.0));
test_db.AddRecord(TuningRecord(
"k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0));
auto records = test_db.LookUp("k1"); auto records = test_db.LookUp("k1");
ASSERT_EQ(records.size(), 1); ASSERT_EQ(records.size(), 1);
...@@ -184,11 +221,13 @@ TEST_F(TestJSONFileDatabase, Reload) { ...@@ -184,11 +221,13 @@ TEST_F(TestJSONFileDatabase, Reload) {
EXPECT_EQ(records[0].execution_cost, loaded_records[0].execution_cost); EXPECT_EQ(records[0].execution_cost, loaded_records[0].execution_cost);
EXPECT_EQ(records[0].predicted_cost, loaded_records[0].predicted_cost); EXPECT_EQ(records[0].predicted_cost, loaded_records[0].predicted_cost);
// check the equality of trace info between original TuningRecord and the loaded TuningRecord // check the equality of trace info between original TuningRecord and the
// loaded TuningRecord
const auto& lhs_trace = records[0].trace; const auto& lhs_trace = records[0].trace;
const auto& rhs_trace = loaded_records[0].trace; const auto& rhs_trace = loaded_records[0].trace;
google::protobuf::util::MessageDifferencer dif; google::protobuf::util::MessageDifferencer dif;
static const google::protobuf::Descriptor* descriptor = cinn::ir::proto::ScheduleDesc_Step::descriptor(); static const google::protobuf::Descriptor* descriptor =
cinn::ir::proto::ScheduleDesc_Step::descriptor();
dif.TreatAsSet(descriptor->FindFieldByName("attrs")); dif.TreatAsSet(descriptor->FindFieldByName("attrs"));
EXPECT_TRUE(dif.Compare(lhs_trace, rhs_trace)); EXPECT_TRUE(dif.Compare(lhs_trace, rhs_trace));
......
...@@ -53,7 +53,8 @@ struct MeasureResult { ...@@ -53,7 +53,8 @@ struct MeasureResult {
// The result of building with input schedule // The result of building with input schedule
struct BuildResult { struct BuildResult {
// The scope that owns detail compilation infos of parameters in the runtime program // The scope that owns detail compilation infos of parameters in the runtime
// program
const hlir::framework::Scope* compiled_scope; const hlir::framework::Scope* compiled_scope;
// The executable program // The executable program
std::unique_ptr<hlir::framework::Program> runtime_program; std::unique_ptr<hlir::framework::Program> runtime_program;
...@@ -68,11 +69,13 @@ class ScheduleBuilder { ...@@ -68,11 +69,13 @@ class ScheduleBuilder {
virtual BuildResult Build(const MeasureInput& input) = 0; virtual BuildResult Build(const MeasureInput& input) = 0;
}; };
// This interface defines how to run the built result. Like above ScheduleBuilder, // This interface defines how to run the built result. Like above
// a runner shoule be implemented with not bound to a specific task. // ScheduleBuilder, a runner shoule be implemented with not bound to a specific
// task.
class ScheduleRunner { class ScheduleRunner {
public: public:
virtual MeasureResult Run(const MeasureInput& input, const BuildResult& build_result) = 0; virtual MeasureResult Run(const MeasureInput& input,
const BuildResult& build_result) = 0;
}; };
} // namespace auto_schedule } // namespace auto_schedule
......
...@@ -68,10 +68,15 @@ class TestMeasurer : public ::testing::Test { ...@@ -68,10 +68,15 @@ class TestMeasurer : public ::testing::Test {
graph_compiler = std::make_unique<GraphCompiler>(target, scope, graph); graph_compiler = std::make_unique<GraphCompiler>(target, scope, graph);
TaskCreator task_creator; TaskCreator task_creator;
tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); tasks = task_creator.CreateTuneTaskOpLevel(graph.get());
const auto& dtype_dict = graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype"); const auto& dtype_dict =
const auto& shape_dict = graph->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape"); graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target); const auto& shape_dict = graph->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>(
"infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target);
inputs.reserve(tasks.size()); inputs.reserve(tasks.size());
for (int i = 0; i < tasks.size(); ++i) { for (int i = 0; i < tasks.size(); ++i) {
auto* task = &tasks[i]; auto* task = &tasks[i];
...@@ -95,13 +100,17 @@ class ThrowExceptionRunner : public ScheduleRunner { ...@@ -95,13 +100,17 @@ class ThrowExceptionRunner : public ScheduleRunner {
struct Exception : public std::exception { struct Exception : public std::exception {
const char* what() const throw() { return "RunError"; } const char* what() const throw() { return "RunError"; }
}; };
MeasureResult Run(const MeasureInput& input, const BuildResult& build_result) override { throw Exception(); } MeasureResult Run(const MeasureInput& input,
const BuildResult& build_result) override {
throw Exception();
}
}; };
TEST_F(TestMeasurer, Basic) { TEST_F(TestMeasurer, Basic) {
auto builder = std::make_unique<SimpleBuilder>(graph_compiler.get()); auto builder = std::make_unique<SimpleBuilder>(graph_compiler.get());
auto runner = std::make_unique<SimpleRunner>(1); auto runner = std::make_unique<SimpleRunner>(1);
auto measurer = std::make_unique<ScheduleMeasurer>(builder.get(), runner.get()); auto measurer =
std::make_unique<ScheduleMeasurer>(builder.get(), runner.get());
std::vector<MeasureResult> results = measurer->Measure(inputs); std::vector<MeasureResult> results = measurer->Measure(inputs);
ASSERT_EQ(inputs.size(), results.size()); ASSERT_EQ(inputs.size(), results.size());
} }
...@@ -111,13 +120,16 @@ TEST_F(TestMeasurer, CatchException) { ...@@ -111,13 +120,16 @@ TEST_F(TestMeasurer, CatchException) {
auto runner = std::make_unique<SimpleRunner>(1); auto runner = std::make_unique<SimpleRunner>(1);
auto throw_builder = std::make_unique<ThrowExceptionBuilder>(); auto throw_builder = std::make_unique<ThrowExceptionBuilder>();
auto throw_runner = std::make_unique<ThrowExceptionRunner>(); auto throw_runner = std::make_unique<ThrowExceptionRunner>();
auto measurer_with_build_error = std::make_unique<ScheduleMeasurer>(throw_builder.get(), runner.get(), 2); auto measurer_with_build_error =
std::vector<MeasureResult> results = measurer_with_build_error->Measure(inputs); std::make_unique<ScheduleMeasurer>(throw_builder.get(), runner.get(), 2);
std::vector<MeasureResult> results =
measurer_with_build_error->Measure(inputs);
ASSERT_EQ(inputs.size(), results.size()); ASSERT_EQ(inputs.size(), results.size());
EXPECT_EQ(results[0].error_msg, "Build failed, error: BuildError\n"); EXPECT_EQ(results[0].error_msg, "Build failed, error: BuildError\n");
// TODO(CtfGo): test parallel build after we support thread-safe compilation // TODO(CtfGo): test parallel build after we support thread-safe compilation
auto measurer_with_run_error = std::make_unique<ScheduleMeasurer>(builder.get(), throw_runner.get(), 1); auto measurer_with_run_error =
std::make_unique<ScheduleMeasurer>(builder.get(), throw_runner.get(), 1);
results = measurer_with_run_error->Measure(inputs); results = measurer_with_run_error->Measure(inputs);
ASSERT_EQ(inputs.size(), results.size()); ASSERT_EQ(inputs.size(), results.size());
EXPECT_EQ(results[0].error_msg, "Run failed, error: RunError\n"); EXPECT_EQ(results[0].error_msg, "Run failed, error: RunError\n");
......
...@@ -21,10 +21,13 @@ ...@@ -21,10 +21,13 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
ScheduleMeasurer::ScheduleMeasurer(ScheduleBuilder* builder, ScheduleRunner* runner, int num_threads) ScheduleMeasurer::ScheduleMeasurer(ScheduleBuilder* builder,
ScheduleRunner* runner,
int num_threads)
: builder_(builder), runner_(runner), num_threads_(num_threads) {} : builder_(builder), runner_(runner), num_threads_(num_threads) {}
std::vector<MeasureResult> ScheduleMeasurer::Measure(const std::vector<MeasureInput>& inputs) { std::vector<MeasureResult> ScheduleMeasurer::Measure(
const std::vector<MeasureInput>& inputs) {
if (inputs.empty()) { if (inputs.empty()) {
LOG(WARNING) << "inputs is empty"; LOG(WARNING) << "inputs is empty";
return {}; return {};
...@@ -33,20 +36,24 @@ std::vector<MeasureResult> ScheduleMeasurer::Measure(const std::vector<MeasureIn ...@@ -33,20 +36,24 @@ std::vector<MeasureResult> ScheduleMeasurer::Measure(const std::vector<MeasureIn
std::vector<MeasureResult> results(inputs.size()); std::vector<MeasureResult> results(inputs.size());
// define how to build a candidate with the specified index // define how to build a candidate with the specified index
auto build_fn = [builder = builder_, &inputs, &build_results, &results](int index) { auto build_fn =
[builder = builder_, &inputs, &build_results, &results](int index) {
VLOG(6) << "Build candidate index: " << index; VLOG(6) << "Build candidate index: " << index;
auto m_start = std::chrono::steady_clock::now(); auto m_start = std::chrono::steady_clock::now();
try { try {
build_results[index] = builder->Build(inputs[index]); build_results[index] = builder->Build(inputs[index]);
} catch (std::exception& e) { } catch (std::exception& e) {
results[index].error_msg = utils::StringFormat("Build failed, error: %s\n", e.what()); results[index].error_msg =
utils::StringFormat("Build failed, error: %s\n", e.what());
} }
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - m_start); auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - m_start);
results[index].elapsed_time += static_cast<double>(time_span.count()); results[index].elapsed_time += static_cast<double>(time_span.count());
}; };
// define how to run a candidate with the specified index // define how to run a candidate with the specified index
auto run_fn = [runner = runner_, &inputs, &build_results, &results](int index) { auto run_fn =
[runner = runner_, &inputs, &build_results, &results](int index) {
VLOG(6) << "Run candidate index: " << index; VLOG(6) << "Run candidate index: " << index;
auto m_start = std::chrono::steady_clock::now(); auto m_start = std::chrono::steady_clock::now();
try { try {
...@@ -55,9 +62,11 @@ std::vector<MeasureResult> ScheduleMeasurer::Measure(const std::vector<MeasureIn ...@@ -55,9 +62,11 @@ std::vector<MeasureResult> ScheduleMeasurer::Measure(const std::vector<MeasureIn
results[index] = runner->Run(inputs[index], build_results[index]); results[index] = runner->Run(inputs[index], build_results[index]);
} }
} catch (std::exception& e) { } catch (std::exception& e) {
results[index].error_msg = utils::StringFormat("Run failed, error: %s\n", e.what()); results[index].error_msg =
utils::StringFormat("Run failed, error: %s\n", e.what());
} }
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - m_start); auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - m_start);
results[index].elapsed_time += static_cast<double>(time_span.count()); results[index].elapsed_time += static_cast<double>(time_span.count());
}; };
...@@ -66,8 +75,10 @@ std::vector<MeasureResult> ScheduleMeasurer::Measure(const std::vector<MeasureIn ...@@ -66,8 +75,10 @@ std::vector<MeasureResult> ScheduleMeasurer::Measure(const std::vector<MeasureIn
build_fn(index); build_fn(index);
run_fn(index); run_fn(index);
}; };
// default num_threads_ is 1 and in that case it will perform all measurements sequentially inplace. // default num_threads_ is 1 and in that case it will perform all measurements
utils::parallel_run(measure_fn, utils::SequenceDispatcher(0, inputs.size()), num_threads_); // sequentially inplace.
utils::parallel_run(
measure_fn, utils::SequenceDispatcher(0, inputs.size()), num_threads_);
VLOG(4) << "Measure " << inputs.size() << " candidates"; VLOG(4) << "Measure " << inputs.size() << " candidates";
return results; return results;
......
...@@ -25,7 +25,9 @@ namespace auto_schedule { ...@@ -25,7 +25,9 @@ namespace auto_schedule {
// which are building the input schedules and running the generated codes. // which are building the input schedules and running the generated codes.
class ScheduleMeasurer { class ScheduleMeasurer {
public: public:
ScheduleMeasurer(ScheduleBuilder* builder, ScheduleRunner* runner, int num_threads = 1); ScheduleMeasurer(ScheduleBuilder* builder,
ScheduleRunner* runner,
int num_threads = 1);
// Measure a batch of inputs and return all results once. // Measure a batch of inputs and return all results once.
std::vector<MeasureResult> Measure(const std::vector<MeasureInput>& inputs); std::vector<MeasureResult> Measure(const std::vector<MeasureInput>& inputs);
......
...@@ -19,17 +19,21 @@ namespace auto_schedule { ...@@ -19,17 +19,21 @@ namespace auto_schedule {
using hlir::framework::GraphCompiler; using hlir::framework::GraphCompiler;
SimpleBuilder::SimpleBuilder(hlir::framework::GraphCompiler* graph_compiler) : graph_compiler_(graph_compiler) {} SimpleBuilder::SimpleBuilder(hlir::framework::GraphCompiler* graph_compiler)
: graph_compiler_(graph_compiler) {}
BuildResult SimpleBuilder::Build(const MeasureInput& input) { BuildResult SimpleBuilder::Build(const MeasureInput& input) {
CHECK_NE(graph_compiler_, static_cast<GraphCompiler*>(nullptr)) << "empty handle to GraphCompiler"; CHECK_NE(graph_compiler_, static_cast<GraphCompiler*>(nullptr))
<< "empty handle to GraphCompiler";
GraphCompiler::CompileOptions compile_options; GraphCompiler::CompileOptions compile_options;
compile_options.groups.emplace_back(input.task->subgraph); compile_options.groups.emplace_back(input.task->subgraph);
compile_options.lowered_funcs.emplace_back(input.lowered_funcs); compile_options.lowered_funcs.emplace_back(input.lowered_funcs);
compile_options.remove_unused_variables = false; compile_options.remove_unused_variables = false;
VLOG(5) << "call GraphCompiler to Build with Graph::Group size=" << compile_options.groups.size() VLOG(5) << "call GraphCompiler to Build with Graph::Group size="
<< ", lowered_funcs group size=" << compile_options.lowered_funcs.size(); << compile_options.groups.size() << ", lowered_funcs group size="
GraphCompiler::CompilationResult compiled_result = graph_compiler_->Build(compile_options); << compile_options.lowered_funcs.size();
GraphCompiler::CompilationResult compiled_result =
graph_compiler_->Build(compile_options);
BuildResult build_result; BuildResult build_result;
build_result.compiled_scope = graph_compiler_->GetScope().get(); build_result.compiled_scope = graph_compiler_->GetScope().get();
......
...@@ -35,7 +35,8 @@ using hlir::framework::Tensor; ...@@ -35,7 +35,8 @@ using hlir::framework::Tensor;
// Parameters that needs to be initialized to 0. // Parameters that needs to be initialized to 0.
// Key is the Op name, and value is the index of the input parameter in the Op. // Key is the Op name, and value is the index of the input parameter in the Op.
static const std::unordered_map<std::string, std::vector<int>> kInitWithZeroParams = { static const std::unordered_map<std::string, std::vector<int>>
kInitWithZeroParams = {
{"lookup_table", {1}}, {"lookup_table", {1}},
{"gather", {1}}, {"gather", {1}},
{"gather_nd", {1}}, {"gather_nd", {1}},
...@@ -44,38 +45,53 @@ static const std::unordered_map<std::string, std::vector<int>> kInitWithZeroPara ...@@ -44,38 +45,53 @@ static const std::unordered_map<std::string, std::vector<int>> kInitWithZeroPara
}; };
// Generate random value and populate them to the output address of memory // Generate random value and populate them to the output address of memory
static void PopulateRandomValue(const common::Type& type, const int numel, void* raw_ptr) { static void PopulateRandomValue(const common::Type& type,
const int numel,
void* raw_ptr) {
std::random_device seed; std::random_device seed;
std::default_random_engine engine(seed()); std::default_random_engine engine(seed());
if (type == common::Bool()) { if (type == common::Bool()) {
auto* fmt_ptr = reinterpret_cast<bool*>(raw_ptr); auto* fmt_ptr = reinterpret_cast<bool*>(raw_ptr);
std::bernoulli_distribution dist(0.5); std::bernoulli_distribution dist(0.5);
std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else if (type == common::I32()) { } else if (type == common::I32()) {
auto* fmt_ptr = reinterpret_cast<int*>(raw_ptr); auto* fmt_ptr = reinterpret_cast<int*>(raw_ptr);
std::uniform_int_distribution<int> dist(std::numeric_limits<int>::min(), std::numeric_limits<int>::max()); std::uniform_int_distribution<int> dist(std::numeric_limits<int>::min(),
std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); std::numeric_limits<int>::max());
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else if (type == common::I64()) { } else if (type == common::I64()) {
auto* fmt_ptr = reinterpret_cast<int64_t*>(raw_ptr); auto* fmt_ptr = reinterpret_cast<int64_t*>(raw_ptr);
std::uniform_int_distribution<int64_t> dist(std::numeric_limits<int64_t>::min(), std::uniform_int_distribution<int64_t> dist(
std::numeric_limits<int64_t>::min(),
std::numeric_limits<int64_t>::max()); std::numeric_limits<int64_t>::max());
std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else if (type == common::F32()) { } else if (type == common::F32()) {
auto* fmt_ptr = reinterpret_cast<float*>(raw_ptr); auto* fmt_ptr = reinterpret_cast<float*>(raw_ptr);
std::uniform_real_distribution<float> dist(std::numeric_limits<float>::min(), std::numeric_limits<float>::max()); std::uniform_real_distribution<float> dist(
std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); std::numeric_limits<float>::min(), std::numeric_limits<float>::max());
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else { } else {
CHECK_EQ(type.bytes(), 8) << "Unsupported type: " << type << ", type.bytes = " << type.bytes(); CHECK_EQ(type.bytes(), 8)
<< "Unsupported type: " << type << ", type.bytes = " << type.bytes();
auto* fmt_ptr = reinterpret_cast<uint8_t*>(raw_ptr); auto* fmt_ptr = reinterpret_cast<uint8_t*>(raw_ptr);
std::uniform_int_distribution<uint8_t> dist(std::numeric_limits<uint8_t>::min(), std::uniform_int_distribution<uint8_t> dist(
std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max()); std::numeric_limits<uint8_t>::max());
std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} }
} }
// Initialize a tensor with 0 if init_with_zero == true, otherwise initialize the tensor with random value. // Initialize a tensor with 0 if init_with_zero == true, otherwise initialize
static void InitTensorData(Tensor tensor, const common::Target& target, bool init_with_zero) { // the tensor with random value.
static void InitTensorData(Tensor tensor,
const common::Target& target,
bool init_with_zero) {
int mem_size = tensor->shape().numel() * tensor->type().bytes(); int mem_size = tensor->shape().numel() * tensor->type().bytes();
auto* tensor_data = tensor->mutable_data(target, tensor->type()); auto* tensor_data = tensor->mutable_data(target, tensor->type());
#ifdef CINN_WITH_CUDA #ifdef CINN_WITH_CUDA
...@@ -101,9 +117,11 @@ static void InitTensorData(Tensor tensor, const common::Target& target, bool ini ...@@ -101,9 +117,11 @@ static void InitTensorData(Tensor tensor, const common::Target& target, bool ini
// Find all parameter names in the task corresponding to the MeasureInput // Find all parameter names in the task corresponding to the MeasureInput
// that need to be initialized to 0 when measuring. // that need to be initialized to 0 when measuring.
static std::unordered_set<std::string> ParamsNeedInitWithZero(const MeasureInput& input) { static std::unordered_set<std::string> ParamsNeedInitWithZero(
const MeasureInput& input) {
std::unordered_set<std::string> res; std::unordered_set<std::string> res;
std::vector<hlir::framework::Node*> nodes = input.task->subgraph->CollectNodes(); std::vector<hlir::framework::Node*> nodes =
input.task->subgraph->CollectNodes();
for (auto* node : nodes) { for (auto* node : nodes) {
if (kInitWithZeroParams.count(node->op()->name) != 0) { if (kInitWithZeroParams.count(node->op()->name) != 0) {
std::vector<int> param_idxs = kInitWithZeroParams.at(node->op()->name); std::vector<int> param_idxs = kInitWithZeroParams.at(node->op()->name);
...@@ -111,7 +129,8 @@ static std::unordered_set<std::string> ParamsNeedInitWithZero(const MeasureInput ...@@ -111,7 +129,8 @@ static std::unordered_set<std::string> ParamsNeedInitWithZero(const MeasureInput
for (int param_idx : param_idxs) { for (int param_idx : param_idxs) {
CHECK_GT(inlinks.size(), param_idx); CHECK_GT(inlinks.size(), param_idx);
auto& edge = inlinks.at(param_idx); auto& edge = inlinks.at(param_idx);
std::string param_name = edge->source()->as<hlir::framework::NodeData>()->id(); std::string param_name =
edge->source()->as<hlir::framework::NodeData>()->id();
VLOG(6) << "param needs to be init with 0: " << param_name; VLOG(6) << "param needs to be init with 0: " << param_name;
res.insert(param_name); res.insert(param_name);
} }
...@@ -128,7 +147,8 @@ SimpleRunner::SimpleRunner(int repeat_times) : repeat_times_(repeat_times) { ...@@ -128,7 +147,8 @@ SimpleRunner::SimpleRunner(int repeat_times) : repeat_times_(repeat_times) {
// Prepare execution arguments of all instructions to run, a argument // Prepare execution arguments of all instructions to run, a argument
// may be obtained from the input of measurement or allocating new buffer // may be obtained from the input of measurement or allocating new buffer
// with random value. // with random value.
std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureInput& input, std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(
const MeasureInput& input,
const BuildResult& build_result, const BuildResult& build_result,
hlir::framework::Scope* temp_scope) { hlir::framework::Scope* temp_scope) {
std::map<std::string, cinn_pod_value_t> result; std::map<std::string, cinn_pod_value_t> result;
...@@ -138,7 +158,8 @@ std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureI ...@@ -138,7 +158,8 @@ std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureI
const auto* compiled_scope = build_result.compiled_scope; const auto* compiled_scope = build_result.compiled_scope;
const auto& instructions = build_result.runtime_program->GetRunInstructions(); const auto& instructions = build_result.runtime_program->GetRunInstructions();
std::unordered_set<std::string> params_need_init_with_zero = ParamsNeedInitWithZero(input); std::unordered_set<std::string> params_need_init_with_zero =
ParamsNeedInitWithZero(input);
auto fill_arg_fn = [&](const std::string& param) { auto fill_arg_fn = [&](const std::string& param) {
VLOG(6) << "Filling argument:" << param; VLOG(6) << "Filling argument:" << param;
...@@ -169,7 +190,8 @@ std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureI ...@@ -169,7 +190,8 @@ std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureI
temp_tensor->Resize(compiled_tensor->shape()); temp_tensor->Resize(compiled_tensor->shape());
temp_tensor->set_type(compiled_tensor->type()); temp_tensor->set_type(compiled_tensor->type());
temp_tensor->mutable_data(target, compiled_tensor->type()); temp_tensor->mutable_data(target, compiled_tensor->type());
InitTensorData(temp_tensor, target, params_need_init_with_zero.count(param) != 0); InitTensorData(
temp_tensor, target, params_need_init_with_zero.count(param) != 0);
result.emplace(param, temp_tensor->buffer()); result.emplace(param, temp_tensor->buffer());
}; };
...@@ -186,7 +208,8 @@ std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureI ...@@ -186,7 +208,8 @@ std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(const MeasureI
return result; return result;
} }
MeasureResult SimpleRunner::Run(const MeasureInput& input, const BuildResult& build_result) { MeasureResult SimpleRunner::Run(const MeasureInput& input,
const BuildResult& build_result) {
MeasureResult result; MeasureResult result;
auto t_start = std::chrono::steady_clock::now(); auto t_start = std::chrono::steady_clock::now();
// prepare execution arguments // prepare execution arguments
...@@ -209,16 +232,18 @@ MeasureResult SimpleRunner::Run(const MeasureInput& input, const BuildResult& bu ...@@ -209,16 +232,18 @@ MeasureResult SimpleRunner::Run(const MeasureInput& input, const BuildResult& bu
CUDA_CALL(cudaDeviceSynchronize()); CUDA_CALL(cudaDeviceSynchronize());
} }
#endif #endif
auto time_span = auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - run_start); std::chrono::steady_clock::now() - run_start);
auto cost_avg = static_cast<double>(time_span.count()) / repeat_times_; auto cost_avg = static_cast<double>(time_span.count()) / repeat_times_;
result.execution_cost += cost_avg; result.execution_cost += cost_avg;
} }
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - t_start); auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - t_start);
result.elapsed_time = static_cast<double>(time_span.count()); result.elapsed_time = static_cast<double>(time_span.count());
VLOG(4) << "A measurement done:repeat_times[" << repeat_times_ << "]total_elapsed_time[" << result.elapsed_time VLOG(4) << "A measurement done:repeat_times[" << repeat_times_
<< "]total_elapsed_time[" << result.elapsed_time
<< "]us,execution_cost[" << result.execution_cost << "]us"; << "]us,execution_cost[" << result.execution_cost << "]us";
return result; return result;
} }
......
...@@ -26,10 +26,12 @@ class SimpleRunner : public ScheduleRunner { ...@@ -26,10 +26,12 @@ class SimpleRunner : public ScheduleRunner {
public: public:
SimpleRunner(int repeat_times); SimpleRunner(int repeat_times);
MeasureResult Run(const MeasureInput& input, const BuildResult& build_result) override; MeasureResult Run(const MeasureInput& input,
const BuildResult& build_result) override;
private: private:
std::map<std::string, cinn_pod_value_t> PrepareArgs(const MeasureInput& input, std::map<std::string, cinn_pod_value_t> PrepareArgs(
const MeasureInput& input,
const BuildResult& build_result, const BuildResult& build_result,
hlir::framework::Scope* temp_scope); hlir::framework::Scope* temp_scope);
......
...@@ -56,7 +56,8 @@ class TestSimpleRunner : public ::testing::Test { ...@@ -56,7 +56,8 @@ class TestSimpleRunner : public ::testing::Test {
auto program = CreateAddReluProgram(); auto program = CreateAddReluProgram();
auto graph = cinn::frontend::Optimize(&program, fetch_ids, target); auto graph = cinn::frontend::Optimize(&program, fetch_ids, target);
compiled_scope = BuildScope(target, graph); compiled_scope = BuildScope(target, graph);
graph_compiler = std::make_unique<GraphCompiler>(target, compiled_scope, graph); graph_compiler =
std::make_unique<GraphCompiler>(target, compiled_scope, graph);
auto runtime_program = graph_compiler->Build(); auto runtime_program = graph_compiler->Build();
const auto& instructions = runtime_program->GetRunInstructions(); const auto& instructions = runtime_program->GetRunInstructions();
ASSERT_EQ(1, instructions.size()); ASSERT_EQ(1, instructions.size());
...@@ -115,11 +116,15 @@ TEST_F(TestSimpleRunner, TimeMeasured) { ...@@ -115,11 +116,15 @@ TEST_F(TestSimpleRunner, TimeMeasured) {
BuildResult build_result; BuildResult build_result;
build_result.compiled_scope = nullptr; build_result.compiled_scope = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions; std::vector<std::unique_ptr<Instruction>> instructions;
instructions.emplace_back( instructions.emplace_back(new Instruction(common::DefaultHostTarget(),
new Instruction(common::DefaultHostTarget(), nullptr, {}, {"empty_placeholder"}, "sleep_fn")); nullptr,
{},
{"empty_placeholder"},
"sleep_fn"));
instructions.back()->SetLoweredFunc(reinterpret_cast<void*>(sleep_fn)); instructions.back()->SetLoweredFunc(reinterpret_cast<void*>(sleep_fn));
instructions.back()->Finalize(); instructions.back()->Finalize();
build_result.runtime_program.reset(new hlir::framework::Program(nullptr, std::move(instructions))); build_result.runtime_program.reset(
new hlir::framework::Program(nullptr, std::move(instructions)));
// to skip the condition check of params in Instruction::PreparePodArgs // to skip the condition check of params in Instruction::PreparePodArgs
std::map<std::string, cinn_pod_value_t> preset_args; std::map<std::string, cinn_pod_value_t> preset_args;
......
...@@ -22,10 +22,12 @@ ...@@ -22,10 +22,12 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
int ExtractNumThreads(const ir::IRSchedule& ir_schedule, const std::string& bind_axis) { int ExtractNumThreads(const ir::IRSchedule& ir_schedule,
const std::string& bind_axis) {
const ir::ScheduleDesc& trace = ir_schedule.GetTraceDesc(); const ir::ScheduleDesc& trace = ir_schedule.GetTraceDesc();
for (auto&& step : trace.Steps()) { for (auto&& step : trace.Steps()) {
if (step.type == "Bind" && step.attrs.find("thread_axis") != step.attrs.end() && if (step.type == "Bind" &&
step.attrs.find("thread_axis") != step.attrs.end() &&
absl::get<std::string>(step.attrs.at("thread_axis")) == bind_axis) { absl::get<std::string>(step.attrs.at("thread_axis")) == bind_axis) {
CHECK_EQ(step.inputs.at("loop").size(), 1); CHECK_EQ(step.inputs.at("loop").size(), 1);
return step.inputs.at("loop")[0].As<ir::For>()->extent.as_int32(); return step.inputs.at("loop")[0].As<ir::For>()->extent.as_int32();
...@@ -38,9 +40,13 @@ std::vector<std::string> FindCandidates(const ir::ScheduleDesc& trace) { ...@@ -38,9 +40,13 @@ std::vector<std::string> FindCandidates(const ir::ScheduleDesc& trace) {
std::vector<std::string> candidate_block_names; std::vector<std::string> candidate_block_names;
for (auto&& step : trace.Steps()) { for (auto&& step : trace.Steps()) {
if (step.type == "AnnotateIntAttr" && if (step.type == "AnnotateIntAttr" &&
absl::get<std::string>(step.attrs.at("key")) == ir::attr::cooperative_process) { absl::get<std::string>(step.attrs.at("key")) ==
ir::attr::cooperative_process) {
candidate_block_names.push_back( candidate_block_names.push_back(
step.inputs.at("block")[0].As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name); step.inputs.at("block")[0]
.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name);
} }
} }
return candidate_block_names; return candidate_block_names;
......
...@@ -20,8 +20,9 @@ namespace cinn { ...@@ -20,8 +20,9 @@ namespace cinn {
namespace auto_schedule { namespace auto_schedule {
/* /*
* @brief Rewrite the cooperative_process annotation to actually bind the loop on threadIdx. * @brief Rewrite the cooperative_process annotation to actually bind the loop
* This rule is used for collaborative data handling of multiple threads within the same block. * on threadIdx. This rule is used for collaborative data handling of multiple
* threads within the same block.
*/ */
class CooperativeProcess : public PostScheduleRule { class CooperativeProcess : public PostScheduleRule {
public: public:
......
...@@ -44,17 +44,27 @@ TEST_F(TestCooperativeProcess, Matmul) { ...@@ -44,17 +44,27 @@ TEST_F(TestCooperativeProcess, Matmul) {
int steps_k = 8; int steps_k = 8;
Initialize(common::DefaultNVGPUTarget()); Initialize(common::DefaultNVGPUTarget());
frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}}); frontend::Program matmul_op =
tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}});
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed); ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed);
// split loops // split loops
std::vector<ir::Expr> loops = ir_schedule.GetLoops("temp_matmul_out"); std::vector<ir::Expr> loops = ir_schedule.GetLoops("temp_matmul_out");
std::vector<ir::Expr> k_loops = ir_schedule.Split(loops[2], {steps_k, -1}); std::vector<ir::Expr> k_loops = ir_schedule.Split(loops[2], {steps_k, -1});
std::vector<ir::Expr> j_loops = ir_schedule.Split(loops[1], {num_blocks_x, num_threads_x, -1}); std::vector<ir::Expr> j_loops =
std::vector<ir::Expr> i_loops = ir_schedule.Split(loops[0], {num_blocks_y, num_threads_y, -1}); ir_schedule.Split(loops[1], {num_blocks_x, num_threads_x, -1});
std::vector<ir::Expr> i_loops =
ir_schedule.Split(loops[0], {num_blocks_y, num_threads_y, -1});
// reorder to "SSRRS": i0, j0, i1, j1, k0, k1, j2, i2 // reorder to "SSRRS": i0, j0, i1, j1, k0, k1, j2, i2
loops = ir_schedule.GetLoops("temp_matmul_out"); loops = ir_schedule.GetLoops("temp_matmul_out");
ir_schedule.Reorder({loops[0], loops[3], loops[1], loops[4], loops[6], loops[7], loops[2], loops[5]}); ir_schedule.Reorder({loops[0],
loops[3],
loops[1],
loops[4],
loops[6],
loops[7],
loops[2],
loops[5]});
// fuse and bind // fuse and bind
loops = ir_schedule.GetLoops("temp_matmul_out"); loops = ir_schedule.GetLoops("temp_matmul_out");
ir::Expr i1_j1_fused = ir_schedule.Fuse({loops[2], loops[3]}); ir::Expr i1_j1_fused = ir_schedule.Fuse({loops[2], loops[3]});
...@@ -65,23 +75,31 @@ TEST_F(TestCooperativeProcess, Matmul) { ...@@ -65,23 +75,31 @@ TEST_F(TestCooperativeProcess, Matmul) {
// cache read // cache read
ir::Expr out_block = ir_schedule.GetBlock("temp_matmul_out"); ir::Expr out_block = ir_schedule.GetBlock("temp_matmul_out");
ir::Expr X_cache_block = ir_schedule.CacheRead(out_block, 1, "shared"); ir::Expr X_cache_block = ir_schedule.CacheRead(out_block, 1, "shared");
std::string X_cache_block_name = std::string X_cache_block_name = X_cache_block.As<ir::ScheduleBlockRealize>()
X_cache_block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name; ->schedule_block.As<ir::ScheduleBlock>()
->name;
loops = ir_schedule.GetLoops("temp_matmul_out"); loops = ir_schedule.GetLoops("temp_matmul_out");
ir_schedule.ComputeAt(X_cache_block, loops[2]); ir_schedule.ComputeAt(X_cache_block, loops[2]);
std::vector<ir::Expr> X_cache_loops = ir_schedule.GetLoops(X_cache_block_name); std::vector<ir::Expr> X_cache_loops =
ir_schedule.GetLoops(X_cache_block_name);
ir_schedule.Fuse({X_cache_loops[3], X_cache_loops[4]}); ir_schedule.Fuse({X_cache_loops[3], X_cache_loops[4]});
ir_schedule.Annotate(ir_schedule.GetBlock(X_cache_block_name), ir::attr::cooperative_process, 0); ir_schedule.Annotate(ir_schedule.GetBlock(X_cache_block_name),
ir::attr::cooperative_process,
0);
out_block = ir_schedule.GetBlock("temp_matmul_out"); out_block = ir_schedule.GetBlock("temp_matmul_out");
ir::Expr Y_cache_block = ir_schedule.CacheRead(out_block, 2, "shared"); ir::Expr Y_cache_block = ir_schedule.CacheRead(out_block, 2, "shared");
std::string Y_cache_block_name = std::string Y_cache_block_name = Y_cache_block.As<ir::ScheduleBlockRealize>()
Y_cache_block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name; ->schedule_block.As<ir::ScheduleBlock>()
->name;
loops = ir_schedule.GetLoops("temp_matmul_out"); loops = ir_schedule.GetLoops("temp_matmul_out");
ir_schedule.ComputeAt(Y_cache_block, loops[2]); ir_schedule.ComputeAt(Y_cache_block, loops[2]);
std::vector<ir::Expr> Y_cache_loops = ir_schedule.GetLoops(Y_cache_block_name); std::vector<ir::Expr> Y_cache_loops =
ir_schedule.GetLoops(Y_cache_block_name);
ir_schedule.Fuse({Y_cache_loops[3], Y_cache_loops[4]}); ir_schedule.Fuse({Y_cache_loops[3], Y_cache_loops[4]});
ir_schedule.Annotate(ir_schedule.GetBlock(Y_cache_block_name), ir::attr::cooperative_process, 0); ir_schedule.Annotate(ir_schedule.GetBlock(Y_cache_block_name),
ir::attr::cooperative_process,
0);
// apply CooperativeProcess // apply CooperativeProcess
CooperativeProcess cooperative_process; CooperativeProcess cooperative_process;
...@@ -187,7 +205,8 @@ TEST_F(TestCooperativeProcess, Matmul) { ...@@ -187,7 +205,8 @@ TEST_F(TestCooperativeProcess, Matmul) {
// execute and check precision // execute and check precision
CheckResult( CheckResult(
GenExecutableKernel(ir_module), GenExecutableKernel(ir_module),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))), GenExecutableKernel(BuildIRModule(MakeIRSchedule(
matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_input_names, default_input_names,
default_output_names, default_output_names,
{X_shape, Y_shape}, {X_shape, Y_shape},
......
...@@ -29,21 +29,28 @@ static constexpr uint32_t kMaxBlocks = 256; ...@@ -29,21 +29,28 @@ static constexpr uint32_t kMaxBlocks = 256;
bool IsSpatialLoop(const ir::For* for_node) { bool IsSpatialLoop(const ir::For* for_node) {
if (for_node->for_type() != ir::ForType::Serial) return false; if (for_node->for_type() != ir::ForType::Serial) return false;
const auto& loop_var = for_node->loop_var; const auto& loop_var = for_node->loop_var;
// collect cases where the loop_var used in one of reduce axis in underneath ScheduleBlock // collect cases where the loop_var used in one of reduce axis in underneath
auto used_for_reduce_axis = ir::CollectIRNodesWithoutTensor(for_node->body, [&loop_var](const Expr* x) { // ScheduleBlock
auto used_for_reduce_axis = ir::CollectIRNodesWithoutTensor(
for_node->body, [&loop_var](const Expr* x) {
const auto* block_realize = x->As<ir::ScheduleBlockRealize>(); const auto* block_realize = x->As<ir::ScheduleBlockRealize>();
if (!block_realize) return false; if (!block_realize) return false;
const auto* schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>(); const auto* schedule_block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock"; CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock";
CHECK_EQ(block_realize->iter_values.size(), schedule_block->iter_vars.size()); CHECK_EQ(block_realize->iter_values.size(),
schedule_block->iter_vars.size());
for (int i = 0; i < block_realize->iter_values.size(); ++i) { for (int i = 0; i < block_realize->iter_values.size(); ++i) {
const ir::Var& iter_var = schedule_block->iter_vars[i]; const ir::Var& iter_var = schedule_block->iter_vars[i];
const ir::Expr& binding = block_realize->iter_values[i]; const ir::Expr& binding = block_realize->iter_values[i];
if (iter_var->is_reduce_axis || iter_var->name.substr(0, 6) == "reduce") { if (iter_var->is_reduce_axis ||
auto used_exprs = ir::CollectIRNodesWithoutTensor(binding, [&loop_var](const Expr* x) { iter_var->name.substr(0, 6) == "reduce") {
auto used_exprs = ir::CollectIRNodesWithoutTensor(
binding, [&loop_var](const Expr* x) {
const ir::_Var_* var = x->As<ir::_Var_>(); const ir::_Var_* var = x->As<ir::_Var_>();
if (var && (x->same_as(loop_var) || var->name == loop_var->name)) { if (var &&
(x->same_as(loop_var) || var->name == loop_var->name)) {
return true; return true;
} }
return false; return false;
...@@ -59,7 +66,8 @@ bool IsSpatialLoop(const ir::For* for_node) { ...@@ -59,7 +66,8 @@ bool IsSpatialLoop(const ir::For* for_node) {
return true; return true;
} }
// count the number of loops that can be binded from the input for_node to bottom // count the number of loops that can be binded from the input for_node to
// bottom
int CountLoopCanBinded(const ir::For* for_node) { int CountLoopCanBinded(const ir::For* for_node) {
int cnt = 0; int cnt = 0;
while (for_node) { while (for_node) {
...@@ -68,9 +76,11 @@ int CountLoopCanBinded(const ir::For* for_node) { ...@@ -68,9 +76,11 @@ int CountLoopCanBinded(const ir::For* for_node) {
cnt += 1; cnt += 1;
CHECK(for_node->body.defined() && for_node->body.As<ir::Block>()) << "Body is not defined"; CHECK(for_node->body.defined() && for_node->body.As<ir::Block>())
<< "Body is not defined";
const ir::Block* body = for_node->body.As<ir::Block>(); const ir::Block* body = for_node->body.As<ir::Block>();
// terminate when body of this loop has more than one statement or the body is not a ir::For node // terminate when body of this loop has more than one statement or the body
// is not a ir::For node
for_node = body->stmts.size() == 1 ? body->stmts[0].As<ir::For>() : nullptr; for_node = body->stmts.size() == 1 ? body->stmts[0].As<ir::For>() : nullptr;
} }
return cnt; return cnt;
...@@ -82,13 +92,17 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule, ...@@ -82,13 +92,17 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule,
int max_blocks, int max_blocks,
int max_threads_per_block) { int max_threads_per_block) {
auto all_loops = ir_schedule->GetLoops(block_name); auto all_loops = ir_schedule->GetLoops(block_name);
CHECK_LE(num_loops_to_bind, all_loops.size()) << "The number of loops to be bind is greater than size of all_loops"; CHECK_LE(num_loops_to_bind, all_loops.size())
// check whether it is the case that threadIdx has been binded but blockIdx not, << "The number of loops to be bind is greater than size of all_loops";
// the threadIdx can only be binded in the first loop after num_loops_to_bind loops // check whether it is the case that threadIdx has been binded but blockIdx
// because we has excluded other cases in CountLoopCanBinded // not, the threadIdx can only be binded in the first loop after
// num_loops_to_bind loops because we has excluded other cases in
// CountLoopCanBinded
bool gpu_thread_has_binded = bool gpu_thread_has_binded =
num_loops_to_bind < all_loops.size() && all_loops[num_loops_to_bind].As<ir::For>()->is_gpu_thread_binded(); num_loops_to_bind < all_loops.size() &&
Expr fused_loop = ir_schedule->Fuse({all_loops.begin(), all_loops.begin() + num_loops_to_bind}); all_loops[num_loops_to_bind].As<ir::For>()->is_gpu_thread_binded();
Expr fused_loop = ir_schedule->Fuse(
{all_loops.begin(), all_loops.begin() + num_loops_to_bind});
int32_t extent = fused_loop.As<ir::For>()->extent.as_int32(); int32_t extent = fused_loop.As<ir::For>()->extent.as_int32();
if (gpu_thread_has_binded) { if (gpu_thread_has_binded) {
ir_schedule->Bind(fused_loop, "blockIdx.x"); ir_schedule->Bind(fused_loop, "blockIdx.x");
...@@ -106,7 +120,8 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule, ...@@ -106,7 +120,8 @@ void BindGPUIndex(ir::IRSchedule* ir_schedule,
ir_schedule->Bind(splits[0], "blockIdx.x"); ir_schedule->Bind(splits[0], "blockIdx.x");
ir_schedule->Bind(splits[1], "threadIdx.x"); ir_schedule->Bind(splits[1], "threadIdx.x");
} else { } else {
auto splits = ir_schedule->Split(fused_loop, {-1, max_blocks, max_threads_per_block}); auto splits =
ir_schedule->Split(fused_loop, {-1, max_blocks, max_threads_per_block});
CHECK_EQ(splits.size(), 3); CHECK_EQ(splits.size(), 3);
ir_schedule->Reorder({splits[1], splits[2], splits[0]}); ir_schedule->Reorder({splits[1], splits[2], splits[0]});
all_loops = ir_schedule->GetLoops(block_name); all_loops = ir_schedule->GetLoops(block_name);
...@@ -126,29 +141,36 @@ RuleApplyType AutoBind::Init(ir::IRSchedule* ir_schedule) { ...@@ -126,29 +141,36 @@ RuleApplyType AutoBind::Init(ir::IRSchedule* ir_schedule) {
} }
num_applicable_ = applicable_schedule_blocks_.size(); num_applicable_ = applicable_schedule_blocks_.size();
VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_; VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_;
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
} }
void AutoBind::Apply(int index) { void AutoBind::Apply(int index) {
CHECK_LT(index, applicable_schedule_blocks_.size()) << "invalid apply index:" << index; CHECK_LT(index, applicable_schedule_blocks_.size())
<< "invalid apply index:" << index;
auto applied_block = applicable_schedule_blocks_.at(index); auto applied_block = applicable_schedule_blocks_.at(index);
auto all_loops = ir_schedule_->GetLoops(applied_block); auto all_loops = ir_schedule_->GetLoops(applied_block);
BindGPUIndex(ir_schedule_, BindGPUIndex(ir_schedule_,
applied_block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name, applied_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name,
CountLoopCanBinded(all_loops[0].As<ir::For>()), CountLoopCanBinded(all_loops[0].As<ir::For>()),
kMaxBlocks, kMaxBlocks,
target_->max_num_threads()); target_->max_num_threads());
return; return;
} }
RuleApplyType AutoBind::AnalyseApplyType(SearchState state, const std::string& block_name) const { RuleApplyType AutoBind::AnalyseApplyType(SearchState state,
const std::string& block_name) const {
Expr block_expr = state->ir_schedule.GetBlock(block_name); Expr block_expr = state->ir_schedule.GetBlock(block_name);
auto all_loops = state->ir_schedule.GetLoops(block_expr); auto all_loops = state->ir_schedule.GetLoops(block_expr);
return CountLoopCanBinded(all_loops[0].As<ir::For>()) > 0 ? RuleApplyType::kApplyAndPruneOtherRules return CountLoopCanBinded(all_loops[0].As<ir::For>()) > 0
? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply; : RuleApplyType::kCannotApply;
} }
std::vector<SearchState> AutoBind::ApplyOnBlock(SearchState state, const std::string& block_name) { std::vector<SearchState> AutoBind::ApplyOnBlock(SearchState state,
const std::string& block_name) {
SearchState new_state = state.Copy(); SearchState new_state = state.Copy();
auto all_loops = state->ir_schedule.GetLoops(block_name); auto all_loops = state->ir_schedule.GetLoops(block_name);
BindGPUIndex(&new_state->ir_schedule, BindGPUIndex(&new_state->ir_schedule,
......
...@@ -36,9 +36,11 @@ class AutoBind : public AutoGenRule { ...@@ -36,9 +36,11 @@ class AutoBind : public AutoGenRule {
std::string GetRuleName() const override { return "AutoBind"; } std::string GetRuleName() const override { return "AutoBind"; }
RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) override; std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
private: private:
std::vector<Expr> applicable_schedule_blocks_; std::vector<Expr> applicable_schedule_blocks_;
......
...@@ -36,9 +36,11 @@ class TestAutoBind : public TestAutoGenRuleBase { ...@@ -36,9 +36,11 @@ class TestAutoBind : public TestAutoGenRuleBase {
std::vector<std::string> default_input_names = {"X", "Y"}; std::vector<std::string> default_input_names = {"X", "Y"};
std::vector<std::string> default_output_names = {"temp_matmul_out"}; std::vector<std::string> default_output_names = {"temp_matmul_out"};
void TestApplyOnElementWiseAdd(const std::vector<int>& shape, const std::string& block_name) { void TestApplyOnElementWiseAdd(const std::vector<int>& shape,
const std::string& block_name) {
Initialize(common::DefaultNVGPUTarget()); Initialize(common::DefaultNVGPUTarget());
auto test_program = tests::OpBuilder("elementwise_add").Build({{"X", shape}, {"Y", shape}}); auto test_program =
tests::OpBuilder("elementwise_add").Build({{"X", shape}, {"Y", shape}});
// construct input parameter // construct input parameter
ir::IRSchedule ir_schedule = MakeIRSchedule(test_program); ir::IRSchedule ir_schedule = MakeIRSchedule(test_program);
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
...@@ -48,7 +50,8 @@ class TestAutoBind : public TestAutoGenRuleBase { ...@@ -48,7 +50,8 @@ class TestAutoBind : public TestAutoGenRuleBase {
// apply // apply
AutoBind auto_bind(target_); AutoBind auto_bind(target_);
ASSERT_EQ(auto_bind.AnalyseApplyType(state, block_name), RuleApplyType::kApplyAndPruneOtherRules); ASSERT_EQ(auto_bind.AnalyseApplyType(state, block_name),
RuleApplyType::kApplyAndPruneOtherRules);
auto result = auto_bind.ApplyOnBlock(state, block_name)[0]; auto result = auto_bind.ApplyOnBlock(state, block_name)[0];
std::vector<ir::Expr> exprs = result->ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> exprs = result->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
...@@ -56,7 +59,8 @@ class TestAutoBind : public TestAutoGenRuleBase { ...@@ -56,7 +59,8 @@ class TestAutoBind : public TestAutoGenRuleBase {
// check bind result // check bind result
auto all_loops = result->ir_schedule.GetLoops(block_name); auto all_loops = result->ir_schedule.GetLoops(block_name);
int total_num = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); int total_num =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
if (total_num <= kMaxThreadsPerBlock) { if (total_num <= kMaxThreadsPerBlock) {
ASSERT_EQ(all_loops.size(), 1); ASSERT_EQ(all_loops.size(), 1);
EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(), total_num); EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(), total_num);
...@@ -64,18 +68,22 @@ class TestAutoBind : public TestAutoGenRuleBase { ...@@ -64,18 +68,22 @@ class TestAutoBind : public TestAutoGenRuleBase {
} else if (total_num <= kMaxBlocks * kMaxThreadsPerBlock) { } else if (total_num <= kMaxBlocks * kMaxThreadsPerBlock) {
ASSERT_EQ(all_loops.size(), 2); ASSERT_EQ(all_loops.size(), 2);
EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(), EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(),
static_cast<int32_t>(std::ceil(double(total_num) / kMaxThreadsPerBlock))); static_cast<int32_t>(
std::ceil(double(total_num) / kMaxThreadsPerBlock)));
EXPECT_TRUE(all_loops[0].As<ir::For>()->is_gpu_block_binded()); EXPECT_TRUE(all_loops[0].As<ir::For>()->is_gpu_block_binded());
EXPECT_EQ(all_loops[1].As<ir::For>()->extent.as_int32(), kMaxThreadsPerBlock); EXPECT_EQ(all_loops[1].As<ir::For>()->extent.as_int32(),
kMaxThreadsPerBlock);
EXPECT_TRUE(all_loops[1].As<ir::For>()->is_gpu_thread_binded()); EXPECT_TRUE(all_loops[1].As<ir::For>()->is_gpu_thread_binded());
} else { } else {
ASSERT_EQ(all_loops.size(), 3); ASSERT_EQ(all_loops.size(), 3);
EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(), kMaxBlocks); EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(), kMaxBlocks);
EXPECT_TRUE(all_loops[0].As<ir::For>()->is_gpu_block_binded()); EXPECT_TRUE(all_loops[0].As<ir::For>()->is_gpu_block_binded());
EXPECT_EQ(all_loops[1].As<ir::For>()->extent.as_int32(), kMaxThreadsPerBlock); EXPECT_EQ(all_loops[1].As<ir::For>()->extent.as_int32(),
kMaxThreadsPerBlock);
EXPECT_TRUE(all_loops[1].As<ir::For>()->is_gpu_thread_binded()); EXPECT_TRUE(all_loops[1].As<ir::For>()->is_gpu_thread_binded());
EXPECT_EQ(all_loops[2].As<ir::For>()->extent.as_int32(), EXPECT_EQ(all_loops[2].As<ir::For>()->extent.as_int32(),
static_cast<int32_t>(std::ceil(double(total_num) / (kMaxBlocks * kMaxThreadsPerBlock)))); static_cast<int32_t>(std::ceil(
double(total_num) / (kMaxBlocks * kMaxThreadsPerBlock))));
EXPECT_FALSE(all_loops[2].As<ir::For>()->is_binded()); EXPECT_FALSE(all_loops[2].As<ir::For>()->is_binded());
} }
...@@ -83,8 +91,10 @@ class TestAutoBind : public TestAutoGenRuleBase { ...@@ -83,8 +91,10 @@ class TestAutoBind : public TestAutoGenRuleBase {
auto ir_module = BuildIRModule(result->ir_schedule); auto ir_module = BuildIRModule(result->ir_schedule);
auto source_code = GenSourceCode(ir_module); auto source_code = GenSourceCode(ir_module);
VLOG(6) << "Optimized source code:\n" << source_code; VLOG(6) << "Optimized source code:\n" << source_code;
auto manual_ir_module = BuildIRModule(MakeIRSchedule(test_program, /* apply_manual_schedule*/ true)); auto manual_ir_module = BuildIRModule(
VLOG(6) << "Manual-schedule compiled source code:\n" << GenSourceCode(manual_ir_module); MakeIRSchedule(test_program, /* apply_manual_schedule*/ true));
VLOG(6) << "Manual-schedule compiled source code:\n"
<< GenSourceCode(manual_ir_module);
CheckResult(GenExecutableKernel(ir_module), CheckResult(GenExecutableKernel(ir_module),
GenExecutableKernel(manual_ir_module), GenExecutableKernel(manual_ir_module),
default_input_names, default_input_names,
...@@ -97,16 +107,20 @@ class TestAutoBind : public TestAutoGenRuleBase { ...@@ -97,16 +107,20 @@ class TestAutoBind : public TestAutoGenRuleBase {
TEST_F(TestAutoBind, AnalyseApplyType) { TEST_F(TestAutoBind, AnalyseApplyType) {
Initialize(common::DefaultNVGPUTarget()); Initialize(common::DefaultNVGPUTarget());
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::OpBuilder("matmul").Build({{"X", {32, 64}}, {"Y", {64, 32}}})); ir::IRSchedule ir_schedule = MakeIRSchedule(
tests::OpBuilder("matmul").Build({{"X", {32, 64}}, {"Y", {64, 32}}}));
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
AutoBind auto_bind(target_); AutoBind auto_bind(target_);
const std::string& applied_block_name = default_output_names.back(); const std::string& applied_block_name = default_output_names.back();
// outer two loops of initial Expr are spatial loops, so it can be applied // outer two loops of initial Expr are spatial loops, so it can be applied
EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name),
RuleApplyType::kApplyAndPruneOtherRules);
state->ir_schedule.Fuse(applied_block_name, {0, 1}); state->ir_schedule.Fuse(applied_block_name, {0, 1});
state->ir_schedule.Bind(state->ir_schedule.GetLoops(applied_block_name)[0], "threadIdx.x"); state->ir_schedule.Bind(state->ir_schedule.GetLoops(applied_block_name)[0],
"threadIdx.x");
// after fuse and bind, there is no loops to be binded. // after fuse and bind, there is no loops to be binded.
EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name), RuleApplyType::kCannotApply); EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name),
RuleApplyType::kCannotApply);
} }
TEST_F(TestAutoBind, ApplyOnBlock) { TEST_F(TestAutoBind, ApplyOnBlock) {
......
...@@ -27,12 +27,16 @@ namespace auto_schedule { ...@@ -27,12 +27,16 @@ namespace auto_schedule {
AutoGenRule::AutoGenRule(const common::Target& target) : target_(&target) {} AutoGenRule::AutoGenRule(const common::Target& target) : target_(&target) {}
int AutoGenRule::NumberApplicable() const { int AutoGenRule::NumberApplicable() const {
CHECK_GE(num_applicable_, 0) << "Call " << GetRuleName() << "::NumberApplicable() without initialization."; CHECK_GE(num_applicable_, 0)
<< "Call " << GetRuleName()
<< "::NumberApplicable() without initialization.";
return num_applicable_; return num_applicable_;
} }
void AutoGenRule::ApplyRandomly() { void AutoGenRule::ApplyRandomly() {
CHECK_GT(num_applicable_, 0) << "Call " << GetRuleName() << "::ApplyRandomly() with NumberApplicable() == 0"; CHECK_GT(num_applicable_, 0)
<< "Call " << GetRuleName()
<< "::ApplyRandomly() with NumberApplicable() == 0";
int index = rand() % num_applicable_; int index = rand() % num_applicable_;
return Apply(index); return Apply(index);
} }
......
...@@ -29,15 +29,18 @@ enum class RuleApplyType : int { ...@@ -29,15 +29,18 @@ enum class RuleApplyType : int {
// This rule cannot be applied to ModuleExpr. // This rule cannot be applied to ModuleExpr.
kCannotApply = 0, kCannotApply = 0,
// This rule can be applied to ModuleExpr, // This rule can be applied to ModuleExpr,
// and the original ModuleExpr will be retained for branching with other rules. // and the original ModuleExpr will be retained for branching with other
// rules.
kApply = 1, kApply = 1,
// This rule can be applied, but the original ModuleExpr will be deleted, // This rule can be applied, but the original ModuleExpr will be deleted,
// so the branches with other rules applied on the original ModuleExpr will be pruned. // so the branches with other rules applied on the original ModuleExpr will be
// pruned.
kApplyAndPruneOtherRules = 2, kApplyAndPruneOtherRules = 2,
}; };
/** /**
* Base class for rules of auto-generating schedule (like Ansor's sketch generation) * Base class for rules of auto-generating schedule (like Ansor's sketch
* generation)
* *
*/ */
class AutoGenRule { class AutoGenRule {
...@@ -46,7 +49,8 @@ class AutoGenRule { ...@@ -46,7 +49,8 @@ class AutoGenRule {
~AutoGenRule() = default; ~AutoGenRule() = default;
// Initialize the AutoGenRule, it must be called before further actions. // Initialize the AutoGenRule, it must be called before further actions.
// Returns false if the rule cannot be applied on the mod_expr, true otherwise. // Returns false if the rule cannot be applied on the mod_expr, true
// otherwise.
virtual RuleApplyType Init(ir::IRSchedule* ir_schedule) = 0; virtual RuleApplyType Init(ir::IRSchedule* ir_schedule) = 0;
// CINN IRSchedule can contain many ScheduleBlock(s) and Loop(s), so // CINN IRSchedule can contain many ScheduleBlock(s) and Loop(s), so
...@@ -65,11 +69,15 @@ class AutoGenRule { ...@@ -65,11 +69,15 @@ class AutoGenRule {
// Returns the name of the rule, used for debug. // Returns the name of the rule, used for debug.
virtual std::string GetRuleName() const = 0; virtual std::string GetRuleName() const = 0;
// Analyze the ApplyType of the rule used for a block determined by a specific SearchState and block name // Analyze the ApplyType of the rule used for a block determined by a specific
virtual RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const = 0; // SearchState and block name
virtual RuleApplyType AnalyseApplyType(
SearchState state, const std::string& block_name) const = 0;
// Apply the rule to a block determined by a specific SearchState and block name // Apply the rule to a block determined by a specific SearchState and block
virtual std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) = 0; // name
virtual std::vector<SearchState> ApplyOnBlock(
SearchState state, const std::string& block_name) = 0;
protected: protected:
// number of ScheduleBlock that can apply this auto gen rule // number of ScheduleBlock that can apply this auto gen rule
......
...@@ -34,18 +34,23 @@ ...@@ -34,18 +34,23 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
AutoInline::AutoInline(const common::Target& target, const std::unordered_set<std::string>& no_inline_output_names) AutoInline::AutoInline(
const common::Target& target,
const std::unordered_set<std::string>& no_inline_output_names)
: AutoGenRule(target), no_inline_output_names_(no_inline_output_names) {} : AutoGenRule(target), no_inline_output_names_(no_inline_output_names) {}
bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const { bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
const ir::ScheduleBlockRealize* sche_block_realize = sche_block_realize_expr.As<ir::ScheduleBlockRealize>(); ir::IRSchedule* ir_sch) const {
const ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>(); const ir::ScheduleBlockRealize* sche_block_realize =
sche_block_realize_expr.As<ir::ScheduleBlockRealize>();
const ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
ir::Expr compute_body = sche_block->body; ir::Expr compute_body = sche_block->body;
ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr); ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr);
// Check the schedule block to be inlined is not a reduce tensor. // Check the schedule block to be inlined is not a reduce tensor.
std::set<ir::Expr> find_store = std::set<ir::Expr> find_store = ir::CollectIRNodesWithoutTensor(
ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { return x->As<ir::Store>(); }); compute_body, [&](const Expr* x) { return x->As<ir::Store>(); });
if (find_store.size() != 1UL) { if (find_store.size() != 1UL) {
return false; return false;
} }
...@@ -57,8 +62,10 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir:: ...@@ -57,8 +62,10 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::
} }
// LoweredFunc output can be tensor name or tensor buffer name // LoweredFunc output can be tensor name or tensor buffer name
if (no_inline_output_names_.find(tensor->name) != no_inline_output_names_.end() || if (no_inline_output_names_.find(tensor->name) !=
no_inline_output_names_.find(tensor->buffer->name) != no_inline_output_names_.end()) { no_inline_output_names_.end() ||
no_inline_output_names_.find(tensor->buffer->name) !=
no_inline_output_names_.end()) {
return false; return false;
} }
...@@ -70,26 +77,32 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir:: ...@@ -70,26 +77,32 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::
// Check this schedule block is the only writer of the tensor. // Check this schedule block is the only writer of the tensor.
find_store = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { find_store = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::Store>() && (x->As<ir::Store>()->tensor).as_tensor_ref()->name == tensor->name; return x->As<ir::Store>() &&
(x->As<ir::Store>()->tensor).as_tensor_ref()->name == tensor->name;
}); });
if (find_store.size() != 1UL) { if (find_store.size() != 1UL) {
return false; return false;
} }
// Check there is no overlap between the buffers the schedule block reads and writes. // Check there is no overlap between the buffers the schedule block reads and
std::set<ir::Expr> find_load = ir::CollectIRNodesWithoutTensor( // writes.
compute_body, [&](const Expr* x) { return x->As<ir::Load>() && x->As<ir::Load>()->tensor == tensor_expr; }); std::set<ir::Expr> find_load =
ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) {
return x->As<ir::Load>() && x->As<ir::Load>()->tensor == tensor_expr;
});
if (!find_load.empty()) { if (!find_load.empty()) {
return false; return false;
} }
ir::Expr store = *(find_store.begin()); ir::Expr store = *(find_store.begin());
ir::ComputeInliner inliner(store.As<ir::Store>()->tensor.as_tensor_ref(), store); ir::ComputeInliner inliner(store.As<ir::Store>()->tensor.as_tensor_ref(),
store);
if (!inliner.BodyPatternAllowInline()) { if (!inliner.BodyPatternAllowInline()) {
return false; return false;
} }
ir::LeafBlockRemovalPlan remove_plan(sche_block_realize_expr, &inliner.src_stmt, &inliner.tgt_stmt); ir::LeafBlockRemovalPlan remove_plan(
sche_block_realize_expr, &inliner.src_stmt, &inliner.tgt_stmt);
remove_plan(&root); remove_plan(&root);
if (!inliner.src_stmt.defined() || !inliner.tgt_stmt.defined()) { if (!inliner.src_stmt.defined() || !inliner.tgt_stmt.defined()) {
return false; return false;
...@@ -99,16 +112,20 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir:: ...@@ -99,16 +112,20 @@ bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::
return true; return true;
} }
AutoInlineType AutoInline::AnalyzeInlineType(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const { AutoInlineType AutoInline::AnalyzeInlineType(
const ir::ScheduleBlockRealize* sche_block_realize = sche_block_realize_expr.As<ir::ScheduleBlockRealize>(); const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const {
const ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>(); const ir::ScheduleBlockRealize* sche_block_realize =
sche_block_realize_expr.As<ir::ScheduleBlockRealize>();
const ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
// Inline if the block has only 1 write buffer // Inline if the block has only 1 write buffer
if (sche_block->write_buffers.size() != 1) { if (sche_block->write_buffers.size() != 1) {
return AutoInlineType::kCannotInline; return AutoInlineType::kCannotInline;
} }
std::unordered_set<ir::IrNodeTy> no_inline_node_types = {ir::IrNodeTy::IfThenElse}; std::unordered_set<ir::IrNodeTy> no_inline_node_types = {
ir::IrNodeTy::IfThenElse};
if (ContainsNodeType(sche_block->body, no_inline_node_types)) { if (ContainsNodeType(sche_block->body, no_inline_node_types)) {
return AutoInlineType::kCannotInline; return AutoInlineType::kCannotInline;
} }
...@@ -131,25 +148,32 @@ RuleApplyType AutoInline::Init(ir::IRSchedule* ir_schedule) { ...@@ -131,25 +148,32 @@ RuleApplyType AutoInline::Init(ir::IRSchedule* ir_schedule) {
num_applicable_ = 0; num_applicable_ = 0;
for (size_t i = 0; i < all_block_realizes_.size(); ++i) { for (size_t i = 0; i < all_block_realizes_.size(); ++i) {
ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes_[i].As<ir::ScheduleBlockRealize>(); ir::ScheduleBlockRealize* sche_block_realize =
AnalyzeScheduleBlockReadWriteBuffer(sche_block_realize->schedule_block.As<ir::ScheduleBlock>()); all_block_realizes_[i].As<ir::ScheduleBlockRealize>();
AutoInlineType type = AnalyzeInlineType(all_block_realizes_[i], ir_schedule_); AnalyzeScheduleBlockReadWriteBuffer(
sche_block_realize->schedule_block.As<ir::ScheduleBlock>());
AutoInlineType type =
AnalyzeInlineType(all_block_realizes_[i], ir_schedule_);
if (type != AutoInlineType::kCannotInline) { if (type != AutoInlineType::kCannotInline) {
++num_applicable_; ++num_applicable_;
apply_indices_and_type_.push_back({i, type}); apply_indices_and_type_.push_back({i, type});
} }
} }
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
} }
void AutoInline::Apply(int index) { void AutoInline::Apply(int index) {
CHECK(ir_schedule_ != nullptr) << "Run AutoInline::Apply without Init"; CHECK(ir_schedule_ != nullptr) << "Run AutoInline::Apply without Init";
CHECK(num_applicable_ > 0 && apply_indices_and_type_.size() == num_applicable_) CHECK(num_applicable_ > 0 &&
apply_indices_and_type_.size() == num_applicable_)
<< "AutoInline::Apply pre-condition doesn't meet"; << "AutoInline::Apply pre-condition doesn't meet";
CHECK(index >= 0 && num_applicable_ > index) CHECK(index >= 0 && num_applicable_ > index)
<< "Invalid index for AutoInline::Apply, the index needs 0 <= index && index < NumberApplicable(), " << "Invalid index for AutoInline::Apply, the index needs 0 <= index && "
<< "Currently index = " << index << ", NumberApplicable() = " << num_applicable_; "index < NumberApplicable(), "
<< "Currently index = " << index
<< ", NumberApplicable() = " << num_applicable_;
int apply_index = apply_indices_and_type_[index].first; int apply_index = apply_indices_and_type_[index].first;
Apply(ir_schedule_, all_block_realizes_[apply_index]); Apply(ir_schedule_, all_block_realizes_[apply_index]);
...@@ -158,18 +182,23 @@ void AutoInline::Apply(int index) { ...@@ -158,18 +182,23 @@ void AutoInline::Apply(int index) {
std::string AutoInline::GetRuleName() const { return "AutoInline"; } std::string AutoInline::GetRuleName() const { return "AutoInline"; }
RuleApplyType AutoInline::AnalyseApplyType(SearchState state, const std::string& block_name) const { RuleApplyType AutoInline::AnalyseApplyType(
SearchState state, const std::string& block_name) const {
Expr block_expr = state->ir_schedule.GetBlock(block_name); Expr block_expr = state->ir_schedule.GetBlock(block_name);
auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>(); auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr; CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr;
AnalyzeScheduleBlockReadWriteBuffer(block_realize->schedule_block.As<ir::ScheduleBlock>()); AnalyzeScheduleBlockReadWriteBuffer(
block_realize->schedule_block.As<ir::ScheduleBlock>());
AutoInlineType type = AnalyzeInlineType(block_expr, &state->ir_schedule); AutoInlineType type = AnalyzeInlineType(block_expr, &state->ir_schedule);
return type == AutoInlineType::kCannotInline ? RuleApplyType::kCannotApply : RuleApplyType::kApplyAndPruneOtherRules; return type == AutoInlineType::kCannotInline
? RuleApplyType::kCannotApply
: RuleApplyType::kApplyAndPruneOtherRules;
} }
std::vector<SearchState> AutoInline::ApplyOnBlock(SearchState state, const std::string& block_name) { std::vector<SearchState> AutoInline::ApplyOnBlock(
SearchState state, const std::string& block_name) {
SearchState new_state = state.Copy(); SearchState new_state = state.Copy();
Expr block_expr = new_state->ir_schedule.GetBlock(block_name); Expr block_expr = new_state->ir_schedule.GetBlock(block_name);
Apply(&new_state->ir_schedule, block_expr); Apply(&new_state->ir_schedule, block_expr);
...@@ -181,7 +210,8 @@ void AutoInline::Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { ...@@ -181,7 +210,8 @@ void AutoInline::Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) {
auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>(); auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr; CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr;
AnalyzeScheduleBlockReadWriteBuffer(block_realize->schedule_block.As<ir::ScheduleBlock>()); AnalyzeScheduleBlockReadWriteBuffer(
block_realize->schedule_block.As<ir::ScheduleBlock>());
AutoInlineType type = AnalyzeInlineType(block_expr, ir_schedule); AutoInlineType type = AnalyzeInlineType(block_expr, ir_schedule);
if (type == AutoInlineType::kInlineIntoConsumer) { if (type == AutoInlineType::kInlineIntoConsumer) {
...@@ -202,8 +232,10 @@ void AutoInline::Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { ...@@ -202,8 +232,10 @@ void AutoInline::Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) {
// we need to re-analyze // we need to re-analyze
all_block_realizes_ = ir_schedule->GetAllBlocks(); all_block_realizes_ = ir_schedule->GetAllBlocks();
for (size_t i = 0; i < all_block_realizes_.size(); ++i) { for (size_t i = 0; i < all_block_realizes_.size(); ++i) {
ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes_[i].As<ir::ScheduleBlockRealize>(); ir::ScheduleBlockRealize* sche_block_realize =
ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>(); all_block_realizes_[i].As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
sche_block->read_buffers = {}; sche_block->read_buffers = {};
sche_block->write_buffers = {}; sche_block->write_buffers = {};
AnalyzeScheduleBlockReadWriteBuffer(sche_block); AnalyzeScheduleBlockReadWriteBuffer(sche_block);
......
...@@ -41,7 +41,8 @@ enum class AutoInlineType : int { ...@@ -41,7 +41,8 @@ enum class AutoInlineType : int {
class AutoInline : public AutoGenRule { class AutoInline : public AutoGenRule {
public: public:
AutoInline(const common::Target& target, const std::unordered_set<std::string>& no_inline_output_names); AutoInline(const common::Target& target,
const std::unordered_set<std::string>& no_inline_output_names);
~AutoInline() = default; ~AutoInline() = default;
RuleApplyType Init(ir::IRSchedule* ir_schedule) override; RuleApplyType Init(ir::IRSchedule* ir_schedule) override;
...@@ -50,13 +51,17 @@ class AutoInline : public AutoGenRule { ...@@ -50,13 +51,17 @@ class AutoInline : public AutoGenRule {
std::string GetRuleName() const override; std::string GetRuleName() const override;
AutoInlineType AnalyzeInlineType(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const; AutoInlineType AnalyzeInlineType(const Expr& sche_block_realize_expr,
ir::IRSchedule* ir_sch) const;
bool CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const; bool CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
ir::IRSchedule* ir_sch) const;
RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) override; std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
private: private:
void Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr); void Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr);
......
...@@ -63,7 +63,14 @@ TEST(AutoInline, SingleLoopInline) { ...@@ -63,7 +63,14 @@ TEST(AutoInline, SingleLoopInline) {
poly::StageMap stages = CreateStages({A, B, C}); poly::StageMap stages = CreateStages({A, B, C});
std::vector<ir::LoweredFunc> funcs = std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestAutoInline_SingleLoopInline", stages, {A, C}, {}, {}, nullptr, target, true); lang::LowerVec("TestAutoInline_SingleLoopInline",
stages,
{A, C},
{},
{},
nullptr,
target,
true);
VLOG(6) << "Expr after lowering:"; VLOG(6) << "Expr after lowering:";
VLOG(6) << funcs[0]->body; VLOG(6) << funcs[0]->body;
...@@ -90,7 +97,8 @@ TEST(AutoInline, SingleLoopInline) { ...@@ -90,7 +97,8 @@ TEST(AutoInline, SingleLoopInline) {
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
// ApplyOnBlock // ApplyOnBlock
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "B"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_inline.AnalyseApplyType(state, "B"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "B"); auto new_states = auto_inline.ApplyOnBlock(state, "B");
auto test_func = [](ir::IRSchedule* ir_sch) { auto test_func = [](ir::IRSchedule* ir_sch) {
...@@ -130,7 +138,8 @@ TEST(AutoInline, SingleLoopInline) { ...@@ -130,7 +138,8 @@ TEST(AutoInline, SingleLoopInline) {
// Cannot inline above expr again // Cannot inline above expr again
EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply); EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply);
EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "C"), RuleApplyType::kCannotApply); EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "C"),
RuleApplyType::kCannotApply);
} }
TEST(AutoInline, AddReluInline) { TEST(AutoInline, AddReluInline) {
...@@ -151,12 +160,17 @@ TEST(AutoInline, AddReluInline) { ...@@ -151,12 +160,17 @@ TEST(AutoInline, AddReluInline) {
auto graph = std::make_shared<Graph>(program, target); auto graph = std::make_shared<Graph>(program, target);
hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
const auto& dtype_dict = graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype"); const auto& dtype_dict =
const auto& shape_dict = graph->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape"); graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target); "inferdtype");
const auto& shape_dict = graph->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target);
EXPECT_EQ(graph->fusion_groups.size(), 1UL); EXPECT_EQ(graph->fusion_groups.size(), 1UL);
std::vector<ir::LoweredFunc> funcs = op_lowerer->LowerWithoutSchedule(graph->fusion_groups[0]); std::vector<ir::LoweredFunc> funcs =
op_lowerer->LowerWithoutSchedule(graph->fusion_groups[0]);
VLOG(6) << "Expr before auto inline: " << funcs[0]->body; VLOG(6) << "Expr before auto inline: " << funcs[0]->body;
...@@ -186,10 +200,12 @@ TEST(AutoInline, AddReluInline) { ...@@ -186,10 +200,12 @@ TEST(AutoInline, AddReluInline) {
auto_inline.Apply(0); auto_inline.Apply(0);
// ApplyOnBlock // ApplyOnBlock
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_1"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_1"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "var_1"); auto new_states = auto_inline.ApplyOnBlock(state, "var_1");
// Auto Inline again // Auto Inline again
EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_3"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_3"),
RuleApplyType::kApplyAndPruneOtherRules);
new_states = auto_inline.ApplyOnBlock(new_states[0], "var_3"); new_states = auto_inline.ApplyOnBlock(new_states[0], "var_3");
auto test_func = [](ir::IRSchedule* ir_sch) { auto test_func = [](ir::IRSchedule* ir_sch) {
...@@ -238,7 +254,8 @@ TEST(AutoInline, AddReluInline) { ...@@ -238,7 +254,8 @@ TEST(AutoInline, AddReluInline) {
// Cannot inline above expr again // Cannot inline above expr again
EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply); EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply);
EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_2"), RuleApplyType::kCannotApply); EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_2"),
RuleApplyType::kCannotApply);
} }
#ifdef CINN_WITH_CUDA #ifdef CINN_WITH_CUDA
...@@ -246,14 +263,8 @@ class TestAutoInline : public TestAutoGenRuleBase {}; ...@@ -246,14 +263,8 @@ class TestAutoInline : public TestAutoGenRuleBase {};
/* The single chain graph composed of multiple blocks can be inlined into one. /* The single chain graph composed of multiple blocks can be inlined into one.
* *
* Before AutoInline: The output of the previous block is the input of another block. * Before AutoInline: The output of the previous block is the input of another
* Loop1: * block. Loop1: x1 = Add() Loop2: x2 = Multiply(x1) Loop3: x3 = Add(x2) Loop4:
* x1 = Add()
* Loop2:
* x2 = Multiply(x1)
* Loop3:
* x3 = Add(x2)
* Loop4:
* x4 = Relu(x3) * x4 = Relu(x3)
* *
* After AutoInline: All loops are inlined into a loop. * After AutoInline: All loops are inlined into a loop.
...@@ -263,18 +274,22 @@ class TestAutoInline : public TestAutoGenRuleBase {}; ...@@ -263,18 +274,22 @@ class TestAutoInline : public TestAutoGenRuleBase {};
TEST_F(TestAutoInline, SingleChain) { TEST_F(TestAutoInline, SingleChain) {
Target target = common::DefaultNVGPUTarget(); Target target = common::DefaultNVGPUTarget();
Initialize(target); Initialize(target);
std::vector<std::string> input_names = {"bias", "conv_output", "bn_scale", "bn_offset"}; std::vector<std::string> input_names = {
std::vector<std::string> output_names = {"var_6", "var_5", "var_1", "var", "var_0", "var_4", "var_3"}; "bias", "conv_output", "bn_scale", "bn_offset"};
std::vector<std::string> output_names = {
"var_6", "var_5", "var_1", "var", "var_0", "var_4", "var_3"};
std::vector<int32_t> conv_output_shape = {1, 512, 56, 56}; std::vector<int32_t> conv_output_shape = {1, 512, 56, 56};
int32_t channel = conv_output_shape[1]; int32_t channel = conv_output_shape[1];
std::vector<tests::VariableInfo> inputs_varinfo({{"conv_output", conv_output_shape}, std::vector<tests::VariableInfo> inputs_varinfo(
{{"conv_output", conv_output_shape},
{"bias", {channel, 1, 1}}, {"bias", {channel, 1, 1}},
{"bn_scale", {channel, 1, 1}}, {"bn_scale", {channel, 1, 1}},
{"bn_offset", {channel, 1, 1}}}); {"bn_offset", {channel, 1, 1}}});
// Construct the computation graph and convert it to ir::Expr // Construct the computation graph and convert it to ir::Expr
Context::Global().ResetNameId(); Context::Global().ResetNameId();
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::BiasBnReLUBuilder().Build(inputs_varinfo)); ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::BiasBnReLUBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL); ASSERT_EQ(func_bodys.size(), 1UL);
...@@ -282,20 +297,23 @@ TEST_F(TestAutoInline, SingleChain) { ...@@ -282,20 +297,23 @@ TEST_F(TestAutoInline, SingleChain) {
// Apply AutoInline for every block that can be inline // Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()}); AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_3"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_3"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "var_3"); auto new_states = auto_inline.ApplyOnBlock(state, "var_3");
std::vector<std::string> inline_block_names({"var_4", "var_5", "var_6", "var", "var_0", "var_1"}); std::vector<std::string> inline_block_names(
{"var_4", "var_5", "var_6", "var", "var_0", "var_1"});
for (const auto& inline_block_name : inline_block_names) { for (const auto& inline_block_name : inline_block_names) {
new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name); new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name);
} }
std::vector<ir::Expr> exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code // build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually = auto build_module_manually = BuildIRModule(MakeIRSchedule(
BuildIRModule(MakeIRSchedule(tests::BiasBnReLUBuilder().Build(inputs_varinfo), -1, true)); tests::BiasBnReLUBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto); auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto; VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually); auto source_code_manually = GenSourceCode(build_module_manually);
...@@ -305,7 +323,10 @@ TEST_F(TestAutoInline, SingleChain) { ...@@ -305,7 +323,10 @@ TEST_F(TestAutoInline, SingleChain) {
GenExecutableKernel(build_module_manually), GenExecutableKernel(build_module_manually),
input_names, input_names,
output_names, output_names,
{{conv_output_shape[1], 1, 1}, conv_output_shape, conv_output_shape, conv_output_shape}, {{conv_output_shape[1], 1, 1},
conv_output_shape,
conv_output_shape,
conv_output_shape},
{conv_output_shape, {1}, {1}, {1}, {1}, {1}, {1}}, {conv_output_shape, {1}, {1}, {1}, {1}, {1}, {1}},
target); target);
} }
...@@ -335,7 +356,8 @@ TEST_F(TestAutoInline, InlineToMultiConsumers) { ...@@ -335,7 +356,8 @@ TEST_F(TestAutoInline, InlineToMultiConsumers) {
// Construct the computation graph and convert it to ir::Expr // Construct the computation graph and convert it to ir::Expr
Context::Global().ResetNameId(); Context::Global().ResetNameId();
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo)); ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL); ASSERT_EQ(func_bodys.size(), 1UL);
...@@ -343,17 +365,19 @@ TEST_F(TestAutoInline, InlineToMultiConsumers) { ...@@ -343,17 +365,19 @@ TEST_F(TestAutoInline, InlineToMultiConsumers) {
// Apply AutoInline for every block that can be inline // Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()}); AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_0"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_0"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "var_1"); auto new_states = auto_inline.ApplyOnBlock(state, "var_1");
new_states = auto_inline.ApplyOnBlock(state, "var_0"); new_states = auto_inline.ApplyOnBlock(state, "var_0");
std::vector<ir::Expr> exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code // build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually = auto build_module_manually = BuildIRModule(MakeIRSchedule(
BuildIRModule(MakeIRSchedule(tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo), -1, true)); tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto); auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto; VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually); auto source_code_manually = GenSourceCode(build_module_manually);
...@@ -387,14 +411,20 @@ TEST_F(TestAutoInline, OnlySpatialOp) { ...@@ -387,14 +411,20 @@ TEST_F(TestAutoInline, OnlySpatialOp) {
Target target = common::DefaultNVGPUTarget(); Target target = common::DefaultNVGPUTarget();
Initialize(target); Initialize(target);
std::vector<std::string> input_names = {"x", "y"}; std::vector<std::string> input_names = {"x", "y"};
std::vector<std::string> output_names = { std::vector<std::string> output_names = {"var_6",
"var_6", "var_4", "constant_idx_last", "constant_idx_first", "var_2", "var_5"}; "var_4",
"constant_idx_last",
"constant_idx_first",
"var_2",
"var_5"};
std::vector<int32_t> input_shape{256, 256}; std::vector<int32_t> input_shape{256, 256};
std::vector<tests::VariableInfo> inputs_varinfo({{"x", input_shape}, {"y", input_shape}}); std::vector<tests::VariableInfo> inputs_varinfo(
{{"x", input_shape}, {"y", input_shape}});
// Construct the computation graph and convert it to ir::Expr // Construct the computation graph and convert it to ir::Expr
Context::Global().ResetNameId(); Context::Global().ResetNameId();
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo)); ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL); ASSERT_EQ(func_bodys.size(), 1UL);
...@@ -402,20 +432,23 @@ TEST_F(TestAutoInline, OnlySpatialOp) { ...@@ -402,20 +432,23 @@ TEST_F(TestAutoInline, OnlySpatialOp) {
// Apply AutoInline for every block that can be inline // Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()}); AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "constant_idx_first"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_inline.AnalyseApplyType(state, "constant_idx_first"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "constant_idx_first"); auto new_states = auto_inline.ApplyOnBlock(state, "constant_idx_first");
std::vector<std::string> inline_block_names({"constant_idx_last", "var_2", "var_5", "var_4"}); std::vector<std::string> inline_block_names(
{"constant_idx_last", "var_2", "var_5", "var_4"});
for (const auto& inline_block_name : inline_block_names) { for (const auto& inline_block_name : inline_block_names) {
new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name); new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name);
} }
std::vector<ir::Expr> exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code // build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually = auto build_module_manually = BuildIRModule(MakeIRSchedule(
BuildIRModule(MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo), -1, true)); tests::GatherAddSubBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto); auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto; VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually); auto source_code_manually = GenSourceCode(build_module_manually);
...@@ -451,7 +484,8 @@ TEST_F(TestAutoInline, NoReadBufferOp) { ...@@ -451,7 +484,8 @@ TEST_F(TestAutoInline, NoReadBufferOp) {
std::vector<tests::VariableInfo> inputs_varinfo({{"x", input_shape}}); std::vector<tests::VariableInfo> inputs_varinfo({{"x", input_shape}});
// Construct the computation graph and convert it to ir::Expr // Construct the computation graph and convert it to ir::Expr
ir::IRSchedule ir_schedule = MakeIRSchedule(tests::FillConstantAddBuilder().Build(inputs_varinfo)); ir::IRSchedule ir_schedule =
MakeIRSchedule(tests::FillConstantAddBuilder().Build(inputs_varinfo));
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL); ASSERT_EQ(func_bodys.size(), 1UL);
...@@ -459,16 +493,18 @@ TEST_F(TestAutoInline, NoReadBufferOp) { ...@@ -459,16 +493,18 @@ TEST_F(TestAutoInline, NoReadBufferOp) {
// Apply AutoInline for every block that can be inline // Apply AutoInline for every block that can be inline
AutoInline auto_inline(target_, {output_names.front()}); AutoInline auto_inline(target_, {output_names.front()});
EXPECT_EQ(auto_inline.AnalyseApplyType(state, "fill_constant"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(auto_inline.AnalyseApplyType(state, "fill_constant"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = auto_inline.ApplyOnBlock(state, "fill_constant"); auto new_states = auto_inline.ApplyOnBlock(state, "fill_constant");
std::vector<ir::Expr> exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> exprs =
new_states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0];
// build ir::Module and debug source code // build ir::Module and debug source code
auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule);
auto build_module_manually = auto build_module_manually = BuildIRModule(MakeIRSchedule(
BuildIRModule(MakeIRSchedule(tests::FillConstantAddBuilder().Build(inputs_varinfo), -1, true)); tests::FillConstantAddBuilder().Build(inputs_varinfo), -1, true));
auto source_code_auto = GenSourceCode(build_module_auto); auto source_code_auto = GenSourceCode(build_module_auto);
VLOG(6) << " auto-schedule source code:\n" << source_code_auto; VLOG(6) << " auto-schedule source code:\n" << source_code_auto;
auto source_code_manually = GenSourceCode(build_module_manually); auto source_code_manually = GenSourceCode(build_module_manually);
......
...@@ -33,11 +33,13 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const { ...@@ -33,11 +33,13 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const {
auto has_reduce_iter = [](const Expr* x) { auto has_reduce_iter = [](const Expr* x) {
auto* block_realize = x->As<ir::ScheduleBlockRealize>(); auto* block_realize = x->As<ir::ScheduleBlockRealize>();
if (block_realize) { if (block_realize) {
auto* schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>(); auto* schedule_block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock"; CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock";
for (auto&& var : schedule_block->iter_vars) { for (auto&& var : schedule_block->iter_vars) {
if (var->is_reduce_axis) { if (var->is_reduce_axis) {
VLOG(6) << "find ScheduleBlockRealize:" << *x << " has reduce_axis:" << var; VLOG(6) << "find ScheduleBlockRealize:" << *x
<< " has reduce_axis:" << var;
return true; return true;
} }
} }
...@@ -46,7 +48,8 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const { ...@@ -46,7 +48,8 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const {
}; };
// whether has any for-loop with non-serial type // whether has any for-loop with non-serial type
auto has_nonserial_loop = [](const Expr* x) { auto has_nonserial_loop = [](const Expr* x) {
if (x->As<ir::For>() && x->As<ir::For>()->for_type() != ir::ForType::Serial) { if (x->As<ir::For>() &&
x->As<ir::For>()->for_type() != ir::ForType::Serial) {
VLOG(6) << "find non-serial loop:" << *x; VLOG(6) << "find non-serial loop:" << *x;
return true; return true;
} }
...@@ -55,7 +58,9 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const { ...@@ -55,7 +58,9 @@ bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const {
auto find_target_exprs = ir::CollectIRNodesWithoutTensor( auto find_target_exprs = ir::CollectIRNodesWithoutTensor(
schedule_block->body, schedule_block->body,
[&has_reduce_iter, &has_nonserial_loop](const Expr* x) { return has_reduce_iter(x) || has_nonserial_loop(x); }); [&has_reduce_iter, &has_nonserial_loop](const Expr* x) {
return has_reduce_iter(x) || has_nonserial_loop(x);
});
return !find_target_exprs.empty(); return !find_target_exprs.empty();
} }
...@@ -74,44 +79,55 @@ RuleApplyType AutoUnroll::Init(ir::IRSchedule* ir_schedule) { ...@@ -74,44 +79,55 @@ RuleApplyType AutoUnroll::Init(ir::IRSchedule* ir_schedule) {
Expr root_block = ir_schedule_->GetRootBlock(block_realizes[i]); Expr root_block = ir_schedule_->GetRootBlock(block_realizes[i]);
auto* block_realize = root_block.As<ir::ScheduleBlockRealize>(); auto* block_realize = root_block.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block; CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block;
auto* schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>(); auto* schedule_block =
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:" << Expr(block_realize); block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:"
<< Expr(block_realize);
if (MeetCondition(schedule_block)) { if (MeetCondition(schedule_block)) {
deduplicate_results.emplace(root_block); deduplicate_results.emplace(root_block);
} }
} }
applicable_schedule_blocks_ = {deduplicate_results.begin(), deduplicate_results.end()}; applicable_schedule_blocks_ = {deduplicate_results.begin(),
deduplicate_results.end()};
num_applicable_ = applicable_schedule_blocks_.size(); num_applicable_ = applicable_schedule_blocks_.size();
VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_; VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_;
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
} }
void AutoUnroll::Apply(int index) { void AutoUnroll::Apply(int index) {
CHECK_LT(index, applicable_schedule_blocks_.size()) << "invalid apply index:" << index; CHECK_LT(index, applicable_schedule_blocks_.size())
<< "invalid apply index:" << index;
auto applied_block = applicable_schedule_blocks_.at(index); auto applied_block = applicable_schedule_blocks_.at(index);
int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()]; int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()];
ir_schedule_->Annotate(applied_block, ir::attr::auto_unroll_max_step, max_step); ir_schedule_->Annotate(
applied_block, ir::attr::auto_unroll_max_step, max_step);
return; return;
} }
RuleApplyType AutoUnroll::AnalyseApplyType(SearchState state, const std::string& block_name) const { RuleApplyType AutoUnroll::AnalyseApplyType(
SearchState state, const std::string& block_name) const {
Expr block_expr = state->ir_schedule.GetBlock(block_name); Expr block_expr = state->ir_schedule.GetBlock(block_name);
Expr root_block = state->ir_schedule.GetRootBlock(block_expr); Expr root_block = state->ir_schedule.GetRootBlock(block_expr);
auto* block_realize = root_block.As<ir::ScheduleBlockRealize>(); auto* block_realize = root_block.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block; CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block;
auto* schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>(); auto* schedule_block = block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:" << Expr(block_realize); CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:"
<< Expr(block_realize);
return MeetCondition(schedule_block) ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; return MeetCondition(schedule_block) ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
} }
std::vector<SearchState> AutoUnroll::ApplyOnBlock(SearchState state, const std::string& block_name) { std::vector<SearchState> AutoUnroll::ApplyOnBlock(
SearchState state, const std::string& block_name) {
SearchState new_state = state.Copy(); SearchState new_state = state.Copy();
Expr block_expr = new_state->ir_schedule.GetBlock(block_name); Expr block_expr = new_state->ir_schedule.GetBlock(block_name);
Expr applied_block = new_state->ir_schedule.GetRootBlock(block_expr); Expr applied_block = new_state->ir_schedule.GetRootBlock(block_expr);
int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()]; int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()];
new_state->ir_schedule.Annotate(applied_block, ir::attr::auto_unroll_max_step, max_step); new_state->ir_schedule.Annotate(
applied_block, ir::attr::auto_unroll_max_step, max_step);
return {new_state}; return {new_state};
} }
......
...@@ -24,10 +24,11 @@ ...@@ -24,10 +24,11 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
// This rule can be applied in a ScheduleBlock has reduce axis or has loops with non-serial type. // This rule can be applied in a ScheduleBlock has reduce axis or has loops with
// As a result, it will set a attribute with key named ir::attr::auto_unroll_max_step and value // non-serial type. As a result, it will set a attribute with key named
// indicating max permitted unrolled step in the applied ScheduleBlock. Finally, UnrollLoop pass // ir::attr::auto_unroll_max_step and value indicating max permitted unrolled
// will do unroll based on actual situation. // step in the applied ScheduleBlock. Finally, UnrollLoop pass will do unroll
// based on actual situation.
class AutoUnroll : public AutoGenRule { class AutoUnroll : public AutoGenRule {
public: public:
AutoUnroll(const common::Target& target) : AutoGenRule(target) {} AutoUnroll(const common::Target& target) : AutoGenRule(target) {}
...@@ -39,9 +40,11 @@ class AutoUnroll : public AutoGenRule { ...@@ -39,9 +40,11 @@ class AutoUnroll : public AutoGenRule {
std::string GetRuleName() const override { return "AutoUnroll"; } std::string GetRuleName() const override { return "AutoUnroll"; }
RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) override; std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
private: private:
bool MeetCondition(const ir::ScheduleBlock* schedule_block) const; bool MeetCondition(const ir::ScheduleBlock* schedule_block) const;
......
...@@ -39,7 +39,8 @@ TEST(AutoUnroll, Init) { ...@@ -39,7 +39,8 @@ TEST(AutoUnroll, Init) {
Target target = common::DefaultHostTarget(); Target target = common::DefaultHostTarget();
#endif #endif
auto stages = CreateStages({C}); auto stages = CreateStages({C});
auto funcs = cinn::lang::LowerVec("test_init", stages, {A, B, C}, {}, {}, nullptr, target, true); auto funcs = cinn::lang::LowerVec(
"test_init", stages, {A, B, C}, {}, {}, nullptr, target, true);
auto ast_expr = funcs[0]->body; auto ast_expr = funcs[0]->body;
ir::IRSchedule init_schedule(ir::ModuleExpr({ast_expr})); ir::IRSchedule init_schedule(ir::ModuleExpr({ast_expr}));
...@@ -58,7 +59,9 @@ TEST(AutoUnroll, UnrollableApply) { ...@@ -58,7 +59,9 @@ TEST(AutoUnroll, UnrollableApply) {
Placeholder<float> B("B", {K, N}); Placeholder<float> B("B", {K, N});
Var k(K.as_int32(), "k0"); Var k(K.as_int32(), "k0");
Tensor C = Compute( Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); {M, N},
[&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
#ifdef CINN_WITH_CUDA #ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget(); Target target = common::DefaultNVGPUTarget();
...@@ -66,11 +69,14 @@ TEST(AutoUnroll, UnrollableApply) { ...@@ -66,11 +69,14 @@ TEST(AutoUnroll, UnrollableApply) {
Target target = common::DefaultHostTarget(); Target target = common::DefaultHostTarget();
#endif #endif
auto stages = CreateStages({C}); auto stages = CreateStages({C});
auto funcs = cinn::lang::LowerVec("test_unrollable", stages, {A, B, C}, {}, {}, nullptr, target, true); auto funcs = cinn::lang::LowerVec(
"test_unrollable", stages, {A, B, C}, {}, {}, nullptr, target, true);
auto ast_expr = funcs[0]->body; auto ast_expr = funcs[0]->body;
auto* init_block_realize = ast_expr.As<ir::Block>()->stmts.front().As<ir::ScheduleBlockRealize>(); auto* init_block_realize =
auto* init_schedule_block = init_block_realize->schedule_block.As<ir::ScheduleBlock>(); ast_expr.As<ir::Block>()->stmts.front().As<ir::ScheduleBlockRealize>();
auto* init_schedule_block =
init_block_realize->schedule_block.As<ir::ScheduleBlock>();
ASSERT_NE(init_schedule_block, nullptr); ASSERT_NE(init_schedule_block, nullptr);
ASSERT_TRUE(init_schedule_block->attrs.empty()); ASSERT_TRUE(init_schedule_block->attrs.empty());
VLOG(6) << "Before auto-unroll:\n" << ast_expr; VLOG(6) << "Before auto-unroll:\n" << ast_expr;
...@@ -78,25 +84,34 @@ TEST(AutoUnroll, UnrollableApply) { ...@@ -78,25 +84,34 @@ TEST(AutoUnroll, UnrollableApply) {
AutoUnroll test_rule(target); AutoUnroll test_rule(target);
ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr})); ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr}));
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
ASSERT_EQ(test_rule.Init(&ir_schedule), RuleApplyType::kApplyAndPruneOtherRules); ASSERT_EQ(test_rule.Init(&ir_schedule),
RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(test_rule.NumberApplicable(), 1); EXPECT_EQ(test_rule.NumberApplicable(), 1);
test_rule.ApplyRandomly(); test_rule.ApplyRandomly();
// ApplyOnBlock // ApplyOnBlock
EXPECT_EQ(test_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(test_rule.AnalyseApplyType(state, "C"),
std::vector<cinn::auto_schedule::SearchState> states = test_rule.ApplyOnBlock(state, "C"); RuleApplyType::kApplyAndPruneOtherRules);
std::vector<cinn::auto_schedule::SearchState> states =
test_rule.ApplyOnBlock(state, "C");
auto test_func = [](IRSchedule* ir_sch) { auto test_func = [](IRSchedule* ir_sch) {
Expr applied_expr = ir_sch->GetModule().GetExprs().front(); Expr applied_expr = ir_sch->GetModule().GetExprs().front();
auto* applied_block_realize = applied_expr.As<ir::Block>()->stmts.front().As<ir::ScheduleBlockRealize>(); auto* applied_block_realize = applied_expr.As<ir::Block>()
auto* applied_schedule_block = applied_block_realize->schedule_block.As<ir::ScheduleBlock>(); ->stmts.front()
.As<ir::ScheduleBlockRealize>();
auto* applied_schedule_block =
applied_block_realize->schedule_block.As<ir::ScheduleBlock>();
ASSERT_FALSE(applied_schedule_block->attrs.empty()); ASSERT_FALSE(applied_schedule_block->attrs.empty());
EXPECT_EQ(applied_schedule_block->attrs.count(ir::attr::auto_unroll_max_step), 1); EXPECT_EQ(
const auto& attr_value = applied_schedule_block->attrs.at(ir::attr::auto_unroll_max_step); applied_schedule_block->attrs.count(ir::attr::auto_unroll_max_step), 1);
const auto& attr_value =
applied_schedule_block->attrs.at(ir::attr::auto_unroll_max_step);
const int* max_step = absl::get_if<int>(&attr_value); const int* max_step = absl::get_if<int>(&attr_value);
EXPECT_NE(max_step, nullptr); EXPECT_NE(max_step, nullptr);
EXPECT_LE(*max_step, 128); EXPECT_LE(*max_step, 128);
VLOG(6) << "After auto-unroll:max_step=" << *max_step << ", Ast:\n" << ir_sch->GetModule().GetExprs().front(); VLOG(6) << "After auto-unroll:max_step=" << *max_step << ", Ast:\n"
<< ir_sch->GetModule().GetExprs().front();
}; };
test_func(&ir_schedule); test_func(&ir_schedule);
......
...@@ -34,7 +34,8 @@ class TestMixRules : public TestAutoGenRuleBase { ...@@ -34,7 +34,8 @@ class TestMixRules : public TestAutoGenRuleBase {
}; };
TEST_F(TestMixRules, 2DMatmulOnMultiTilingRelated) { TEST_F(TestMixRules, 2DMatmulOnMultiTilingRelated) {
frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}}); frontend::Program matmul_op =
tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}});
Initialize(common::DefaultNVGPUTarget()); Initialize(common::DefaultNVGPUTarget());
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op); ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op);
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
...@@ -42,7 +43,8 @@ TEST_F(TestMixRules, 2DMatmulOnMultiTilingRelated) { ...@@ -42,7 +43,8 @@ TEST_F(TestMixRules, 2DMatmulOnMultiTilingRelated) {
VLOG(6) << "Original Expr:\n" << func_bodys[0]; VLOG(6) << "Original Expr:\n" << func_bodys[0];
// Apply MultiLevelTiling // Apply MultiLevelTiling
MultiLevelTiling multi_level_tiling(target_, MultiLevelTiling::kConfigs.at(target_.arch)); MultiLevelTiling multi_level_tiling(
target_, MultiLevelTiling::kConfigs.at(target_.arch));
multi_level_tiling.Init(&ir_schedule); multi_level_tiling.Init(&ir_schedule);
ASSERT_EQ(multi_level_tiling.NumberApplicable(), 1); ASSERT_EQ(multi_level_tiling.NumberApplicable(), 1);
multi_level_tiling.ApplyRandomly(); multi_level_tiling.ApplyRandomly();
...@@ -54,7 +56,8 @@ TEST_F(TestMixRules, 2DMatmulOnMultiTilingRelated) { ...@@ -54,7 +56,8 @@ TEST_F(TestMixRules, 2DMatmulOnMultiTilingRelated) {
VLOG(6) << "scheduled source code:\n" << source_code; VLOG(6) << "scheduled source code:\n" << source_code;
// execute and check precision // execute and check precision
CheckResult(GenExecutableKernel(ir_module), CheckResult(GenExecutableKernel(ir_module),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, /* apply_manual_schedule */ true))), GenExecutableKernel(BuildIRModule(
MakeIRSchedule(matmul_op, /* apply_manual_schedule */ true))),
default_input_names, default_input_names,
default_output_names, default_output_names,
{{32, 32}, {32, 32}}, {{32, 32}, {32, 32}},
......
...@@ -38,7 +38,8 @@ ...@@ -38,7 +38,8 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
MultiLevelTiling::MultiLevelTiling(const common::Target& target, const Config& config) MultiLevelTiling::MultiLevelTiling(const common::Target& target,
const Config& config)
: AutoGenRule(target), config_(config) { : AutoGenRule(target), config_(config) {
for (int i = 0; i < config_.tile_struct.size(); ++i) { for (int i = 0; i < config_.tile_struct.size(); ++i) {
if (config_.tile_struct[i] == 'S') { if (config_.tile_struct[i] == 'S') {
...@@ -51,7 +52,8 @@ MultiLevelTiling::MultiLevelTiling(const common::Target& target, const Config& c ...@@ -51,7 +52,8 @@ MultiLevelTiling::MultiLevelTiling(const common::Target& target, const Config& c
} }
} }
bool MultiLevelTiling::MeetCondition(const ir::ScheduleBlockRealize& sche_block_realize) const { bool MultiLevelTiling::MeetCondition(
const ir::ScheduleBlockRealize& sche_block_realize) const {
return NeedsMultiLevelTiling(sche_block_realize); return NeedsMultiLevelTiling(sche_block_realize);
} }
...@@ -61,15 +63,18 @@ RuleApplyType MultiLevelTiling::Init(ir::IRSchedule* ir_schedule) { ...@@ -61,15 +63,18 @@ RuleApplyType MultiLevelTiling::Init(ir::IRSchedule* ir_schedule) {
applicable_indices_.clear(); applicable_indices_.clear();
num_applicable_ = 0; num_applicable_ = 0;
for (size_t i = 0; i < all_block_realizes_.size(); ++i) { for (size_t i = 0; i < all_block_realizes_.size(); ++i) {
ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes_[i].As<ir::ScheduleBlockRealize>(); ir::ScheduleBlockRealize* sche_block_realize =
AnalyzeScheduleBlockReadWriteBuffer(sche_block_realize->schedule_block.As<ir::ScheduleBlock>()); all_block_realizes_[i].As<ir::ScheduleBlockRealize>();
AnalyzeScheduleBlockReadWriteBuffer(
sche_block_realize->schedule_block.As<ir::ScheduleBlock>());
if (MeetCondition(*sche_block_realize)) { if (MeetCondition(*sche_block_realize)) {
++num_applicable_; ++num_applicable_;
applicable_indices_.push_back(i); applicable_indices_.push_back(i);
} }
} }
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
} }
void MultiLevelTiling::Apply(int index) { void MultiLevelTiling::Apply(int index) {
...@@ -77,12 +82,16 @@ void MultiLevelTiling::Apply(int index) { ...@@ -77,12 +82,16 @@ void MultiLevelTiling::Apply(int index) {
CHECK(num_applicable_ > 0 && applicable_indices_.size() == num_applicable_) CHECK(num_applicable_ > 0 && applicable_indices_.size() == num_applicable_)
<< "MultiLevelTiling::Apply pre-condition doesn't meet"; << "MultiLevelTiling::Apply pre-condition doesn't meet";
CHECK(index >= 0 && num_applicable_ > index) CHECK(index >= 0 && num_applicable_ > index)
<< "Invalid index for MultiLevelTiling::Apply, the index needs 0 <= index && index < NumberApplicable(), " << "Invalid index for MultiLevelTiling::Apply, the index needs 0 <= "
<< "Currently index = " << index << ", NumberApplicable() = " << num_applicable_; "index && index < NumberApplicable(), "
<< "Currently index = " << index
<< ", NumberApplicable() = " << num_applicable_;
int apply_index = applicable_indices_[index]; int apply_index = applicable_indices_[index];
std::string block_name = std::string block_name = all_block_realizes_[apply_index]
all_block_realizes_[apply_index].As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name; .As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
Expr block_expr = all_block_realizes_[apply_index]; Expr block_expr = all_block_realizes_[apply_index];
ApplyTiling(ir_schedule_, block_expr); ApplyTiling(ir_schedule_, block_expr);
block_expr = ir_schedule_->GetBlock(block_name); block_expr = ir_schedule_->GetBlock(block_name);
...@@ -96,16 +105,21 @@ void MultiLevelTiling::Apply(int index) { ...@@ -96,16 +105,21 @@ void MultiLevelTiling::Apply(int index) {
std::string MultiLevelTiling::GetRuleName() const { return "MultiLevelTiling"; } std::string MultiLevelTiling::GetRuleName() const { return "MultiLevelTiling"; }
RuleApplyType MultiLevelTiling::AnalyseApplyType(SearchState state, const std::string& block_name) const { RuleApplyType MultiLevelTiling::AnalyseApplyType(
SearchState state, const std::string& block_name) const {
Expr block_expr = state->ir_schedule.GetBlock(block_name); Expr block_expr = state->ir_schedule.GetBlock(block_name);
auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>(); auto* block_realize = block_expr.As<ir::ScheduleBlockRealize>();
CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr; CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr;
AnalyzeScheduleBlockReadWriteBuffer(block_realize->schedule_block.As<ir::ScheduleBlock>()); AnalyzeScheduleBlockReadWriteBuffer(
block_realize->schedule_block.As<ir::ScheduleBlock>());
return NeedsMultiLevelTiling(*block_realize) ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; return NeedsMultiLevelTiling(*block_realize)
? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
} }
std::vector<SearchState> MultiLevelTiling::ApplyOnBlock(SearchState state, const std::string& block_name) { std::vector<SearchState> MultiLevelTiling::ApplyOnBlock(
SearchState state, const std::string& block_name) {
SearchState new_state = state.Copy(); SearchState new_state = state.Copy();
ir::IRSchedule* ir_sch = &new_state->ir_schedule; ir::IRSchedule* ir_sch = &new_state->ir_schedule;
Expr block_expr = ir_sch->GetBlock(block_name); Expr block_expr = ir_sch->GetBlock(block_name);
...@@ -119,14 +133,18 @@ std::vector<SearchState> MultiLevelTiling::ApplyOnBlock(SearchState state, const ...@@ -119,14 +133,18 @@ std::vector<SearchState> MultiLevelTiling::ApplyOnBlock(SearchState state, const
return {new_state}; return {new_state};
} }
void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule,
ir::ScheduleBlockRealize* sche_block_realize = block_expr.As<ir::ScheduleBlockRealize>(); ir::Expr& block_expr) {
ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As<ir::ScheduleBlock>(); ir::ScheduleBlockRealize* sche_block_realize =
block_expr.As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
tile_loops_.clear(); tile_loops_.clear();
tile_loops_.resize(config_.tile_struct.size()); tile_loops_.resize(config_.tile_struct.size());
std::vector<Expr> for_exprs = ir_schedule->GetLoops(block_expr); std::vector<Expr> for_exprs = ir_schedule->GetLoops(block_expr);
VLOG(5) << "The number of loops to split in MultiLevelTiling is " << for_exprs.size(); VLOG(5) << "The number of loops to split in MultiLevelTiling is "
<< for_exprs.size();
for (int i = for_exprs.size() - 1; i >= 0; --i) { for (int i = for_exprs.size() - 1; i >= 0; --i) {
ir::For* ir_for = for_exprs[i].As<ir::For>(); ir::For* ir_for = for_exprs[i].As<ir::For>();
VLOG(6) << "Applying Split for MultiLevelTiling on: " << Expr(ir_for); VLOG(6) << "Applying Split for MultiLevelTiling on: " << Expr(ir_for);
...@@ -141,8 +159,10 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ ...@@ -141,8 +159,10 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
int num_split = idx->size(); int num_split = idx->size();
if (num_split > 1) { if (num_split > 1) {
std::vector<Expr> tile_split_factor = ir_schedule->SamplePerfectTile(Expr(ir_for), num_split, 64); std::vector<Expr> tile_split_factor =
std::vector<Expr> splited = ir_schedule->Split(Expr(ir_for), tile_split_factor); ir_schedule->SamplePerfectTile(Expr(ir_for), num_split, 64);
std::vector<Expr> splited =
ir_schedule->Split(Expr(ir_for), tile_split_factor);
VLOG(6) << "Finish Split for MultiLevelTiling on above loop"; VLOG(6) << "Finish Split for MultiLevelTiling on above loop";
for (int j = 0; j < num_split; ++j) { for (int j = 0; j < num_split; ++j) {
tile_loops_[idx->at(j)].push_back(splited[j]); tile_loops_[idx->at(j)].push_back(splited[j]);
...@@ -159,7 +179,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ ...@@ -159,7 +179,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
for (int i = 0; i < for_exprs.size(); ++i) { for (int i = 0; i < for_exprs.size(); ++i) {
loop_var_name_to_idx[for_exprs[i].As<ir::For>()->loop_var->name] = i; loop_var_name_to_idx[for_exprs[i].As<ir::For>()->loop_var->name] = i;
} }
CHECK(loop_var_name_to_idx.size() == for_exprs.size()) << "Loops contain duplicate loop var names after split"; CHECK(loop_var_name_to_idx.size() == for_exprs.size())
<< "Loops contain duplicate loop var names after split";
std::vector<Expr> splited_loops; std::vector<Expr> splited_loops;
for (auto& t : tile_loops_) { for (auto& t : tile_loops_) {
...@@ -173,7 +194,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ ...@@ -173,7 +194,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
} }
Expr reordered_expr = ir_schedule->Reorder(splited_loops); Expr reordered_expr = ir_schedule->Reorder(splited_loops);
VLOG(5) << "Finish Reorder in MultiLevelTiling, now do Fuse and Binding on the main loop chain"; VLOG(5) << "Finish Reorder in MultiLevelTiling, now do Fuse and Binding on "
"the main loop chain";
int num_binds = std::min(config_.bind_axis.size(), tile_loops_.size()); int num_binds = std::min(config_.bind_axis.size(), tile_loops_.size());
for (int i = 0; i < num_binds; ++i) { for (int i = 0; i < num_binds; ++i) {
...@@ -182,7 +204,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ ...@@ -182,7 +204,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
for (int j = 0; j < for_exprs.size(); ++j) { for (int j = 0; j < for_exprs.size(); ++j) {
loop_var_name_to_idx[for_exprs[j].As<ir::For>()->loop_var->name] = j; loop_var_name_to_idx[for_exprs[j].As<ir::For>()->loop_var->name] = j;
} }
CHECK(loop_var_name_to_idx.size() == for_exprs.size()) << "Loops contain duplicate loop var names before Fusion"; CHECK(loop_var_name_to_idx.size() == for_exprs.size())
<< "Loops contain duplicate loop var names before Fusion";
// Some loops extent may exceed the limited max factor (For example, // Some loops extent may exceed the limited max factor (For example,
// exceed the limit number of CUDA threads), here we check whether // exceed the limit number of CUDA threads), here we check whether
...@@ -209,7 +232,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ ...@@ -209,7 +232,8 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
Expr fused = ir_schedule->Fuse(tile_loops_[i]); Expr fused = ir_schedule->Fuse(tile_loops_[i]);
ir_schedule->Bind(fused, config_.bind_axis[i]); ir_schedule->Bind(fused, config_.bind_axis[i]);
} else if (first_idx_less_than_max_factor != -1) { } else if (first_idx_less_than_max_factor != -1) {
ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor], config_.bind_axis[i]); ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor],
config_.bind_axis[i]);
} }
} }
...@@ -229,13 +253,17 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ ...@@ -229,13 +253,17 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
} }
} }
if (!other_loop_chain_schedule.defined()) { if (!other_loop_chain_schedule.defined()) {
LOG(WARNING) << "Has non-main loop chain, but not corresponding ScheduleBlock in MultiLevelTiling"; LOG(WARNING) << "Has non-main loop chain, but not corresponding "
"ScheduleBlock in MultiLevelTiling";
continue; continue;
} }
std::string other_loop_schedule_name = std::string other_loop_schedule_name =
other_loop_chain_schedule.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name; other_loop_chain_schedule.As<ir::ScheduleBlockRealize>()
VLOG(6) << "Found other_loop_schedule_name = " << other_loop_schedule_name; ->schedule_block.As<ir::ScheduleBlock>()
->name;
VLOG(6) << "Found other_loop_schedule_name = "
<< other_loop_schedule_name;
int fuse_index = 0; int fuse_index = 0;
for (int i = 0; i < num_binds; ++i) { for (int i = 0; i < num_binds; ++i) {
for_exprs = ir_schedule->GetLoops(other_loop_schedule_name); for_exprs = ir_schedule->GetLoops(other_loop_schedule_name);
...@@ -250,20 +278,23 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ ...@@ -250,20 +278,23 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
int extent_prod = 1; int extent_prod = 1;
int first_idx_less_than_max_factor = -1; int first_idx_less_than_max_factor = -1;
for (int j = 0; j < tile_loops_[i].size(); ++j) { for (int j = 0; j < tile_loops_[i].size(); ++j) {
int extent = for_exprs[fuse_index + j].As<ir::For>()->extent.as_int32(); int extent =
for_exprs[fuse_index + j].As<ir::For>()->extent.as_int32();
extent_prod *= extent; extent_prod *= extent;
if (first_idx_less_than_max_factor == -1 && extent <= max_factor_) { if (first_idx_less_than_max_factor == -1 && extent <= max_factor_) {
first_idx_less_than_max_factor = fuse_index + j; first_idx_less_than_max_factor = fuse_index + j;
} }
} }
if (extent_prod <= max_factor_) { if (extent_prod <= max_factor_) {
std::vector<Expr> loops_to_fuse(for_exprs.begin() + fuse_index, std::vector<Expr> loops_to_fuse(
for_exprs.begin() + fuse_index,
for_exprs.begin() + fuse_index + tile_loops_[i].size()); for_exprs.begin() + fuse_index + tile_loops_[i].size());
Expr fused = ir_schedule->Fuse(loops_to_fuse); Expr fused = ir_schedule->Fuse(loops_to_fuse);
ir_schedule->Bind(fused, config_.bind_axis[i]); ir_schedule->Bind(fused, config_.bind_axis[i]);
fuse_index += 1; fuse_index += 1;
} else if (first_idx_less_than_max_factor != -1) { } else if (first_idx_less_than_max_factor != -1) {
ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor], config_.bind_axis[i]); ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor],
config_.bind_axis[i]);
fuse_index += tile_loops_[i].size(); fuse_index += tile_loops_[i].size();
} }
} }
...@@ -272,9 +303,12 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_ ...@@ -272,9 +303,12 @@ void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_
} }
} }
void MultiLevelTiling::ApplyCacheRead(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { void MultiLevelTiling::ApplyCacheRead(ir::IRSchedule* ir_schedule,
ir::ScheduleBlockRealize* sch_block_realize = block_expr.As<ir::ScheduleBlockRealize>(); ir::Expr& block_expr) {
ir::ScheduleBlock* sch_block = sch_block_realize->schedule_block.As<ir::ScheduleBlock>(); ir::ScheduleBlockRealize* sch_block_realize =
block_expr.As<ir::ScheduleBlockRealize>();
ir::ScheduleBlock* sch_block =
sch_block_realize->schedule_block.As<ir::ScheduleBlock>();
std::string block_name = sch_block->name; std::string block_name = sch_block->name;
// Analyze which buffers can be cached // Analyze which buffers can be cached
...@@ -302,85 +336,110 @@ void MultiLevelTiling::ApplyCacheRead(ir::IRSchedule* ir_schedule, ir::Expr& blo ...@@ -302,85 +336,110 @@ void MultiLevelTiling::ApplyCacheRead(ir::IRSchedule* ir_schedule, ir::Expr& blo
} }
// 2.Do CacheRead and get the cache block // 2.Do CacheRead and get the cache block
ir::Expr cache_block = ir_schedule->CacheRead(block_expr, read_buffer_index, config_.read_cache_memory_type); ir::Expr cache_block = ir_schedule->CacheRead(
block_expr, read_buffer_index, config_.read_cache_memory_type);
std::string cache_block_name = std::string cache_block_name =
cache_block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name; cache_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
std::string target_for_loop_name = loops.back().As<ir::For>()->loop_var->name; std::string target_for_loop_name =
loops.back().As<ir::For>()->loop_var->name;
// 3.Place the cache_block under target_for_loop // 3.Place the cache_block under target_for_loop
// The original block expr is invalid after the CacheRead schedule, // The original block expr is invalid after the CacheRead schedule,
// so we reacquire the block expr after the schedule according to the block name // so we reacquire the block expr after the schedule according to the
// block name
block_expr = ir_schedule->GetBlock(block_name); block_expr = ir_schedule->GetBlock(block_name);
std::vector<Expr> for_exprs = ir_schedule->GetLoops(block_expr); std::vector<Expr> for_exprs = ir_schedule->GetLoops(block_expr);
for (const Expr& for_expr : for_exprs) { for (const Expr& for_expr : for_exprs) {
if (for_expr.As<ir::For>()->loop_var->name.find(target_for_loop_name) != std::string::npos) { if (for_expr.As<ir::For>()->loop_var->name.find(target_for_loop_name) !=
std::string::npos) {
ir_schedule->ComputeAt(cache_block, for_expr, true); ir_schedule->ComputeAt(cache_block, for_expr, true);
break; break;
} }
} }
// 4.Threads under the same block cooperative fetch data from global memory. // 4.Threads under the same block cooperative fetch data from global
// memory.
Expr new_cache_block = ir_schedule->GetBlock(cache_block_name); Expr new_cache_block = ir_schedule->GetBlock(cache_block_name);
auto cache_block_loops = ir_schedule->GetLoops(new_cache_block); auto cache_block_loops = ir_schedule->GetLoops(new_cache_block);
std::vector<std::string> compute_at_extra_var = utils::Split( std::vector<std::string> compute_at_extra_var = utils::Split(
absl::get<std::string>( absl::get<std::string>(new_cache_block.As<ir::ScheduleBlockRealize>()
new_cache_block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->attrs.at( ->schedule_block.As<ir::ScheduleBlock>()
"compute_at_extra_var")), ->attrs.at("compute_at_extra_var")),
","); ",");
std::vector<Expr> buffer_loops; std::vector<Expr> buffer_loops;
// int nthreads = 1; // int nthreads = 1;
for (const Expr& for_expr : cache_block_loops) { for (const Expr& for_expr : cache_block_loops) {
if (std::find(compute_at_extra_var.begin(), if (std::find(compute_at_extra_var.begin(),
compute_at_extra_var.end(), compute_at_extra_var.end(),
for_expr.As<ir::For>()->loop_var->name) != compute_at_extra_var.end()) { for_expr.As<ir::For>()->loop_var->name) !=
compute_at_extra_var.end()) {
buffer_loops.push_back(for_expr); buffer_loops.push_back(for_expr);
} }
} }
auto fused_buffer_loop = ir_schedule->Fuse(buffer_loops); auto fused_buffer_loop = ir_schedule->Fuse(buffer_loops);
// TODO(BiynXu): Implement vectorize fetching data and pass in vector length // TODO(BiynXu): Implement vectorize fetching data and pass in vector
ir_schedule->Annotate(ir_schedule->GetBlock(cache_block_name), ir::attr::cooperative_process, 0); // length
ir_schedule->Annotate(ir_schedule->GetBlock(cache_block_name),
ir::attr::cooperative_process,
0);
} }
} }
} }
void MultiLevelTiling::ApplyCacheWrite(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { void MultiLevelTiling::ApplyCacheWrite(ir::IRSchedule* ir_schedule,
ir::Expr cache_block = ir_schedule->CacheWrite(block_expr, 0, config_.write_cache_memory_type); ir::Expr& block_expr) {
ir::Expr cache_block =
ir_schedule->CacheWrite(block_expr, 0, config_.write_cache_memory_type);
for (int level : config_.write_cache_levels) { for (int level : config_.write_cache_levels) {
const auto loops = tile_loops_.at(level - 1); const auto loops = tile_loops_.at(level - 1);
if (loops.size() == 0) { if (loops.size() == 0) {
continue; continue;
} }
std::string target_for_loop_name = loops.back().As<ir::For>()->loop_var->name; std::string target_for_loop_name =
// Because the block name is changed in CacheWrite, we need to calculate the derived name loops.back().As<ir::For>()->loop_var->name;
// according to the logic of CacheWrite and find the loop structure according to the derived name. // Because the block name is changed in CacheWrite, we need to calculate the
// derived name according to the logic of CacheWrite and find the loop
// structure according to the derived name.
const std::string original_block_name = const std::string original_block_name =
block_expr.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->name; block_expr.As<ir::ScheduleBlockRealize>()
const std::string derivative_block_name = ->schedule_block.As<ir::ScheduleBlock>()
original_block_name + "_" + config_.write_cache_memory_type + "_temp_buffer"; ->name;
const std::string derivative_block_name = original_block_name + "_" +
config_.write_cache_memory_type +
"_temp_buffer";
std::vector<Expr> for_exprs = ir_schedule->GetLoops(derivative_block_name); std::vector<Expr> for_exprs = ir_schedule->GetLoops(derivative_block_name);
for (const Expr& for_expr : for_exprs) { for (const Expr& for_expr : for_exprs) {
if (for_expr.As<ir::For>()->loop_var->name.find(target_for_loop_name) != std::string::npos) { if (for_expr.As<ir::For>()->loop_var->name.find(target_for_loop_name) !=
ir_schedule->ReverseComputeAt(ir_schedule->GetBlock(original_block_name), for_expr, true); std::string::npos) {
ir_schedule->ReverseComputeAt(
ir_schedule->GetBlock(original_block_name), for_expr, true);
} }
} }
const std::string reduce_init_block_name = original_block_name + "__reduce_init"; const std::string reduce_init_block_name =
original_block_name + "__reduce_init";
for_exprs = ir_schedule->GetLoops(derivative_block_name); for_exprs = ir_schedule->GetLoops(derivative_block_name);
for (const Expr& for_expr : for_exprs) { for (const Expr& for_expr : for_exprs) {
if (for_expr.As<ir::For>()->loop_var->name.find(target_for_loop_name) != std::string::npos && if (for_expr.As<ir::For>()->loop_var->name.find(target_for_loop_name) !=
std::string::npos &&
ir_schedule->HasBlock(reduce_init_block_name)) { ir_schedule->HasBlock(reduce_init_block_name)) {
ir_schedule->SimpleComputeAt(ir_schedule->GetBlock(reduce_init_block_name), for_expr); ir_schedule->SimpleComputeAt(
ir_schedule->GetBlock(reduce_init_block_name), for_expr);
} }
} }
} }
} }
const std::unordered_map<common::Target::Arch, MultiLevelTiling::Config> MultiLevelTiling::kConfigs{ const std::unordered_map<common::Target::Arch, MultiLevelTiling::Config>
MultiLevelTiling::kConfigs{
{common::Target::Arch::NVGPU, {common::Target::Arch::NVGPU,
MultiLevelTiling::Config{ MultiLevelTiling::Config{
/*bind_axis*/ std::vector<std::string>{"blockIdx.x", "threadIdx.x"}, /*bind_axis*/ std::vector<std::string>{"blockIdx.x",
"threadIdx.x"},
/*tile_struct*/ std::string("SSSRRSRS"), /*tile_struct*/ std::string("SSSRRSRS"),
/*read_cache_memory_type*/ std::string("shared"), /*read_cache_memory_type*/ std::string("shared"),
/*read_cache_levels*/ std::vector<int>{4}, /*read_cache_levels*/ std::vector<int>{4},
......
...@@ -72,9 +72,11 @@ class MultiLevelTiling : public AutoGenRule { ...@@ -72,9 +72,11 @@ class MultiLevelTiling : public AutoGenRule {
// Returns true if sche_block_realize is applicable by MultiLevelTiling // Returns true if sche_block_realize is applicable by MultiLevelTiling
bool MeetCondition(const ir::ScheduleBlockRealize& sche_block_realize) const; bool MeetCondition(const ir::ScheduleBlockRealize& sche_block_realize) const;
RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) override; std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
// Sample pair of integer type (a, b) such as a * b = extent // Sample pair of integer type (a, b) such as a * b = extent
template <typename T> template <typename T>
...@@ -101,7 +103,8 @@ class MultiLevelTiling : public AutoGenRule { ...@@ -101,7 +103,8 @@ class MultiLevelTiling : public AutoGenRule {
// Sample num_split integers whose product equals extent // Sample num_split integers whose product equals extent
template <typename T> template <typename T>
std::vector<T> SampleTileSplit(T extent, int num_split) const { std::vector<T> SampleTileSplit(T extent, int num_split) const {
CHECK_GT(num_split, 0) << "num_split in SampleTileSplit must be greater than 0"; CHECK_GT(num_split, 0)
<< "num_split in SampleTileSplit must be greater than 0";
if (num_split == 1) { if (num_split == 1) {
return {extent}; return {extent};
} }
......
...@@ -48,11 +48,13 @@ TEST(MultiLevelTile, SampleSplitTwo) { ...@@ -48,11 +48,13 @@ TEST(MultiLevelTile, SampleSplitTwo) {
Target target = common::DefaultHostTarget(); Target target = common::DefaultHostTarget();
#endif #endif
MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); MultiLevelTiling multi_level_tiling(
target, MultiLevelTiling::kConfigs.at(target.arch));
for (int i = 0; i < 100; ++i) { for (int i = 0; i < 100; ++i) {
size_t number_to_split = rand() % 65535 + 2; // random number in [2, 2^16] size_t number_to_split = rand() % 65535 + 2; // random number in [2, 2^16]
std::vector<size_t> split = multi_level_tiling.SampleSplitTwo<size_t>(number_to_split); std::vector<size_t> split =
multi_level_tiling.SampleSplitTwo<size_t>(number_to_split);
EXPECT_EQ(split.size(), 2UL); EXPECT_EQ(split.size(), 2UL);
EXPECT_EQ(split[0] * split[1], number_to_split); EXPECT_EQ(split[0] * split[1], number_to_split);
} }
...@@ -67,12 +69,14 @@ TEST(MultiLevelTile, SampleTileSplit) { ...@@ -67,12 +69,14 @@ TEST(MultiLevelTile, SampleTileSplit) {
Target target = common::DefaultHostTarget(); Target target = common::DefaultHostTarget();
#endif #endif
MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); MultiLevelTiling multi_level_tiling(
target, MultiLevelTiling::kConfigs.at(target.arch));
for (int i = 0; i < 100; ++i) { for (int i = 0; i < 100; ++i) {
int number_to_split = rand() % 65535 + 2; // random number in [2, 2^16] int number_to_split = rand() % 65535 + 2; // random number in [2, 2^16]
int split_size = rand() % 5 + 1; // random in [1, 5] int split_size = rand() % 5 + 1; // random in [1, 5]
std::vector<int> split = multi_level_tiling.SampleTileSplit<int>(number_to_split, split_size); std::vector<int> split =
multi_level_tiling.SampleTileSplit<int>(number_to_split, split_size);
EXPECT_EQ(split.size(), static_cast<size_t>(split_size)); EXPECT_EQ(split.size(), static_cast<size_t>(split_size));
int product = 1; int product = 1;
for (int num : split) { for (int num : split) {
...@@ -102,21 +106,31 @@ TEST(MultiLevelTile, SimpleLoops) { ...@@ -102,21 +106,31 @@ TEST(MultiLevelTile, SimpleLoops) {
poly::StageMap stages = CreateStages({C}); poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs = std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestMultiLevelTile_SimpleLoops", stages, {C}, {}, {}, nullptr, target, true); lang::LowerVec("TestMultiLevelTile_SimpleLoops",
stages,
{C},
{},
{},
nullptr,
target,
true);
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before MultiLevelTiling: "; VLOG(6) << "Expr before MultiLevelTiling: ";
VLOG(6) << ast_expr; VLOG(6) << ast_expr;
MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); MultiLevelTiling multi_level_tiling(
target, MultiLevelTiling::kConfigs.at(target.arch));
ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr})); ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr}));
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
EXPECT_EQ(multi_level_tiling.Init(&ir_schedule), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(multi_level_tiling.Init(&ir_schedule),
RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1); EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1);
multi_level_tiling.ApplyRandomly(); multi_level_tiling.ApplyRandomly();
// ApplyOnBlock // ApplyOnBlock
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"),
RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = multi_level_tiling.ApplyOnBlock(state, "C"); auto new_states = multi_level_tiling.ApplyOnBlock(state, "C");
auto test_func = [](ir::IRSchedule* ir_sch) { auto test_func = [](ir::IRSchedule* ir_sch) {
...@@ -152,26 +166,30 @@ TEST(MulitLevelTile, MatrixMultiply) { ...@@ -152,26 +166,30 @@ TEST(MulitLevelTile, MatrixMultiply) {
Var k(K.as_int32(), "reduce_axis_k"); Var k(K.as_int32(), "reduce_axis_k");
ir::Tensor C = Compute( ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); {M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
poly::StageMap stages = CreateStages({C}); poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs = std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestMultiLevelTile_MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true); lang::LowerVec("TestMultiLevelTile_MatrixMultiply", stages, {C}, {}, {},
nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before MultiLevelTiling: "; VLOG(6) << "Expr before MultiLevelTiling: ";
VLOG(6) << ast_expr; VLOG(6) << ast_expr;
MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); MultiLevelTiling multi_level_tiling(target,
ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr})); MultiLevelTiling::kConfigs.at(target.arch)); ir::IRSchedule
SearchState state(ir_schedule, 0, {}); ir_schedule(ir::ModuleExpr({ast_expr})); SearchState state(ir_schedule, 0, {});
EXPECT_EQ(multi_level_tiling.Init(&ir_schedule), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(multi_level_tiling.Init(&ir_schedule),
RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1); EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1);
multi_level_tiling.ApplyRandomly(); multi_level_tiling.ApplyRandomly();
// ApplyOnBlock // ApplyOnBlock
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"), RuleApplyType::kApplyAndPruneOtherRules); EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"),
auto new_states = multi_level_tiling.ApplyOnBlock(state, "C"); RuleApplyType::kApplyAndPruneOtherRules); auto new_states =
multi_level_tiling.ApplyOnBlock(state, "C");
auto test_func = [](ir::IRSchedule* ir_sch) { auto test_func = [](ir::IRSchedule* ir_sch) {
std::vector<ir::Expr> exprs = ir_sch->GetModule().GetExprs(); std::vector<ir::Expr> exprs = ir_sch->GetModule().GetExprs();
...@@ -201,16 +219,19 @@ TEST_F(TestMultiLevelTiling, Matmul) { ...@@ -201,16 +219,19 @@ TEST_F(TestMultiLevelTiling, Matmul) {
std::vector<int32_t> out_shape = {32, 32}; std::vector<int32_t> out_shape = {32, 32};
Initialize(common::DefaultNVGPUTarget()); Initialize(common::DefaultNVGPUTarget());
frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}}); frontend::Program matmul_op =
tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}});
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed); ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed);
SearchState state(ir_schedule); SearchState state(ir_schedule);
VLOG(6) << "Original state:\n" << state->DebugString(); VLOG(6) << "Original state:\n" << state->DebugString();
// Apply MultiLevelTiling // Apply MultiLevelTiling
MultiLevelTiling multi_level_tiling(target_, MultiLevelTiling::kConfigs.at(target_.arch)); MultiLevelTiling multi_level_tiling(
target_, MultiLevelTiling::kConfigs.at(target_.arch));
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]), EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]),
RuleApplyType::kApplyAndPruneOtherRules); RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = multi_level_tiling.ApplyOnBlock(state, default_output_names[0]); auto new_states =
multi_level_tiling.ApplyOnBlock(state, default_output_names[0]);
VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString(); VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString();
std::string ir = GetIR(new_states[0]->ir_schedule); std::string ir = GetIR(new_states[0]->ir_schedule);
std::string expected_ir = R"ROC(Expr 0 { std::string expected_ir = R"ROC(Expr 0 {
...@@ -332,7 +353,8 @@ TEST_F(TestMultiLevelTiling, Matmul) { ...@@ -332,7 +353,8 @@ TEST_F(TestMultiLevelTiling, Matmul) {
// execute and check precision // execute and check precision
CheckResult( CheckResult(
GenExecutableKernel(ir_module), GenExecutableKernel(ir_module),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))), GenExecutableKernel(BuildIRModule(MakeIRSchedule(
matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_input_names, default_input_names,
default_output_names, default_output_names,
{X_shape, Y_shape}, {X_shape, Y_shape},
...@@ -349,14 +371,17 @@ TEST_F(TestMultiLevelTiling, ReduceSum) { ...@@ -349,14 +371,17 @@ TEST_F(TestMultiLevelTiling, ReduceSum) {
Initialize(common::DefaultNVGPUTarget()); Initialize(common::DefaultNVGPUTarget());
frontend::Program reduce_sum_op = frontend::Program reduce_sum_op =
tests::OpBuilder("reduce_sum").Build({{"X", X_shape}}, {{"dim", reduce_dim}, {"keep_dim", false}}); tests::OpBuilder("reduce_sum")
.Build({{"X", X_shape}}, {{"dim", reduce_dim}, {"keep_dim", false}});
ir::IRSchedule ir_schedule = MakeIRSchedule(reduce_sum_op); ir::IRSchedule ir_schedule = MakeIRSchedule(reduce_sum_op);
SearchState state(ir_schedule); SearchState state(ir_schedule);
VLOG(6) << "Original state:\n" << state->DebugString(); VLOG(6) << "Original state:\n" << state->DebugString();
// Apply MultiLevelTiling // Apply MultiLevelTiling
MultiLevelTiling multi_level_tiling(target_, MultiLevelTiling::kConfigs.at(target_.arch)); MultiLevelTiling multi_level_tiling(
// EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]), RuleApplyType::kCannotApply); target_, MultiLevelTiling::kConfigs.at(target_.arch));
// EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state,
// default_output_names[0]), RuleApplyType::kCannotApply);
} }
TEST_F(TestMultiLevelTiling, Pool2d) { TEST_F(TestMultiLevelTiling, Pool2d) {
...@@ -374,7 +399,8 @@ TEST_F(TestMultiLevelTiling, Pool2d) { ...@@ -374,7 +399,8 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
std::string data_format = "NCHW"; std::string data_format = "NCHW";
bool adaptive = false; bool adaptive = false;
std::string padding_algorithm = "EXPLICIT"; std::string padding_algorithm = "EXPLICIT";
frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build({{"input", input_shape}}, frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build(
{{"input", input_shape}},
{{"pool_type", pooling_type}, {{"pool_type", pooling_type},
{"kernel_size", ksize}, {"kernel_size", ksize},
{"stride_size", strides}, {"stride_size", strides},
...@@ -403,7 +429,8 @@ TEST_F(TestMultiLevelTiling, Pool2d) { ...@@ -403,7 +429,8 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
MultiLevelTiling multi_level_tiling(target_, mlt_config); MultiLevelTiling multi_level_tiling(target_, mlt_config);
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]), EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]),
RuleApplyType::kApplyAndPruneOtherRules); RuleApplyType::kApplyAndPruneOtherRules);
auto new_states = multi_level_tiling.ApplyOnBlock(state, default_output_names[0]); auto new_states =
multi_level_tiling.ApplyOnBlock(state, default_output_names[0]);
VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString(); VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString();
std::string ir = GetIR(new_states[0]->ir_schedule); std::string ir = GetIR(new_states[0]->ir_schedule);
...@@ -534,9 +561,10 @@ Expr 1 { ...@@ -534,9 +561,10 @@ Expr 1 {
VLOG(6) << "scheduled source code:\n" << source_code; VLOG(6) << "scheduled source code:\n" << source_code;
// execute and check precision // execute and check precision
CheckResult(GenExecutableKernel(ir_module), CheckResult(
GenExecutableKernel( GenExecutableKernel(ir_module),
BuildIRModule(MakeIRSchedule(pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))), GenExecutableKernel(BuildIRModule(MakeIRSchedule(
pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_input_names, default_input_names,
default_output_names, default_output_names,
{input_shape}, {input_shape},
......
...@@ -34,11 +34,15 @@ class SkipRule : public AutoGenRule { ...@@ -34,11 +34,15 @@ class SkipRule : public AutoGenRule {
std::string GetRuleName() const override; std::string GetRuleName() const override;
RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override { RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override {
return RuleApplyType::kApply; return RuleApplyType::kApply;
} }
std::vector<SearchState> ApplyOnBlock(SearchState state, const std::string& block_name) override { return {state}; } std::vector<SearchState> ApplyOnBlock(
SearchState state, const std::string& block_name) override {
return {state};
}
}; };
} // namespace auto_schedule } // namespace auto_schedule
......
...@@ -53,7 +53,8 @@ TEST(SkipRule, Basic) { ...@@ -53,7 +53,8 @@ TEST(SkipRule, Basic) {
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = CreateStages({C}); poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before SkipRule: "; VLOG(6) << "Expr before SkipRule: ";
...@@ -69,7 +70,8 @@ TEST(SkipRule, Basic) { ...@@ -69,7 +70,8 @@ TEST(SkipRule, Basic) {
// ApplyOnBlock // ApplyOnBlock
EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply); EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply);
std::vector<cinn::auto_schedule::SearchState> states = skip_rule.ApplyOnBlock(state, "C"); std::vector<cinn::auto_schedule::SearchState> states =
skip_rule.ApplyOnBlock(state, "C");
auto test_func = [&ast_expr](ir::IRSchedule* ir_sch) { auto test_func = [&ast_expr](ir::IRSchedule* ir_sch) {
std::vector<ir::Expr> exprs = ir_sch->GetModule().GetExprs(); std::vector<ir::Expr> exprs = ir_sch->GetModule().GetExprs();
...@@ -100,7 +102,8 @@ TEST(SkipRule, ApplyOnSpecificBlock) { ...@@ -100,7 +102,8 @@ TEST(SkipRule, ApplyOnSpecificBlock) {
{M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C");
poly::StageMap stages = CreateStages({C}); poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); std::vector<ir::LoweredFunc> funcs = lang::LowerVec(
"TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before SkipRule: "; VLOG(6) << "Expr before SkipRule: ";
...@@ -111,7 +114,8 @@ TEST(SkipRule, ApplyOnSpecificBlock) { ...@@ -111,7 +114,8 @@ TEST(SkipRule, ApplyOnSpecificBlock) {
SearchState state(ir_schedule, 0, {}); SearchState state(ir_schedule, 0, {});
EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply); EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply);
std::vector<cinn::auto_schedule::SearchState> states = skip_rule.ApplyOnBlock(state, "C"); std::vector<cinn::auto_schedule::SearchState> states =
skip_rule.ApplyOnBlock(state, "C");
std::vector<ir::Expr> exprs = states[0]->ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> exprs = states[0]->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL); EXPECT_EQ(exprs.size(), 1UL);
......
...@@ -46,22 +46,28 @@ void TestAutoGenRuleBase::Initialize(const common::Target& target) { ...@@ -46,22 +46,28 @@ void TestAutoGenRuleBase::Initialize(const common::Target& target) {
backend_compier_ = backends::Compiler::Create(target); backend_compier_ = backends::Compiler::Create(target);
} }
ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(const frontend::Program& test_program, ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
const frontend::Program& test_program,
utils::LinearRandomEngine::StateType rand_seed, utils::LinearRandomEngine::StateType rand_seed,
bool apply_manual_schedule) { bool apply_manual_schedule) {
Context::Global().ResetNameId(); Context::Global().ResetNameId();
auto graph = std::make_shared<hlir::framework::Graph>(test_program, target_); auto graph = std::make_shared<hlir::framework::Graph>(test_program, target_);
hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
LOG_IF(WARNING, graph->fusion_groups.size() > 1) << "Test Graph has more than 1 group"; LOG_IF(WARNING, graph->fusion_groups.size() > 1)
auto& dtype_dict = graph->GetMutableAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype"); << "Test Graph has more than 1 group";
auto& shape_dict = graph->GetMutableAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape"); auto& dtype_dict =
graph->GetMutableAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
auto& shape_dict = graph->GetMutableAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_); hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_);
if (apply_manual_schedule) { if (apply_manual_schedule) {
lowered_funcs_ = op_lowerer.Lower(graph->fusion_groups.front()); lowered_funcs_ = op_lowerer.Lower(graph->fusion_groups.front());
} else { } else {
lowered_funcs_ = op_lowerer.LowerWithoutSchedule(graph->fusion_groups.front()); lowered_funcs_ =
op_lowerer.LowerWithoutSchedule(graph->fusion_groups.front());
} }
CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty"; CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty";
...@@ -76,14 +82,16 @@ std::string TestAutoGenRuleBase::GetIR(const ir::IRSchedule& schedule) { ...@@ -76,14 +82,16 @@ std::string TestAutoGenRuleBase::GetIR(const ir::IRSchedule& schedule) {
const auto& exprs = schedule.GetModule().GetExprs(); const auto& exprs = schedule.GetModule().GetExprs();
std::stringstream module_stream; std::stringstream module_stream;
for (auto i = 0; i < exprs.size(); ++i) { for (auto i = 0; i < exprs.size(); ++i) {
module_stream << "Expr " << i << " {\n" << exprs.at(i) << "\n} // end Expr " << i << "\n"; module_stream << "Expr " << i << " {\n"
<< exprs.at(i) << "\n} // end Expr " << i << "\n";
} }
return module_stream.str(); return module_stream.str();
} }
ir::Module TestAutoGenRuleBase::BuildIRModule(const ir::IRSchedule& schedule) { ir::Module TestAutoGenRuleBase::BuildIRModule(const ir::IRSchedule& schedule) {
auto&& updated_bodys = schedule.GetModule().GetExprs(); auto&& updated_bodys = schedule.GetModule().GetExprs();
CHECK_EQ(lowered_funcs_.size(), updated_bodys.size()) << "associated exprs size not equal"; CHECK_EQ(lowered_funcs_.size(), updated_bodys.size())
<< "associated exprs size not equal";
ir::Module::Builder builder("test_bulder", this->target_); ir::Module::Builder builder("test_bulder", this->target_);
for (int i = 0; i < lowered_funcs_.size(); ++i) { for (int i = 0; i < lowered_funcs_.size(); ++i) {
...@@ -102,20 +110,24 @@ std::string TestAutoGenRuleBase::GenSourceCode(const ir::Module& ir_module) { ...@@ -102,20 +110,24 @@ std::string TestAutoGenRuleBase::GenSourceCode(const ir::Module& ir_module) {
if (target_ == common::DefaultNVGPUTarget()) { if (target_ == common::DefaultNVGPUTarget()) {
codegen = std::make_unique<backends::CodeGenCUDA_Dev>(this->target_); codegen = std::make_unique<backends::CodeGenCUDA_Dev>(this->target_);
} else { } else {
codegen = std::make_unique<backends::CodeGenCX86>(this->target_, CodeGenCX86::Feature::AVX512); codegen = std::make_unique<backends::CodeGenCX86>(
this->target_, CodeGenCX86::Feature::AVX512);
} }
#else #else
codegen = std::make_unique<backends::CodeGenCX86>(this->target_, CodeGenCX86::Feature::AVX512); codegen = std::make_unique<backends::CodeGenCX86>(
this->target_, CodeGenCX86::Feature::AVX512);
#endif #endif
codegen->SetInlineBuiltinCodes(false); codegen->SetInlineBuiltinCodes(false);
return codegen->Compile(ir_module, CodeGenC::OutputKind::CImpl); return codegen->Compile(ir_module, CodeGenC::OutputKind::CImpl);
} }
raw_func_type TestAutoGenRuleBase::GenExecutableKernel(const ir::Module& ir_module) { raw_func_type TestAutoGenRuleBase::GenExecutableKernel(
const ir::Module& ir_module) {
auto&& func_name = lowered_funcs_.front()->name; auto&& func_name = lowered_funcs_.front()->name;
// Compile to machine code // Compile to machine code
backend_compier_->Build(ir_module); backend_compier_->Build(ir_module);
auto test_func_ptr = reinterpret_cast<void (*)(void**, int32_t)>(backend_compier_->Lookup(func_name)); auto test_func_ptr = reinterpret_cast<void (*)(void**, int32_t)>(
backend_compier_->Lookup(func_name));
return test_func_ptr; return test_func_ptr;
} }
...@@ -138,15 +150,19 @@ void MemoryCopy(const float* src, float* dst, int numel, std::string type) { ...@@ -138,15 +150,19 @@ void MemoryCopy(const float* src, float* dst, int numel, std::string type) {
} }
} }
void AddDataToScope( void AddDataToScope(Scope* scope,
Scope* scope, const common::Target& target, float* data_ptr, std::string name, const std::vector<int>& shape) { const common::Target& target,
float* data_ptr,
std::string name,
const std::vector<int>& shape) {
auto* var = scope->Var<Tensor>(name); auto* var = scope->Var<Tensor>(name);
auto& tensor = absl::get<Tensor>(*var); auto& tensor = absl::get<Tensor>(*var);
CHECK(shape.size()) << "The size of shape can not be 0."; CHECK(shape.size()) << "The size of shape can not be 0.";
Shape cinn_shape(shape); Shape cinn_shape(shape);
tensor->Resize(cinn_shape); tensor->Resize(cinn_shape);
auto* tgt_data_ptr = tensor->mutable_data<float>(target); auto* tgt_data_ptr = tensor->mutable_data<float>(target);
std::string mem_cpy_type = target == common::DefaultNVGPUTarget() ? "DeviceToHost" : "HostToHost"; std::string mem_cpy_type =
target == common::DefaultNVGPUTarget() ? "DeviceToHost" : "HostToHost";
MemoryCopy(data_ptr, tgt_data_ptr, cinn_shape.numel(), mem_cpy_type); MemoryCopy(data_ptr, tgt_data_ptr, cinn_shape.numel(), mem_cpy_type);
} }
...@@ -159,16 +175,20 @@ void CheckResult(raw_func_type test_func, ...@@ -159,16 +175,20 @@ void CheckResult(raw_func_type test_func,
const common::Target& target) { const common::Target& target) {
CHECK(input_names.size()) << "The number of inputs must be greater than 0."; CHECK(input_names.size()) << "The number of inputs must be greater than 0.";
CHECK(output_names.size()) << "The number of outputs must be greater than 0."; CHECK(output_names.size()) << "The number of outputs must be greater than 0.";
CHECK_EQ(input_names.size(), input_shapes.size()) << "The quantity of input_names and input_shapes must be equal."; CHECK_EQ(input_names.size(), input_shapes.size())
<< "The quantity of input_names and input_shapes must be equal.";
CHECK_EQ(output_names.size(), output_shapes.size()) CHECK_EQ(output_names.size(), output_shapes.size())
<< "The quantity of output_names and output_shapes must be equal."; << "The quantity of output_names and output_shapes must be equal.";
// Initialize data // Initialize data
std::vector<float*> input_data_ptrs(input_names.size()); std::vector<float*> input_data_ptrs(input_names.size());
for (int i = 0; i < input_shapes.size(); ++i) { for (int i = 0; i < input_shapes.size(); ++i) {
int input_data_numel = int input_data_numel = std::accumulate(
std::accumulate(input_shapes[i].begin(), input_shapes[i].end(), 1, [](int a, int b) { return a * b; }); input_shapes[i].begin(), input_shapes[i].end(), 1, [](int a, int b) {
input_data_ptrs[i] = reinterpret_cast<float*>(malloc(input_data_numel * sizeof(float))); return a * b;
});
input_data_ptrs[i] =
reinterpret_cast<float*>(malloc(input_data_numel * sizeof(float)));
for (int j = 0; j < input_data_numel; ++j) { for (int j = 0; j < input_data_numel; ++j) {
input_data_ptrs[i][j] = (rand() * 1.f) / RAND_MAX; input_data_ptrs[i][j] = (rand() * 1.f) / RAND_MAX;
} }
...@@ -177,24 +197,35 @@ void CheckResult(raw_func_type test_func, ...@@ -177,24 +197,35 @@ void CheckResult(raw_func_type test_func,
std::vector<float*> expected_output_data_ptrs(output_names.size()); std::vector<float*> expected_output_data_ptrs(output_names.size());
std::vector<int> output_data_numels(output_shapes.size()); std::vector<int> output_data_numels(output_shapes.size());
for (int i = 0; i < output_shapes.size(); ++i) { for (int i = 0; i < output_shapes.size(); ++i) {
output_data_numels[i] = output_data_numels[i] = std::accumulate(
std::accumulate(output_shapes[i].begin(), output_shapes[i].end(), 1, [](int a, int b) { return a * b; }); output_shapes[i].begin(), output_shapes[i].end(), 1, [](int a, int b) {
test_output_data_ptrs[i] = reinterpret_cast<float*>(malloc(output_data_numels[i] * sizeof(float))); return a * b;
});
test_output_data_ptrs[i] =
reinterpret_cast<float*>(malloc(output_data_numels[i] * sizeof(float)));
memset(test_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float)); memset(test_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float));
expected_output_data_ptrs[i] = reinterpret_cast<float*>(malloc(output_data_numels[i] * sizeof(float))); expected_output_data_ptrs[i] =
memset(expected_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float)); reinterpret_cast<float*>(malloc(output_data_numels[i] * sizeof(float)));
memset(
expected_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float));
} }
auto launch_kernel_fn = [&](raw_func_type& raw_func, std::vector<float*>& output_data_ptrs) { auto launch_kernel_fn = [&](raw_func_type& raw_func,
std::vector<float*>& output_data_ptrs) {
// Initialize scope // Initialize scope
Scope scope; Scope scope;
// Initialize input data in scope. // Initialize input data in scope.
for (int i = 0; i < input_names.size(); ++i) { for (int i = 0; i < input_names.size(); ++i) {
AddDataToScope(&scope, target, input_data_ptrs[i], input_names[i], input_shapes[i]); AddDataToScope(
&scope, target, input_data_ptrs[i], input_names[i], input_shapes[i]);
} }
// Initialize output data in scope. // Initialize output data in scope.
for (int i = 0; i < output_names.size(); ++i) { for (int i = 0; i < output_names.size(); ++i) {
AddDataToScope(&scope, target, output_data_ptrs[i], output_names[i], output_shapes[i]); AddDataToScope(&scope,
target,
output_data_ptrs[i],
output_names[i],
output_shapes[i]);
} }
// Create Instruction and run // Create Instruction and run
...@@ -208,8 +239,11 @@ void CheckResult(raw_func_type test_func, ...@@ -208,8 +239,11 @@ void CheckResult(raw_func_type test_func,
// data // data
for (int i = 0; i < output_names.size(); ++i) { for (int i = 0; i < output_names.size(); ++i) {
const float* result_ptr = scope.GetTensor(output_names[i])->data<float>(); const float* result_ptr = scope.GetTensor(output_names[i])->data<float>();
std::string mem_cpy_type = target == common::DefaultNVGPUTarget() ? "DeviceToHost" : "HostToHost"; std::string mem_cpy_type = target == common::DefaultNVGPUTarget()
MemoryCopy(result_ptr, output_data_ptrs[i], output_data_numels[i], mem_cpy_type); ? "DeviceToHost"
: "HostToHost";
MemoryCopy(
result_ptr, output_data_ptrs[i], output_data_numels[i], mem_cpy_type);
} }
}; };
...@@ -220,7 +254,8 @@ void CheckResult(raw_func_type test_func, ...@@ -220,7 +254,8 @@ void CheckResult(raw_func_type test_func,
// Check result // Check result
for (int i = 0; i < output_shapes.size(); ++i) { for (int i = 0; i < output_shapes.size(); ++i) {
for (int j = 0; j < output_data_numels[i]; ++j) { for (int j = 0; j < output_data_numels[i]; ++j) {
ASSERT_NEAR(test_output_data_ptrs[i][j], expected_output_data_ptrs[i][j], 1e-4); ASSERT_NEAR(
test_output_data_ptrs[i][j], expected_output_data_ptrs[i][j], 1e-4);
} }
} }
......
...@@ -47,15 +47,18 @@ class TestAutoGenRuleBase : public ::testing::Test { ...@@ -47,15 +47,18 @@ class TestAutoGenRuleBase : public ::testing::Test {
// Initialize context for specified target // Initialize context for specified target
void Initialize(const common::Target& target); void Initialize(const common::Target& target);
// construct an ir::IRSchedule by lowering the specified for following AutoGenRule test // construct an ir::IRSchedule by lowering the specified for following
ir::IRSchedule MakeIRSchedule(const frontend::Program& test_program, // AutoGenRule test
ir::IRSchedule MakeIRSchedule(
const frontend::Program& test_program,
utils::LinearRandomEngine::StateType rand_seed = -1, utils::LinearRandomEngine::StateType rand_seed = -1,
bool apply_manual_schedule = false); bool apply_manual_schedule = false);
// Get the IR of bodies in IRSchedule // Get the IR of bodies in IRSchedule
std::string GetIR(const ir::IRSchedule& schedule); std::string GetIR(const ir::IRSchedule& schedule);
// build ir::Module from the original lowered funcs with their bodies updated by the schedule // build ir::Module from the original lowered funcs with their bodies updated
// by the schedule
ir::Module BuildIRModule(const ir::IRSchedule& schedule); ir::Module BuildIRModule(const ir::IRSchedule& schedule);
// generate source code with the built ir module // generate source code with the built ir module
...@@ -75,9 +78,12 @@ class TestAutoGenRuleBase : public ::testing::Test { ...@@ -75,9 +78,12 @@ class TestAutoGenRuleBase : public ::testing::Test {
* @params-2: Expected function pointer for comparison. * @params-2: Expected function pointer for comparison.
* @params-3: Names of input data. * @params-3: Names of input data.
* @params-4: Names of output data. * @params-4: Names of output data.
* @params-5: Shapes of the input data, each input corresponds to a std::vector<int>. * @params-5: Shapes of the input data, each input corresponds to a
* @params-6: Shapes of the output data, each output corresponds to a std::vector<int>. * std::vector<int>.
* @params-7: The Target expressing computing platform and architecture of the function to be tested. * @params-6: Shapes of the output data, each output corresponds to a
* std::vector<int>.
* @params-7: The Target expressing computing platform and architecture of the
* function to be tested.
* @return: void * @return: void
*/ */
void CheckResult(raw_func_type test_func, void CheckResult(raw_func_type test_func,
......
...@@ -26,20 +26,26 @@ namespace auto_schedule { ...@@ -26,20 +26,26 @@ namespace auto_schedule {
class SearchState; class SearchState;
// Select the next block to be operated for SearchState during the search process // Select the next block to be operated for SearchState during the search
// process
class BlockSampler { class BlockSampler {
public: public:
/** /**
* @brief Create a BlockSampler with the specific strategy name and necessary construct parameters. * @brief Create a BlockSampler with the specific strategy name and necessary
* construct parameters.
* @param all_blocks All possible blocks to be sampled. * @param all_blocks All possible blocks to be sampled.
* @param default_remove_policy The default option to determine whether to delete the next block after selecting it. * @param default_remove_policy The default option to determine whether to
* delete the next block after selecting it.
* @param strategy The block sampling strategy. * @param strategy The block sampling strategy.
* Currently, the available strategies are "traversal" and "probabilistic", * Currently, the available strategies are "traversal" and
* where "traversal" means to select blocks one by one until all blocks are traversed, * "probabilistic", where "traversal" means to select blocks one by one until
* and "probabilistic" means randomly picking blocks according to the given distribution. * all blocks are traversed, and "probabilistic" means randomly picking blocks
* @param weights Used for the probabilistic policy, giving each candidate a weight. * according to the given distribution.
* @param weights Used for the probabilistic policy, giving each candidate a
* weight.
*/ */
static std::unique_ptr<BlockSampler> Make(const std::vector<ir::Expr>& all_blocks, static std::unique_ptr<BlockSampler> Make(
const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy = true, bool default_remove_policy = true,
const std::string& strategy = "traversal", const std::string& strategy = "traversal",
utils::LinearRandomEngine::StateType rand_seed = 0, utils::LinearRandomEngine::StateType rand_seed = 0,
...@@ -56,18 +62,22 @@ class BlockSampler { ...@@ -56,18 +62,22 @@ class BlockSampler {
protected: protected:
// A BlockSampler object should be created with the static function Make() // A BlockSampler object should be created with the static function Make()
BlockSampler(const std::vector<ir::Expr>& all_blocks, bool default_remove_policy); BlockSampler(const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy);
// Select a block to apply rule // Select a block to apply rule
// The param remove is used to determine whether to delete the next block after selecting it, // The param remove is used to determine whether to delete the next block
// If remove == true, it will not be sampled in the future. // after selecting it, If remove == true, it will not be sampled in the
// future.
virtual std::string NextBlock(bool remove) = 0; virtual std::string NextBlock(bool remove) = 0;
// The names of all blocks // The names of all blocks
// Because the Block Expr will be changed in the search process, the name is saved for indexing // Because the Block Expr will be changed in the search process, the name is
// saved for indexing
std::vector<std::string> all_blocks_; std::vector<std::string> all_blocks_;
// The default policy to determine whether to delete the next block after selecting it. // The default policy to determine whether to delete the next block after
// selecting it.
bool default_remove_policy_; bool default_remove_policy_;
}; };
...@@ -75,7 +85,8 @@ class BlockSampler { ...@@ -75,7 +85,8 @@ class BlockSampler {
// witch means to select blocks one by one until all blocks are traversed. // witch means to select blocks one by one until all blocks are traversed.
class TraversalBlockSampler : public BlockSampler { class TraversalBlockSampler : public BlockSampler {
public: public:
TraversalBlockSampler(const std::vector<ir::Expr>& all_blocks, bool default_remove_policy) TraversalBlockSampler(const std::vector<ir::Expr>& all_blocks,
bool default_remove_policy)
: BlockSampler(all_blocks, default_remove_policy), cur_idx_(0) {} : BlockSampler(all_blocks, default_remove_policy), cur_idx_(0) {}
const char* Name() const override { return "traversal"; } const char* Name() const override { return "traversal"; }
......
...@@ -24,7 +24,8 @@ namespace auto_schedule { ...@@ -24,7 +24,8 @@ namespace auto_schedule {
std::vector<ir::Expr> CreateTestBlocks() { std::vector<ir::Expr> CreateTestBlocks() {
std::vector<ir::Expr> blocks; std::vector<ir::Expr> blocks;
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
ir::Expr block = ir::ScheduleBlock::Make({}, {}, {}, "block_" + std::to_string(i), ir::Expr()); ir::Expr block = ir::ScheduleBlock::Make(
{}, {}, {}, "block_" + std::to_string(i), ir::Expr());
blocks.push_back(ir::ScheduleBlockRealize::Make({}, block)); blocks.push_back(ir::ScheduleBlockRealize::Make({}, block));
} }
return blocks; return blocks;
...@@ -32,9 +33,11 @@ std::vector<ir::Expr> CreateTestBlocks() { ...@@ -32,9 +33,11 @@ std::vector<ir::Expr> CreateTestBlocks() {
TEST(BlockSampler, Make) { TEST(BlockSampler, Make) {
std::vector<ir::Expr> mock_blocks = CreateTestBlocks(); std::vector<ir::Expr> mock_blocks = CreateTestBlocks();
auto traversal_block_sampler = BlockSampler::Make(mock_blocks, true, "traversal"); auto traversal_block_sampler =
BlockSampler::Make(mock_blocks, true, "traversal");
ASSERT_STREQ(traversal_block_sampler->Name(), "traversal"); ASSERT_STREQ(traversal_block_sampler->Name(), "traversal");
auto probabilistic_block_sampler = BlockSampler::Make(mock_blocks, true, "probabilistic"); auto probabilistic_block_sampler =
BlockSampler::Make(mock_blocks, true, "probabilistic");
ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic"); ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic");
} }
...@@ -55,14 +58,16 @@ TEST(TraversalBlockSampler, NextBlock) { ...@@ -55,14 +58,16 @@ TEST(TraversalBlockSampler, NextBlock) {
TEST(ProbabilisticBlockSampler, NextBlock) { TEST(ProbabilisticBlockSampler, NextBlock) {
std::vector<ir::Expr> blocks = CreateTestBlocks(); std::vector<ir::Expr> blocks = CreateTestBlocks();
auto probabilistic_block_sampler = BlockSampler::Make(blocks, false, "probabilistic", 0, {4, 2, 1}); auto probabilistic_block_sampler =
BlockSampler::Make(blocks, false, "probabilistic", 0, {4, 2, 1});
std::string block_name; std::string block_name;
for (int i = 0; i < 20; ++i) { for (int i = 0; i < 20; ++i) {
block_name = probabilistic_block_sampler->NextBlock(); block_name = probabilistic_block_sampler->NextBlock();
VLOG(6) << "next block name: " << block_name; VLOG(6) << "next block name: " << block_name;
} }
probabilistic_block_sampler = BlockSampler::Make(blocks, true, "probabilistic", 0, {4, 2, 1}); probabilistic_block_sampler =
BlockSampler::Make(blocks, true, "probabilistic", 0, {4, 2, 1});
probabilistic_block_sampler->NextBlock(); probabilistic_block_sampler->NextBlock();
probabilistic_block_sampler->NextBlock(); probabilistic_block_sampler->NextBlock();
probabilistic_block_sampler->NextBlock(); probabilistic_block_sampler->NextBlock();
......
...@@ -30,16 +30,21 @@ class SearchState; ...@@ -30,16 +30,21 @@ class SearchState;
class RuleSampler { class RuleSampler {
public: public:
/** /**
* @brief Create a RuleSampler with the specific strategy name and necessary construct parameters. * @brief Create a RuleSampler with the specific strategy name and necessary
* construct parameters.
* @param potential_rules All possible rules to be sampled. * @param potential_rules All possible rules to be sampled.
* @param default_remove_policy The default option to determine whether to delete the next block after selecting it. * @param default_remove_policy The default option to determine whether to
* delete the next block after selecting it.
* @param strategy The rule sampling strategy. * @param strategy The rule sampling strategy.
* Currently, the available strategies are "traversal" and "probabilistic", * Currently, the available strategies are "traversal" and
* where "traversal" means to select rules one by one until all rules are traversed, * "probabilistic", where "traversal" means to select rules one by one until
* and "probabilistic" means randomly picking rules according to the given distribution. * all rules are traversed, and "probabilistic" means randomly picking rules
* @param weights Used for the probabilistic policy, giving each candidate a weight. * according to the given distribution.
* @param weights Used for the probabilistic policy, giving each candidate a
* weight.
*/ */
static std::unique_ptr<RuleSampler> Make(const std::vector<AutoGenRule*>& potential_rules, static std::unique_ptr<RuleSampler> Make(
const std::vector<AutoGenRule*>& potential_rules,
bool default_remove_policy = true, bool default_remove_policy = true,
const std::string& strategy = "traversal", const std::string& strategy = "traversal",
utils::LinearRandomEngine::StateType rand_seed = 0, utils::LinearRandomEngine::StateType rand_seed = 0,
...@@ -55,18 +60,21 @@ class RuleSampler { ...@@ -55,18 +60,21 @@ class RuleSampler {
protected: protected:
// A RuleSampler object should be created with the static function Make() // A RuleSampler object should be created with the static function Make()
RuleSampler(const std::vector<AutoGenRule*>& potential_rules, bool default_remove_policy) RuleSampler(const std::vector<AutoGenRule*>& potential_rules,
: potential_rules_(&potential_rules), default_remove_policy_(default_remove_policy) {} bool default_remove_policy)
: potential_rules_(&potential_rules),
default_remove_policy_(default_remove_policy) {}
// Select a rule to apply. // Select a rule to apply.
// The param remove is used to determine whether to delete the next rule after selecting it, // The param remove is used to determine whether to delete the next rule after
// If remove == true, it will not be sampled in the future. // selecting it, If remove == true, it will not be sampled in the future.
virtual AutoGenRule* NextRule(bool remove) = 0; virtual AutoGenRule* NextRule(bool remove) = 0;
// The pointer refers to all potential rules // The pointer refers to all potential rules
const std::vector<AutoGenRule*>* potential_rules_; const std::vector<AutoGenRule*>* potential_rules_;
// The default policy to determine whether to delete the next rule after selecting it. // The default policy to determine whether to delete the next rule after
// selecting it.
bool default_remove_policy_; bool default_remove_policy_;
}; };
...@@ -74,7 +82,8 @@ class RuleSampler { ...@@ -74,7 +82,8 @@ class RuleSampler {
// witch means to select rules one by one until all rules are traversed. // witch means to select rules one by one until all rules are traversed.
class TraversalRuleSampler : public RuleSampler { class TraversalRuleSampler : public RuleSampler {
public: public:
TraversalRuleSampler(const std::vector<AutoGenRule*>& potential_rules, bool default_remove_policy) TraversalRuleSampler(const std::vector<AutoGenRule*>& potential_rules,
bool default_remove_policy)
: RuleSampler(potential_rules, default_remove_policy), cur_idx_(0) {} : RuleSampler(potential_rules, default_remove_policy), cur_idx_(0) {}
const char* Name() const override { return "traversal"; } const char* Name() const override { return "traversal"; }
......
...@@ -28,13 +28,16 @@ Target target = common::DefaultNVGPUTarget(); ...@@ -28,13 +28,16 @@ Target target = common::DefaultNVGPUTarget();
Target target = common::DefaultHostTarget(); Target target = common::DefaultHostTarget();
#endif #endif
std::vector<AutoGenRule*> GenerateTestRules() { return {new AutoUnroll(target), new SkipRule(target)}; } std::vector<AutoGenRule*> GenerateTestRules() {
return {new AutoUnroll(target), new SkipRule(target)};
}
TEST(RuleSampler, Make) { TEST(RuleSampler, Make) {
std::vector<AutoGenRule*> rules = GenerateTestRules(); std::vector<AutoGenRule*> rules = GenerateTestRules();
auto traversal_block_sampler = RuleSampler::Make(rules, true, "traversal"); auto traversal_block_sampler = RuleSampler::Make(rules, true, "traversal");
ASSERT_STREQ(traversal_block_sampler->Name(), "traversal"); ASSERT_STREQ(traversal_block_sampler->Name(), "traversal");
auto probabilistic_block_sampler = RuleSampler::Make(rules, true, "probabilistic"); auto probabilistic_block_sampler =
RuleSampler::Make(rules, true, "probabilistic");
ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic"); ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic");
} }
...@@ -58,14 +61,16 @@ TEST(TraversalRuleSampler, NextRule) { ...@@ -58,14 +61,16 @@ TEST(TraversalRuleSampler, NextRule) {
TEST(ProbabilisticRuleSampler, NextRule) { TEST(ProbabilisticRuleSampler, NextRule) {
std::vector<AutoGenRule*> rules = GenerateTestRules(); std::vector<AutoGenRule*> rules = GenerateTestRules();
auto probabilistic_rule_sampler = RuleSampler::Make(rules, false, "probabilistic", 0, {4, 1}); auto probabilistic_rule_sampler =
RuleSampler::Make(rules, false, "probabilistic", 0, {4, 1});
AutoGenRule* rule; AutoGenRule* rule;
for (int i = 0; i < 20; ++i) { for (int i = 0; i < 20; ++i) {
rule = probabilistic_rule_sampler->NextRule(); rule = probabilistic_rule_sampler->NextRule();
VLOG(6) << "next rule name: " << rule->GetRuleName(); VLOG(6) << "next rule name: " << rule->GetRuleName();
} }
probabilistic_rule_sampler = RuleSampler::Make(rules, true, "probabilistic", 0, {4, 1}); probabilistic_rule_sampler =
RuleSampler::Make(rules, true, "probabilistic", 0, {4, 1});
probabilistic_rule_sampler->NextRule(); probabilistic_rule_sampler->NextRule();
probabilistic_rule_sampler->NextRule(); probabilistic_rule_sampler->NextRule();
ASSERT_EQ(nullptr, probabilistic_rule_sampler->NextRule()); ASSERT_EQ(nullptr, probabilistic_rule_sampler->NextRule());
......
...@@ -39,18 +39,23 @@ DECLARE_bool(auto_schedule_use_cost_model); ...@@ -39,18 +39,23 @@ DECLARE_bool(auto_schedule_use_cost_model);
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
SearchSpace::SearchSpace(const TuneTask& tune_task, utils::LinearRandomEngine::StateType rand_seed) SearchSpace::SearchSpace(const TuneTask& tune_task,
: tune_task_(tune_task), rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)) { utils::LinearRandomEngine::StateType rand_seed)
: tune_task_(tune_task),
rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)) {
const auto& target = tune_task_.target; const auto& target = tune_task_.target;
// initialize a set of rules and they are commonly used by all states // initialize a set of rules and they are commonly used by all states
// TODO(zhhsplendid): pass correct output names to AutoInline // TODO(zhhsplendid): pass correct output names to AutoInline
// sketch_rules_.emplace_back(new AutoInline(target, tune_task_.output_names)); // sketch_rules_.emplace_back(new AutoInline(target,
sketch_rules_.emplace_back(new MultiLevelTiling(target, MultiLevelTiling::kConfigs.at(target.arch))); // tune_task_.output_names));
sketch_rules_.emplace_back(
new MultiLevelTiling(target, MultiLevelTiling::kConfigs.at(target.arch)));
sketch_rules_.emplace_back(new AutoUnroll(target)); sketch_rules_.emplace_back(new AutoUnroll(target));
sketch_rules_.emplace_back(new SkipRule(target)); sketch_rules_.emplace_back(new SkipRule(target));
} }
SearchState SearchSpace::GetScheduleMutate(const SearchState& state, const ExprCostModel& cost_model) { SearchState SearchSpace::GetScheduleMutate(const SearchState& state,
const ExprCostModel& cost_model) {
bool has_manual_schedule = false; bool has_manual_schedule = false;
if (has_manual_schedule) { if (has_manual_schedule) {
SearchState ret = ManualScheduleMutate(state); SearchState ret = ManualScheduleMutate(state);
...@@ -58,9 +63,11 @@ SearchState SearchSpace::GetScheduleMutate(const SearchState& state, const ExprC ...@@ -58,9 +63,11 @@ SearchState SearchSpace::GetScheduleMutate(const SearchState& state, const ExprC
} }
SearchState ret = RandomScheduleMutate(state); SearchState ret = RandomScheduleMutate(state);
if (FLAGS_auto_schedule_use_cost_model) { if (FLAGS_auto_schedule_use_cost_model) {
ret->predicted_cost = cost_model.Predict(ret->ir_schedule.GetModule(), tune_task_.target); ret->predicted_cost =
cost_model.Predict(ret->ir_schedule.GetModule(), tune_task_.target);
} }
VLOG(4) << JoinStatesDebugString("SearchSpace::GetScheduleMutate", {state}, /*verbose=*/VLOG_IS_ON(5)); VLOG(4) << JoinStatesDebugString(
"SearchSpace::GetScheduleMutate", {state}, /*verbose=*/VLOG_IS_ON(5));
return ret; return ret;
} }
...@@ -79,7 +86,8 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) { ...@@ -79,7 +86,8 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) {
for (int idx = 0; idx != ret->applicable_rules.size(); ++idx) { for (int idx = 0; idx != ret->applicable_rules.size(); ++idx) {
AutoGenRule* rule = ret->applicable_rules.at(idx); AutoGenRule* rule = ret->applicable_rules.at(idx);
RuleApplyType apply_type = rule->Init(&ret->ir_schedule); RuleApplyType apply_type = rule->Init(&ret->ir_schedule);
VLOG(6) << "Evaluate rule:" << rule->GetRuleName() << "=" << static_cast<int>(apply_type); VLOG(6) << "Evaluate rule:" << rule->GetRuleName() << "="
<< static_cast<int>(apply_type);
apply_types[idx] = apply_type; apply_types[idx] = apply_type;
if (apply_type != RuleApplyType::kCannotApply) { if (apply_type != RuleApplyType::kCannotApply) {
weight_to_rule_index[cur_weight] = idx; weight_to_rule_index[cur_weight] = idx;
...@@ -94,7 +102,8 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) { ...@@ -94,7 +102,8 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) {
} }
// 3. Sample a schedule on the distribution // 3. Sample a schedule on the distribution
int sample_weighted_index = utils::SampleUniformInt(0, cur_weight, &rand_seed_); int sample_weighted_index =
utils::SampleUniformInt(0, cur_weight, &rand_seed_);
auto iter = weight_to_rule_index.upper_bound(sample_weighted_index); auto iter = weight_to_rule_index.upper_bound(sample_weighted_index);
--iter; --iter;
...@@ -102,13 +111,15 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) { ...@@ -102,13 +111,15 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) {
int sample_rule_index = iter->second; int sample_rule_index = iter->second;
CHECK_LT(sample_rule_index, ret->applicable_rules.size()); CHECK_LT(sample_rule_index, ret->applicable_rules.size());
AutoGenRule* sample_rule = ret->applicable_rules.at(sample_rule_index); AutoGenRule* sample_rule = ret->applicable_rules.at(sample_rule_index);
VLOG(7) << "Apply rule: " << sample_rule->GetRuleName() << " with index=" << sample_weighted_index - iter->first; VLOG(7) << "Apply rule: " << sample_rule->GetRuleName()
<< " with index=" << sample_weighted_index - iter->first;
// 4. Apply the schedule change // 4. Apply the schedule change
sample_rule->Apply(sample_weighted_index - iter->first); sample_rule->Apply(sample_weighted_index - iter->first);
// 5. Remove the rule after applying it // 5. Remove the rule after applying it
if (apply_types.at(sample_rule_index) != RuleApplyType::kCannotApply) { if (apply_types.at(sample_rule_index) != RuleApplyType::kCannotApply) {
ret->applicable_rules.erase(ret->applicable_rules.begin() + sample_rule_index); ret->applicable_rules.erase(ret->applicable_rules.begin() +
sample_rule_index);
} }
return ret; return ret;
...@@ -116,17 +127,20 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) { ...@@ -116,17 +127,20 @@ SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) {
std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) { std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) {
VLOG(5) << "SearchSpace::GetRandomInitialSketch with num=" << num; VLOG(5) << "SearchSpace::GetRandomInitialSketch with num=" << num;
ir::IRSchedule init_schedule(ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()), ir::IRSchedule init_schedule(
ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()),
utils::ForkRandomState(&rand_seed_)); utils::ForkRandomState(&rand_seed_));
std::vector<AutoGenRule*> init_rules; std::vector<AutoGenRule*> init_rules;
std::transform(sketch_rules_.begin(), sketch_rules_.end(), std::back_inserter(init_rules), [](const auto& rule) { std::transform(sketch_rules_.begin(),
return rule.get(); sketch_rules_.end(),
}); std::back_inserter(init_rules),
[](const auto& rule) { return rule.get(); });
std::vector<SearchState> result; std::vector<SearchState> result;
while (result.size() < num) { while (result.size() < num) {
SearchState state(init_schedule, SearchState::NOT_INIT_COST, init_rules); SearchState state(init_schedule, SearchState::NOT_INIT_COST, init_rules);
for (int i = 0; i < init_sketch_random_depth_; ++i) { for (int i = 0; i < init_sketch_random_depth_; ++i) {
VLOG(6) << "Generating random sketch with RandomScheduleMutate at depth: " << i; VLOG(6) << "Generating random sketch with RandomScheduleMutate at depth: "
<< i;
state = RandomScheduleMutate(state); state = RandomScheduleMutate(state);
if (state->applicable_rules.empty()) { if (state->applicable_rules.empty()) {
break; break;
...@@ -134,7 +148,9 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) { ...@@ -134,7 +148,9 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) {
} }
VLOG(5) << JoinStatesDebugString( VLOG(5) << JoinStatesDebugString(
"SearchSpace::GetRandomInitialSketch-New_Sketch", {state}, /*verbose=*/VLOG_IS_ON(6)); "SearchSpace::GetRandomInitialSketch-New_Sketch",
{state},
/*verbose=*/VLOG_IS_ON(6));
result.emplace_back(std::move(state)); result.emplace_back(std::move(state));
} }
return result; return result;
...@@ -142,15 +158,18 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) { ...@@ -142,15 +158,18 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomStrategy(int num) {
std::vector<SearchState> SearchSpace::InitSketchWithRandomPrunedStrategy() { std::vector<SearchState> SearchSpace::InitSketchWithRandomPrunedStrategy() {
VLOG(5) << "SearchSpace::InitSketchWithRandomPrunedStrategy"; VLOG(5) << "SearchSpace::InitSketchWithRandomPrunedStrategy";
ir::IRSchedule init_schedule(ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()), ir::IRSchedule init_schedule(
ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()),
utils::ForkRandomState(&rand_seed_)); utils::ForkRandomState(&rand_seed_));
auto all_blocks = init_schedule.GetAllBlocks(); auto all_blocks = init_schedule.GetAllBlocks();
auto block_sampler = BlockSampler::Make(all_blocks, true, "probabilistic", utils::ForkRandomState(&rand_seed_)); auto block_sampler = BlockSampler::Make(
all_blocks, true, "probabilistic", utils::ForkRandomState(&rand_seed_));
std::vector<AutoGenRule*> init_rules; std::vector<AutoGenRule*> init_rules;
std::transform(sketch_rules_.begin(), sketch_rules_.end() - 1, std::back_inserter(init_rules), [](const auto& rule) { std::transform(sketch_rules_.begin(),
return rule.get(); sketch_rules_.end() - 1,
}); std::back_inserter(init_rules),
[](const auto& rule) { return rule.get(); });
CHECK(init_rules.size() > 0) << "number of init rules cannot be 0"; CHECK(init_rules.size() > 0) << "number of init rules cannot be 0";
SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {}); SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {});
...@@ -159,7 +178,8 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomPrunedStrategy() { ...@@ -159,7 +178,8 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomPrunedStrategy() {
std::vector<SearchState>* p_states_next = &states_buf2; std::vector<SearchState>* p_states_next = &states_buf2;
int total_steps = 0, steps; int total_steps = 0, steps;
std::string block_name; std::string block_name;
while ("" != (block_name = block_sampler->NextBlock()) && total_steps < init_sketch_random_depth_) { while ("" != (block_name = block_sampler->NextBlock()) &&
total_steps < init_sketch_random_depth_) {
steps = utils::SampleUniformInt(1, init_rules.size() + 1, &rand_seed_); steps = utils::SampleUniformInt(1, init_rules.size() + 1, &rand_seed_);
if (total_steps + steps > init_sketch_random_depth_) { if (total_steps + steps > init_sketch_random_depth_) {
steps = init_sketch_random_depth_ - total_steps; steps = init_sketch_random_depth_ - total_steps;
...@@ -167,29 +187,39 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomPrunedStrategy() { ...@@ -167,29 +187,39 @@ std::vector<SearchState> SearchSpace::InitSketchWithRandomPrunedStrategy() {
total_steps += steps; total_steps += steps;
p_states_next->clear(); p_states_next->clear();
for (const auto& state : *p_states_cur) { for (const auto& state : *p_states_cur) {
auto rule_sampler = RuleSampler::Make(init_rules, true, "probabilistic", utils::ForkRandomState(&rand_seed_)); auto rule_sampler =
auto new_states = ApplySketchRule(state, block_name, rule_sampler.get(), steps, false, 1); RuleSampler::Make(init_rules,
p_states_next->insert(p_states_next->end(), new_states.begin(), new_states.end()); true,
"probabilistic",
utils::ForkRandomState(&rand_seed_));
auto new_states = ApplySketchRule(
state, block_name, rule_sampler.get(), steps, false, 1);
p_states_next->insert(
p_states_next->end(), new_states.begin(), new_states.end());
} }
std::swap(p_states_cur, p_states_next); std::swap(p_states_cur, p_states_next);
} }
VLOG(5) << JoinStatesDebugString( VLOG(5) << JoinStatesDebugString(
"SearchSpace::InitSketchWithRandomPrunedStrategy", *p_states_cur, /*verbose=*/VLOG_IS_ON(6)); "SearchSpace::InitSketchWithRandomPrunedStrategy",
*p_states_cur,
/*verbose=*/VLOG_IS_ON(6));
return *p_states_cur; return *p_states_cur;
} }
std::vector<SearchState> SearchSpace::InitSketchWithRulePrunedStrategy() { std::vector<SearchState> SearchSpace::InitSketchWithRulePrunedStrategy() {
VLOG(5) << "SearchSpace::InitSketchWithRulePrunedStrategy"; VLOG(5) << "SearchSpace::InitSketchWithRulePrunedStrategy";
ir::IRSchedule init_schedule(ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()), ir::IRSchedule init_schedule(
ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()),
utils::ForkRandomState(&rand_seed_)); utils::ForkRandomState(&rand_seed_));
auto all_blocks = init_schedule.GetAllBlocks(); auto all_blocks = init_schedule.GetAllBlocks();
std::reverse(all_blocks.begin(), all_blocks.end()); std::reverse(all_blocks.begin(), all_blocks.end());
auto block_sampler = BlockSampler::Make(all_blocks, true, "traversal"); auto block_sampler = BlockSampler::Make(all_blocks, true, "traversal");
std::vector<AutoGenRule*> init_rules; std::vector<AutoGenRule*> init_rules;
std::transform(sketch_rules_.begin(), sketch_rules_.end() - 1, std::back_inserter(init_rules), [](const auto& rule) { std::transform(sketch_rules_.begin(),
return rule.get(); sketch_rules_.end() - 1,
}); std::back_inserter(init_rules),
[](const auto& rule) { return rule.get(); });
CHECK(init_rules.size() > 0) << "number of init rules cannot be 0"; CHECK(init_rules.size() > 0) << "number of init rules cannot be 0";
SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {}); SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {});
...@@ -201,17 +231,22 @@ std::vector<SearchState> SearchSpace::InitSketchWithRulePrunedStrategy() { ...@@ -201,17 +231,22 @@ std::vector<SearchState> SearchSpace::InitSketchWithRulePrunedStrategy() {
p_states_next->clear(); p_states_next->clear();
for (const auto& state : *p_states_cur) { for (const auto& state : *p_states_cur) {
auto rule_sampler = RuleSampler::Make(init_rules, true, "traversal"); auto rule_sampler = RuleSampler::Make(init_rules, true, "traversal");
auto new_states = ApplySketchRule(state, block_name, rule_sampler.get(), 0, true); auto new_states =
p_states_next->insert(p_states_next->end(), new_states.begin(), new_states.end()); ApplySketchRule(state, block_name, rule_sampler.get(), 0, true);
p_states_next->insert(
p_states_next->end(), new_states.begin(), new_states.end());
} }
std::swap(p_states_cur, p_states_next); std::swap(p_states_cur, p_states_next);
} }
VLOG(5) << JoinStatesDebugString( VLOG(5) << JoinStatesDebugString(
"SearchSpace::InitSketchWithRulePrunedStrategy", *p_states_cur, /*verbose=*/VLOG_IS_ON(6)); "SearchSpace::InitSketchWithRulePrunedStrategy",
*p_states_cur,
/*verbose=*/VLOG_IS_ON(6));
return *p_states_cur; return *p_states_cur;
} }
std::vector<SearchState> SearchSpace::GenerateSketches(int num, const std::string& strategy) { std::vector<SearchState> SearchSpace::GenerateSketches(
int num, const std::string& strategy) {
VLOG(4) << "SearchSpace::GenerateSketches with num = " << num; VLOG(4) << "SearchSpace::GenerateSketches with num = " << num;
if (strategy == "random") { if (strategy == "random") {
...@@ -239,11 +274,13 @@ std::vector<SearchState> SearchSpace::GenerateSketches(int num, const std::strin ...@@ -239,11 +274,13 @@ std::vector<SearchState> SearchSpace::GenerateSketches(int num, const std::strin
} }
} }
} }
VLOG(4) << JoinStatesDebugString("SearchSpace::GenerateSketches", result, /*verbose=*/VLOG_IS_ON(5)); VLOG(4) << JoinStatesDebugString(
"SearchSpace::GenerateSketches", result, /*verbose=*/VLOG_IS_ON(5));
return result; return result;
} }
std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state, std::vector<SearchState> SearchSpace::ApplySketchRule(
const SearchState& state,
const std::string& block_name, const std::string& block_name,
RuleSampler* rule_sampler, RuleSampler* rule_sampler,
int steps, int steps,
...@@ -252,15 +289,18 @@ std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state, ...@@ -252,15 +289,18 @@ std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state,
std::list<SearchState> layer{state}; std::list<SearchState> layer{state};
int step = 0; int step = 0;
AutoGenRule* rule; AutoGenRule* rule;
// After determining a SearchState and a block, each rule has two possibilities: apply and not apply. // After determining a SearchState and a block, each rule has two
// In all transfer spaces, select a rule at each step, and collect all possible new states arrived by apply and not // possibilities: apply and not apply. In all transfer spaces, select a rule
// apply. This forms a tree, and we can use rule pruning or random pruning to reduce the number of sketches. // at each step, and collect all possible new states arrived by apply and not
// apply. This forms a tree, and we can use rule pruning or random pruning to
// reduce the number of sketches.
VLOG(6) << "Collect the states of all transfers within steps: " << steps; VLOG(6) << "Collect the states of all transfers within steps: " << steps;
while ((step++ < steps || steps == 0) && (rule = rule_sampler->NextRule())) { while ((step++ < steps || steps == 0) && (rule = rule_sampler->NextRule())) {
VLOG(7) << "step = " << step << ", rule: " << rule->GetRuleName(); VLOG(7) << "step = " << step << ", rule: " << rule->GetRuleName();
std::list<SearchState> new_states; std::list<SearchState> new_states;
int id = 0; int id = 0;
for (std::list<SearchState>::iterator iter = layer.begin(); iter != layer.end();) { for (std::list<SearchState>::iterator iter = layer.begin();
iter != layer.end();) {
// Some rules will reduce the number of blocks, such as AutoInline, // Some rules will reduce the number of blocks, such as AutoInline,
// so we need to check whether the SearchState still has the block. // so we need to check whether the SearchState still has the block.
if (!(*iter)->ir_schedule.HasBlock(block_name)) { if (!(*iter)->ir_schedule.HasBlock(block_name)) {
...@@ -268,21 +308,26 @@ std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state, ...@@ -268,21 +308,26 @@ std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state,
continue; continue;
} }
auto type = rule->AnalyseApplyType(*iter, block_name); auto type = rule->AnalyseApplyType(*iter, block_name);
VLOG(7) << "At SearchState " << ++id VLOG(7)
<< ", apply type = " << static_cast<typename std::underlying_type<RuleApplyType>::type>(type); << "At SearchState " << ++id << ", apply type = "
<< static_cast<typename std::underlying_type<RuleApplyType>::type>(
type);
// if cannot apply the rule, skip it // if cannot apply the rule, skip it
if (type == RuleApplyType::kCannotApply) { if (type == RuleApplyType::kCannotApply) {
++iter; ++iter;
continue; continue;
} }
// if can apply the rule, apply it and determine whether to prune the branch that do not apply // if can apply the rule, apply it and determine whether to prune the
std::vector<SearchState> tmp_states = rule->ApplyOnBlock(*iter, block_name); // branch that do not apply
std::vector<SearchState> tmp_states =
rule->ApplyOnBlock(*iter, block_name);
new_states.insert(new_states.end(), tmp_states.begin(), tmp_states.end()); new_states.insert(new_states.end(), tmp_states.begin(), tmp_states.end());
bool need_prune = false; bool need_prune = false;
if (prune_by_rule) { if (prune_by_rule) {
need_prune = (type == RuleApplyType::kApplyAndPruneOtherRules); need_prune = (type == RuleApplyType::kApplyAndPruneOtherRules);
} else { } else {
need_prune = (utils::SampleUniformDouble(0, 1, &rand_seed_) < prune_probability); need_prune =
(utils::SampleUniformDouble(0, 1, &rand_seed_) < prune_probability);
} }
if (need_prune) { if (need_prune) {
iter = layer.erase(iter); iter = layer.erase(iter);
...@@ -290,10 +335,12 @@ std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state, ...@@ -290,10 +335,12 @@ std::vector<SearchState> SearchSpace::ApplySketchRule(const SearchState& state,
++iter; ++iter;
} }
} }
VLOG(7) << "apply on block: " << block_name << ", generate " << new_states.size() << " new states at step " << step; VLOG(7) << "apply on block: " << block_name << ", generate "
<< new_states.size() << " new states at step " << step;
layer.splice(layer.end(), std::move(new_states)); layer.splice(layer.end(), std::move(new_states));
} }
VLOG(6) << "apply on block: " << block_name << ", generate " << layer.size() - 1 << " more states at all"; VLOG(6) << "apply on block: " << block_name << ", generate "
<< layer.size() - 1 << " more states at all";
return std::vector<SearchState>(layer.begin(), layer.end()); return std::vector<SearchState>(layer.begin(), layer.end());
} }
......
...@@ -40,24 +40,31 @@ namespace auto_schedule { ...@@ -40,24 +40,31 @@ namespace auto_schedule {
*/ */
class SearchSpace { class SearchSpace {
public: public:
SearchSpace(const TuneTask& tune_task, utils::LinearRandomEngine::StateType rand_seed = -1); SearchSpace(const TuneTask& tune_task,
utils::LinearRandomEngine::StateType rand_seed = -1);
// Sketch mutate, returns the mutated ModuleExpr and estimited cost // Sketch mutate, returns the mutated ModuleExpr and estimited cost
virtual SearchState GetScheduleMutate(const SearchState& state, const ExprCostModel& cost_model); virtual SearchState GetScheduleMutate(const SearchState& state,
const ExprCostModel& cost_model);
/** /**
* \brief Generate sketch as initial population of evolutionary search. * \brief Generate sketch as initial population of evolutionary search.
* @param num The number of sketches to generate. * @param num The number of sketches to generate.
* @param strategy The strategy to generate sketchs, * @param strategy The strategy to generate sketchs,
* Current optional strategies are "rule_prune" or "random_prune" or "random". * Current optional strategies are "rule_prune" or "random_prune" or
* - "rule_prune": will use rules to prune and generate sketches as efficiently as possible. * "random".
* - "random_prune": will use the new interface ApplySketchRules() to simulate the random generation of sketches, * - "rule_prune": will use rules to prune and generate sketches as
* and supports the function of a rule returning multiple SearchStates and random pruning by probability. * efficiently as possible.
* - "random": will randomly select a block and a rule to apply and repeat this step several times, * - "random_prune": will use the new interface ApplySketchRules() to simulate
* however, each rule can only be used on one SearchState at most once. * the random generation of sketches, and supports the function of a rule
* returning multiple SearchStates and random pruning by probability.
* - "random": will randomly select a block and a rule to apply and repeat
* this step several times, however, each rule can only be used on one
* SearchState at most once.
* @return Generated sketchs. * @return Generated sketchs.
*/ */
virtual std::vector<SearchState> GenerateSketches(int num, const std::string& strategy); virtual std::vector<SearchState> GenerateSketches(
int num, const std::string& strategy);
private: private:
// TODO(zhhsplendid): mutate by manual schedule. // TODO(zhhsplendid): mutate by manual schedule.
...@@ -69,20 +76,24 @@ class SearchSpace { ...@@ -69,20 +76,24 @@ class SearchSpace {
// Generate num sketchs, each with several rounds of SketchMutate // Generate num sketchs, each with several rounds of SketchMutate
std::vector<SearchState> InitSketchWithRandomStrategy(int num); std::vector<SearchState> InitSketchWithRandomStrategy(int num);
// Generate sketch pruned randomly as initial population of evolutionary search // Generate sketch pruned randomly as initial population of evolutionary
// search
std::vector<SearchState> InitSketchWithRandomPrunedStrategy(); std::vector<SearchState> InitSketchWithRandomPrunedStrategy();
// Generate sketch pruned by rules as initial population of evolutionary search // Generate sketch pruned by rules as initial population of evolutionary
// search
std::vector<SearchState> InitSketchWithRulePrunedStrategy(); std::vector<SearchState> InitSketchWithRulePrunedStrategy();
/** /**
* @brief Collect the new states that may be transferred to after applying several rules on a block from a certain * @brief Collect the new states that may be transferred to after applying
* state. * several rules on a block from a certain state.
* @param state Starting point of state transition. * @param state Starting point of state transition.
* @param block_name Name of the block to apply the rules to. * @param block_name Name of the block to apply the rules to.
* @param rule_sampler Sampler that samples the new rule to apply on the block. * @param rule_sampler Sampler that samples the new rule to apply on the
* block.
* @param steps Number of steps to apply the rule. * @param steps Number of steps to apply the rule.
* @param prune_by_rule If true, prune the state transition tree by rule, otherwise prune randomly. * @param prune_by_rule If true, prune the state transition tree by rule,
* otherwise prune randomly.
* @param prune_probability Pruning probability of random pruning. * @param prune_probability Pruning probability of random pruning.
*/ */
std::vector<SearchState> ApplySketchRule(const SearchState& state, std::vector<SearchState> ApplySketchRule(const SearchState& state,
......
...@@ -29,7 +29,9 @@ ...@@ -29,7 +29,9 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
SearchState::SearchState(ir::IRSchedule ir_sch, float cost, const std::vector<AutoGenRule*>& rules) SearchState::SearchState(ir::IRSchedule ir_sch,
float cost,
const std::vector<AutoGenRule*>& rules)
: common::Shared<_SearchState_>(common::make_shared<_SearchState_>()) { : common::Shared<_SearchState_>(common::make_shared<_SearchState_>()) {
auto* state = get(); auto* state = get();
state->ir_schedule = std::move(ir_sch); state->ir_schedule = std::move(ir_sch);
...@@ -37,13 +39,16 @@ SearchState::SearchState(ir::IRSchedule ir_sch, float cost, const std::vector<Au ...@@ -37,13 +39,16 @@ SearchState::SearchState(ir::IRSchedule ir_sch, float cost, const std::vector<Au
state->predicted_cost = cost; state->predicted_cost = cost;
} }
SearchState SearchState::Copy() const { return SearchState((*this)->ir_schedule, (*this)->predicted_cost, {}); } SearchState SearchState::Copy() const {
return SearchState((*this)->ir_schedule, (*this)->predicted_cost, {});
}
std::string _SearchState_::DebugString() const { std::string _SearchState_::DebugString() const {
const auto& exprs = ir_schedule.GetModule().GetExprs(); const auto& exprs = ir_schedule.GetModule().GetExprs();
std::stringstream module_stream; std::stringstream module_stream;
for (auto i = 0; i < exprs.size(); ++i) { for (auto i = 0; i < exprs.size(); ++i) {
module_stream << "Expr " << i << " {\n" << exprs.at(i) << "\n} // end Expr"; module_stream << "Expr " << i << " {\n"
<< exprs.at(i) << "\n} // end Expr";
} }
const char* fmt_str = R"ROC( const char* fmt_str = R"ROC(
...@@ -55,8 +60,10 @@ ScheduleDesc { ...@@ -55,8 +60,10 @@ ScheduleDesc {
} // end ScheduleDesc } // end ScheduleDesc
predicted_cost: %f)ROC"; predicted_cost: %f)ROC";
return utils::StringFormat( return utils::StringFormat(fmt_str,
fmt_str, module_stream.str().c_str(), ir_schedule.GetTraceDesc().DebugString().c_str(), predicted_cost); module_stream.str().c_str(),
ir_schedule.GetTraceDesc().DebugString().c_str(),
predicted_cost);
} }
bool operator<(const SearchState& left, const SearchState& right) { bool operator<(const SearchState& left, const SearchState& right) {
...@@ -119,7 +126,8 @@ size_t SearchStateHash::operator()(const SearchState& s) const { ...@@ -119,7 +126,8 @@ size_t SearchStateHash::operator()(const SearchState& s) const {
return hash_key; return hash_key;
} }
bool SearchStateEqual::operator()(const SearchState& lhs, const SearchState& rhs) const { bool SearchStateEqual::operator()(const SearchState& lhs,
const SearchState& rhs) const {
const auto& lhs_exprs = lhs->ir_schedule.GetModule().GetExprs(); const auto& lhs_exprs = lhs->ir_schedule.GetModule().GetExprs();
const auto& rhs_exprs = rhs->ir_schedule.GetModule().GetExprs(); const auto& rhs_exprs = rhs->ir_schedule.GetModule().GetExprs();
// compare exprs size firstly // compare exprs size firstly
...@@ -127,20 +135,24 @@ bool SearchStateEqual::operator()(const SearchState& lhs, const SearchState& rhs ...@@ -127,20 +135,24 @@ bool SearchStateEqual::operator()(const SearchState& lhs, const SearchState& rhs
// compare every expr one by one with ir::IrEqualVisitor // compare every expr one by one with ir::IrEqualVisitor
for (int i = 0; i < lhs_exprs.size(); ++i) { for (int i = 0; i < lhs_exprs.size(); ++i) {
ir::IrEqualVisitor compartor(/*allow_name_suffix_diff=*/true); // ignore suffix difference in name ir::IrEqualVisitor compartor(
/*allow_name_suffix_diff=*/true); // ignore suffix difference in name
if (!compartor.Compare(lhs_exprs[i], rhs_exprs[i])) return false; if (!compartor.Compare(lhs_exprs[i], rhs_exprs[i])) return false;
} }
return true; return true;
} }
std::string JoinStatesDebugString(const std::string& title, const std::vector<SearchState>& states, bool verbose) { std::string JoinStatesDebugString(const std::string& title,
const std::vector<SearchState>& states,
bool verbose) {
std::stringstream ss; std::stringstream ss;
ss << title << " states size:" << states.size() << "\n"; ss << title << " states size:" << states.size() << "\n";
SearchStateHash state_hasher; SearchStateHash state_hasher;
for (size_t i = 0; i < states.size(); ++i) { for (size_t i = 0; i < states.size(); ++i) {
uint64_t hash_key = state_hasher(states[i]); uint64_t hash_key = state_hasher(states[i]);
if (verbose) { if (verbose) {
ss << "\tState-" << i << " hash:" << hash_key << "\t content:------>" << states[i]->DebugString() << "\n<------"; ss << "\tState-" << i << " hash:" << hash_key << "\t content:------>"
<< states[i]->DebugString() << "\n<------";
} else { } else {
ss << "\tState-" << i << " hash:" << hash_key << "\n"; ss << "\tState-" << i << " hash:" << hash_key << "\n";
} }
......
...@@ -35,7 +35,9 @@ class SearchState : public common::Shared<_SearchState_> { ...@@ -35,7 +35,9 @@ class SearchState : public common::Shared<_SearchState_> {
public: public:
SearchState() = default; SearchState() = default;
// create a new SearchState // create a new SearchState
explicit SearchState(ir::IRSchedule ir_sch, float cost = NOT_INIT_COST, const std::vector<AutoGenRule*>& rules = {}); explicit SearchState(ir::IRSchedule ir_sch,
float cost = NOT_INIT_COST,
const std::vector<AutoGenRule*>& rules = {});
// Constant standing for a cost not being initialized // Constant standing for a cost not being initialized
static constexpr float NOT_INIT_COST = std::numeric_limits<float>::max(); static constexpr float NOT_INIT_COST = std::numeric_limits<float>::max();
...@@ -62,12 +64,14 @@ struct _SearchState_ : public common::Object { ...@@ -62,12 +64,14 @@ struct _SearchState_ : public common::Object {
static constexpr char* __type_info__ = "auto_schedule_state"; static constexpr char* __type_info__ = "auto_schedule_state";
}; };
// SearchStateHash hash functor that visits every AST node and combine their hash of node_type in dfs order // SearchStateHash hash functor that visits every AST node and combine their
// hash of node_type in dfs order
struct SearchStateHash { struct SearchStateHash {
size_t operator()(const SearchState& s) const; size_t operator()(const SearchState& s) const;
}; };
// SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST struct and fields // SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST
// struct and fields
struct SearchStateEqual { struct SearchStateEqual {
bool operator()(const SearchState& lhs, const SearchState& rhs) const; bool operator()(const SearchState& lhs, const SearchState& rhs) const;
}; };
......
...@@ -36,15 +36,34 @@ TEST(TestSearchState, SearchStateHash_Equal) { ...@@ -36,15 +36,34 @@ TEST(TestSearchState, SearchStateHash_Equal) {
{M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); {M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C");
cinn::common::Context::Global().ResetNameId(); cinn::common::Context::Global().ResetNameId();
auto a_plus_const_funcs_1 = auto a_plus_const_funcs_1 = lang::LowerVec("A_plus_const",
lang::LowerVec("A_plus_const", poly::CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); poly::CreateStages({A, B}),
{A, B},
{},
{},
nullptr,
target,
true);
cinn::common::Context::Global().ResetNameId(); cinn::common::Context::Global().ResetNameId();
auto a_plus_const_funcs_2 = auto a_plus_const_funcs_2 = lang::LowerVec("A_plus_const",
lang::LowerVec("A_plus_const", poly::CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); poly::CreateStages({A, B}),
{A, B},
{},
{},
nullptr,
target,
true);
cinn::common::Context::Global().ResetNameId(); cinn::common::Context::Global().ResetNameId();
auto a_plus_b_funcs = lang::LowerVec("A_plus_B", poly::CreateStages({A, C}), {A, C}, {}, {}, nullptr, target, true); auto a_plus_b_funcs = lang::LowerVec("A_plus_B",
poly::CreateStages({A, C}),
{A, C},
{},
{},
nullptr,
target,
true);
std::string a_plus_const_funcs_1_str = R"ROC(function A_plus_const (_A, _B) std::string a_plus_const_funcs_1_str = R"ROC(function A_plus_const (_A, _B)
{ {
...@@ -114,19 +133,25 @@ TEST(TestSearchState, SearchStateHash_Equal) { ...@@ -114,19 +133,25 @@ TEST(TestSearchState, SearchStateHash_Equal) {
})ROC"; })ROC";
ASSERT_EQ(a_plus_const_funcs_1.size(), 1); ASSERT_EQ(a_plus_const_funcs_1.size(), 1);
EXPECT_EQ(a_plus_const_funcs_1_str, utils::GetStreamCnt(a_plus_const_funcs_1.front())); EXPECT_EQ(a_plus_const_funcs_1_str,
utils::GetStreamCnt(a_plus_const_funcs_1.front()));
ASSERT_EQ(a_plus_const_funcs_2.size(), 1); ASSERT_EQ(a_plus_const_funcs_2.size(), 1);
EXPECT_EQ(a_plus_const_funcs_2_str, utils::GetStreamCnt(a_plus_const_funcs_2.front())); EXPECT_EQ(a_plus_const_funcs_2_str,
utils::GetStreamCnt(a_plus_const_funcs_2.front()));
ASSERT_EQ(a_plus_b_funcs.size(), 1); ASSERT_EQ(a_plus_b_funcs.size(), 1);
EXPECT_EQ(a_plus_b_funcs_str, utils::GetStreamCnt(a_plus_b_funcs.front())); EXPECT_EQ(a_plus_b_funcs_str, utils::GetStreamCnt(a_plus_b_funcs.front()));
SearchState a_plus_const_state1(ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_1.front()->body}))); SearchState a_plus_const_state1(
SearchState a_plus_const_state2(ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_2.front()->body}))); ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_1.front()->body})));
SearchState a_plus_b_state(ir::IRSchedule(ir::ModuleExpr({a_plus_b_funcs.front()->body}))); SearchState a_plus_const_state2(
ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_2.front()->body})));
SearchState a_plus_b_state(
ir::IRSchedule(ir::ModuleExpr({a_plus_b_funcs.front()->body})));
SearchStateHash hash_functor; SearchStateHash hash_functor;
SearchStateEqual equal_functor; SearchStateEqual equal_functor;
ASSERT_EQ(hash_functor(a_plus_const_state1), hash_functor(a_plus_const_state2)); ASSERT_EQ(hash_functor(a_plus_const_state1),
hash_functor(a_plus_const_state2));
ASSERT_TRUE(equal_functor(a_plus_const_state1, a_plus_const_state2)); ASSERT_TRUE(equal_functor(a_plus_const_state1, a_plus_const_state2));
ASSERT_NE(hash_functor(a_plus_const_state1), hash_functor(a_plus_b_state)); ASSERT_NE(hash_functor(a_plus_const_state1), hash_functor(a_plus_b_state));
ASSERT_FALSE(equal_functor(a_plus_const_state1, a_plus_b_state)); ASSERT_FALSE(equal_functor(a_plus_const_state1, a_plus_b_state));
......
...@@ -41,7 +41,8 @@ DECLARE_bool(auto_schedule_use_cost_model); ...@@ -41,7 +41,8 @@ DECLARE_bool(auto_schedule_use_cost_model);
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
EvolutionarySearch::EvolutionarySearch(const TuneTask& tune_task, EvolutionarySearch::EvolutionarySearch(
const TuneTask& tune_task,
const ExprCostModel& cost_model, const ExprCostModel& cost_model,
Database* database, Database* database,
utils::LinearRandomEngine::StateType rand_seed, utils::LinearRandomEngine::StateType rand_seed,
...@@ -51,7 +52,8 @@ EvolutionarySearch::EvolutionarySearch(const TuneTask& tune_task, ...@@ -51,7 +52,8 @@ EvolutionarySearch::EvolutionarySearch(const TuneTask& tune_task,
database_(database), database_(database),
rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)), rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)),
mutators_(mutate_rules) { mutators_(mutate_rules) {
search_space_ = std::make_unique<SearchSpace>(tune_task, utils::ForkRandomState(&rand_seed_)); search_space_ = std::make_unique<SearchSpace>(
tune_task, utils::ForkRandomState(&rand_seed_));
if (mutators_.empty()) { if (mutators_.empty()) {
mutators_.push_back(std::make_tuple("mutate_tile_size", 1.0)); mutators_.push_back(std::make_tuple("mutate_tile_size", 1.0));
} }
...@@ -59,7 +61,8 @@ EvolutionarySearch::EvolutionarySearch(const TuneTask& tune_task, ...@@ -59,7 +61,8 @@ EvolutionarySearch::EvolutionarySearch(const TuneTask& tune_task,
for (const auto& mutator : mutators_) { for (const auto& mutator : mutators_) {
if (std::get<1>(mutator) > 0) { if (std::get<1>(mutator) > 0) {
accum_weight += std::get<1>(mutator); accum_weight += std::get<1>(mutator);
weighted_mutators_.insert(std::make_pair(accum_weight, MutateRule::Make(std::get<0>(mutator)))); weighted_mutators_.insert(
std::make_pair(accum_weight, MutateRule::Make(std::get<0>(mutator))));
} }
} }
...@@ -72,46 +75,66 @@ SearchState EvolutionarySearch::SearchModuleExpr(const TuningOptions& options) { ...@@ -72,46 +75,66 @@ SearchState EvolutionarySearch::SearchModuleExpr(const TuningOptions& options) {
return SearchModuleExprBests(options)[0]; return SearchModuleExprBests(options)[0];
} }
std::vector<SearchState> EvolutionarySearch::SearchModuleExprBests(const TuningOptions& options) { std::vector<SearchState> EvolutionarySearch::SearchModuleExprBests(
VLOG(4) << "start SearchModuleExprBests with initial statistics: visited_candidates size=" const TuningOptions& options) {
VLOG(4) << "start SearchModuleExprBests with initial statistics: "
"visited_candidates size="
<< visited_candidates_.size(); << visited_candidates_.size();
std::vector<SearchState> init_population; std::vector<SearchState> init_population;
std::vector<SearchState> topk_from_database = GetTopKCandidatesFromDatabase(options.evolution_pick_database_topk); std::vector<SearchState> topk_from_database =
GetTopKCandidatesFromDatabase(options.evolution_pick_database_topk);
VLOG(4) << JoinStatesDebugString( VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::GetTopKCandidatesFromDatabase", topk_from_database, /*verbose=*/VLOG_IS_ON(5)); "EvolutionarySearch::GetTopKCandidatesFromDatabase",
int init_num = options.evolution_init_population_num - topk_from_database.size(); topk_from_database,
/*verbose=*/VLOG_IS_ON(5));
int init_num =
options.evolution_init_population_num - topk_from_database.size();
std::vector<SearchState> init_sketch = InitSketch(init_num, "rule_prune"); std::vector<SearchState> init_sketch = InitSketch(init_num, "rule_prune");
VLOG(4) << JoinStatesDebugString("EvolutionarySearch::InitSketch", init_sketch, /*verbose=*/VLOG_IS_ON(5)); VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::InitSketch", init_sketch, /*verbose=*/VLOG_IS_ON(5));
init_population.insert(init_population.end(), topk_from_database.begin(), topk_from_database.end()); init_population.insert(init_population.end(),
init_population.insert(init_population.end(), init_sketch.begin(), init_sketch.end()); topk_from_database.begin(),
topk_from_database.end());
init_population.insert(
init_population.end(), init_sketch.begin(), init_sketch.end());
std::vector<SearchState> picked_bests = std::vector<SearchState> picked_bests =
Evolve(init_population, options.evolution_cross_over_num, options.num_samples_per_iteration); Evolve(init_population,
VLOG(4) << JoinStatesDebugString("EvolutionarySearch::Evolve", picked_bests, /*verbose=*/VLOG_IS_ON(5)); options.evolution_cross_over_num,
options.num_samples_per_iteration);
VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::Evolve", picked_bests, /*verbose=*/VLOG_IS_ON(5));
return picked_bests; return picked_bests;
} }
std::vector<SearchState> EvolutionarySearch::SearchModuleExprEpsGreedy(const TuningOptions& options) { std::vector<SearchState> EvolutionarySearch::SearchModuleExprEpsGreedy(
const TuningOptions& options) {
std::vector<SearchState> picked_bests = SearchModuleExprBests(options); std::vector<SearchState> picked_bests = SearchModuleExprBests(options);
int random_num = options.evolution_init_population_num - options.evolution_pick_database_topk; int random_num = options.evolution_init_population_num -
auto results = PickNextGenerationEpsGreedy(picked_bests, options.evolution_pick_database_topk;
auto results =
PickNextGenerationEpsGreedy(picked_bests,
InitSketch(random_num, "random_prune"), InitSketch(random_num, "random_prune"),
options.num_samples_per_iteration, options.num_samples_per_iteration,
options.evolution_eps_greedy); options.evolution_eps_greedy);
VLOG(4) << JoinStatesDebugString( VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::PickNextGenerationEpsGreedy", results, /*verbose=*/VLOG_IS_ON(5)); "EvolutionarySearch::PickNextGenerationEpsGreedy",
results,
/*verbose=*/VLOG_IS_ON(5));
return results; return results;
} }
std::vector<SearchState> EvolutionarySearch::GetTopKCandidatesFromDatabase(int topk) { std::vector<SearchState> EvolutionarySearch::GetTopKCandidatesFromDatabase(
int topk) {
std::vector<SearchState> results; std::vector<SearchState> results;
const auto& task_key = tune_task_.serialized_key; const auto& task_key = tune_task_.serialized_key;
auto records = database_->GetTopK(task_key, topk); auto records = database_->GetTopK(task_key, topk);
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (auto&& record : records) { for (auto&& record : records) {
ir::IRSchedule ir_sch(optim::IRCopy(task_registry->Get(task_key)->module_expr), ir::IRSchedule ir_sch(
optim::IRCopy(task_registry->Get(task_key)->module_expr),
utils::ForkRandomState(&rand_seed_)); utils::ForkRandomState(&rand_seed_));
ir::ScheduleDesc::ReplayWithProto(record.trace, &ir_sch); ir::ScheduleDesc::ReplayWithProto(record.trace, &ir_sch);
results.emplace_back(SearchState(std::move(ir_sch), record.predicted_cost)); results.emplace_back(SearchState(std::move(ir_sch), record.predicted_cost));
...@@ -119,7 +142,8 @@ std::vector<SearchState> EvolutionarySearch::GetTopKCandidatesFromDatabase(int t ...@@ -119,7 +142,8 @@ std::vector<SearchState> EvolutionarySearch::GetTopKCandidatesFromDatabase(int t
return results; return results;
} }
void ApplyPostScheduleRules(ir::IRSchedule* schedule, void ApplyPostScheduleRules(
ir::IRSchedule* schedule,
const std::vector<std::unique_ptr<PostScheduleRule>>& post_schedule_rules) { const std::vector<std::unique_ptr<PostScheduleRule>>& post_schedule_rules) {
schedule->TagPostSchedule(); schedule->TagPostSchedule();
for (const auto& post_rule : post_schedule_rules) { for (const auto& post_rule : post_schedule_rules) {
...@@ -127,25 +151,33 @@ void ApplyPostScheduleRules(ir::IRSchedule* schedule, ...@@ -127,25 +151,33 @@ void ApplyPostScheduleRules(ir::IRSchedule* schedule,
} }
} }
std::vector<SearchState> EvolutionarySearch::InitSketch(int num, const std::string& strategy) { std::vector<SearchState> EvolutionarySearch::InitSketch(
int num, const std::string& strategy) {
VLOG(4) << "InitSketch with num:" << num << ", strategy: " << strategy; VLOG(4) << "InitSketch with num:" << num << ", strategy: " << strategy;
std::vector<SearchState> states = search_space_->GenerateSketches(num, strategy); std::vector<SearchState> states =
search_space_->GenerateSketches(num, strategy);
auto post_schedule_fn = [this, &states](int index) { auto post_schedule_fn = [this, &states](int index) {
ApplyPostScheduleRules(&states[index]->ir_schedule, post_schedule_rules_); ApplyPostScheduleRules(&states[index]->ir_schedule, post_schedule_rules_);
}; };
utils::parallel_run(post_schedule_fn, utils::SequenceDispatcher(0, states.size()), states.size()); utils::parallel_run(post_schedule_fn,
utils::SequenceDispatcher(0, states.size()),
states.size());
return states; return states;
} }
SearchState EvolutionarySearch::CrossOver(const SearchState& state1, const SearchState& state2) { SearchState EvolutionarySearch::CrossOver(const SearchState& state1,
const SearchState& state2) {
// TODO(CtfGo): tracing CrossOver with IRSchedule // TODO(CtfGo): tracing CrossOver with IRSchedule
std::vector<ir::Expr> cross_over_exprs; std::vector<ir::Expr> cross_over_exprs;
std::vector<ir::Expr> father_exprs = state1->ir_schedule.GetModule().GetExprs(); std::vector<ir::Expr> father_exprs =
std::vector<ir::Expr> mother_exprs = state2->ir_schedule.GetModule().GetExprs(); state1->ir_schedule.GetModule().GetExprs();
std::vector<ir::Expr> mother_exprs =
state2->ir_schedule.GetModule().GetExprs();
CHECK_EQ(father_exprs.size(), mother_exprs.size()) CHECK_EQ(father_exprs.size(), mother_exprs.size())
<< "CrossOver ModuleExpr in EvolutionarySearch must have same number of AST"; << "CrossOver ModuleExpr in EvolutionarySearch must have same number of "
"AST";
for (size_t i = 0; i < father_exprs.size(); ++i) { for (size_t i = 0; i < father_exprs.size(); ++i) {
if (utils::SampleUniformInt(0, 2, &rand_seed_) == 0) { if (utils::SampleUniformInt(0, 2, &rand_seed_) == 0) {
...@@ -154,16 +186,22 @@ SearchState EvolutionarySearch::CrossOver(const SearchState& state1, const Searc ...@@ -154,16 +186,22 @@ SearchState EvolutionarySearch::CrossOver(const SearchState& state1, const Searc
cross_over_exprs.push_back(optim::IRCopy(mother_exprs[i])); cross_over_exprs.push_back(optim::IRCopy(mother_exprs[i]));
} }
} }
auto res = SearchState(ir::IRSchedule(ir::ModuleExpr(cross_over_exprs), utils::ForkRandomState(&rand_seed_))); auto res = SearchState(ir::IRSchedule(ir::ModuleExpr(cross_over_exprs),
utils::ForkRandomState(&rand_seed_)));
if (FLAGS_auto_schedule_use_cost_model) { if (FLAGS_auto_schedule_use_cost_model) {
res->predicted_cost = cost_model_.Predict(res->ir_schedule.GetModule(), tune_task_.target); res->predicted_cost =
cost_model_.Predict(res->ir_schedule.GetModule(), tune_task_.target);
} }
VLOG(5) << JoinStatesDebugString("EvolutionarySearch::CrossOver", {state1, state2, res}, /*verbose=*/VLOG_IS_ON(6)); VLOG(5) << JoinStatesDebugString("EvolutionarySearch::CrossOver",
{state1, state2, res},
/*verbose=*/VLOG_IS_ON(6));
return res; return res;
} }
SearchState EvolutionarySearch::Mutate(const SearchState& state, utils::LinearRandomEngine::StateType* rand_seed) { SearchState EvolutionarySearch::Mutate(
CHECK_GT(weighted_mutators_.size(), 0) << "There is no mutate rule can be applied."; const SearchState& state, utils::LinearRandomEngine::StateType* rand_seed) {
CHECK_GT(weighted_mutators_.size(), 0)
<< "There is no mutate rule can be applied.";
double accu_weight = (weighted_mutators_.rbegin())->first; double accu_weight = (weighted_mutators_.rbegin())->first;
CHECK_GT(accu_weight, 0) << "The accumulate weight must be greater than 0."; CHECK_GT(accu_weight, 0) << "The accumulate weight must be greater than 0.";
// sample a mutate rule // sample a mutate rule
...@@ -174,24 +212,31 @@ SearchState EvolutionarySearch::Mutate(const SearchState& state, utils::LinearRa ...@@ -174,24 +212,31 @@ SearchState EvolutionarySearch::Mutate(const SearchState& state, utils::LinearRa
// apply mutation on the trace of SearchState // apply mutation on the trace of SearchState
auto trace = state->ir_schedule.GetTraceDesc(); auto trace = state->ir_schedule.GetTraceDesc();
auto new_trace = mutator->Apply(trace, rand_seed); auto new_trace = mutator->Apply(trace, rand_seed);
// replay the mutated trace on original ModuleExpr to generate a new ir_schedule // replay the mutated trace on original ModuleExpr to generate a new
// ir_schedule
const auto& task_key = tune_task_.serialized_key; const auto& task_key = tune_task_.serialized_key;
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
ir::IRSchedule new_ir_sch(optim::IRCopy(task_registry->Get(task_key)->module_expr), ir::IRSchedule new_ir_sch(
optim::IRCopy(task_registry->Get(task_key)->module_expr),
utils::ForkRandomState(rand_seed)); utils::ForkRandomState(rand_seed));
new_trace.Replay(&new_ir_sch, true); new_trace.Replay(&new_ir_sch, true);
ApplyPostScheduleRules(&new_ir_sch, post_schedule_rules_); ApplyPostScheduleRules(&new_ir_sch, post_schedule_rules_);
auto res = SearchState(std::move(new_ir_sch)); auto res = SearchState(std::move(new_ir_sch));
VLOG(5) << JoinStatesDebugString("EvolutionarySearch::Mutate", {state, res}, /*verbose=*/VLOG_IS_ON(6)); VLOG(5) << JoinStatesDebugString(
"EvolutionarySearch::Mutate", {state, res}, /*verbose=*/VLOG_IS_ON(6));
return res; return res;
} }
std::vector<SearchState> EvolutionarySearch::Evolve(const std::vector<SearchState>& population, std::vector<SearchState> EvolutionarySearch::Evolve(
const std::vector<SearchState>& population,
int cross_over_num, int cross_over_num,
int ret_num) { int ret_num) {
VLOG(4) << utils::StringFormat( VLOG(4) << utils::StringFormat(
"Evolve with population size=%lu,cross_over_num:%lu,ret_num:%lu", population.size(), cross_over_num, ret_num); "Evolve with population size=%lu,cross_over_num:%lu,ret_num:%lu",
population.size(),
cross_over_num,
ret_num);
int generation_num = population.size(); int generation_num = population.size();
if (generation_num == 0) { if (generation_num == 0) {
return std::vector<SearchState>(); return std::vector<SearchState>();
...@@ -199,40 +244,56 @@ std::vector<SearchState> EvolutionarySearch::Evolve(const std::vector<SearchStat ...@@ -199,40 +244,56 @@ std::vector<SearchState> EvolutionarySearch::Evolve(const std::vector<SearchStat
// init evolution // init evolution
std::vector<SearchState> evolution(population); std::vector<SearchState> evolution(population);
for (SearchState& search_state : evolution) { for (SearchState& search_state : evolution) {
if (search_state->predicted_cost == SearchState::NOT_INIT_COST && FLAGS_auto_schedule_use_cost_model) { if (search_state->predicted_cost == SearchState::NOT_INIT_COST &&
search_state->predicted_cost = cost_model_.Predict(search_state->ir_schedule.GetModule(), tune_task_.target); FLAGS_auto_schedule_use_cost_model) {
search_state->predicted_cost = cost_model_.Predict(
search_state->ir_schedule.GetModule(), tune_task_.target);
} }
} }
VLOG(4) << JoinStatesDebugString("EvolutionarySearch::Evolve: Init evolution:", evolution, /*verbose=*/VLOG_IS_ON(5)); VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::Evolve: Init evolution:",
evolution,
/*verbose=*/VLOG_IS_ON(5));
// cross over // cross over
for (int i = 0; i < cross_over_num; ++i) { for (int i = 0; i < cross_over_num; ++i) {
int first_rand_idx = utils::SampleUniformInt(0, generation_num, &rand_seed_); int first_rand_idx =
int second_rand_idx = utils::SampleUniformInt(0, generation_num, &rand_seed_); utils::SampleUniformInt(0, generation_num, &rand_seed_);
int second_rand_idx =
utils::SampleUniformInt(0, generation_num, &rand_seed_);
while (first_rand_idx == second_rand_idx) { while (first_rand_idx == second_rand_idx) {
second_rand_idx = utils::SampleUniformInt(0, generation_num, &rand_seed_); second_rand_idx = utils::SampleUniformInt(0, generation_num, &rand_seed_);
} }
evolution.push_back(CrossOver(population[first_rand_idx], population[second_rand_idx])); evolution.push_back(
CrossOver(population[first_rand_idx], population[second_rand_idx]));
} }
VLOG(4) << JoinStatesDebugString( VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::Evolve: after CrossOver evolution:", evolution, /*verbose=*/VLOG_IS_ON(5)); "EvolutionarySearch::Evolve: after CrossOver evolution:",
evolution,
/*verbose=*/VLOG_IS_ON(5));
// mutate // mutate
std::vector<SearchState> mutated_individuals(evolution.size()); std::vector<SearchState> mutated_individuals(evolution.size());
std::vector<utils::LinearRandomEngine::StateType> rand_seeds(evolution.size()); std::vector<utils::LinearRandomEngine::StateType> rand_seeds(
evolution.size());
for (int i = 0; i < rand_seeds.size(); ++i) { for (int i = 0; i < rand_seeds.size(); ++i) {
rand_seeds[i] = utils::ForkRandomState(&rand_seed_); rand_seeds[i] = utils::ForkRandomState(&rand_seed_);
} }
auto mutate_fn = [this, &evolution, &mutated_individuals, &rand_seeds](int index) { auto mutate_fn = [this, &evolution, &mutated_individuals, &rand_seeds](
int index) {
mutated_individuals[index] = Mutate(evolution[index], &rand_seeds[index]); mutated_individuals[index] = Mutate(evolution[index], &rand_seeds[index]);
}; };
utils::parallel_run(mutate_fn, utils::SequenceDispatcher(0, evolution.size()), evolution.size()); utils::parallel_run(mutate_fn,
utils::SequenceDispatcher(0, evolution.size()),
evolution.size());
if (FLAGS_auto_schedule_use_cost_model) { if (FLAGS_auto_schedule_use_cost_model) {
for (size_t i = 0; i < mutated_individuals.size(); ++i) { for (size_t i = 0; i < mutated_individuals.size(); ++i) {
mutated_individuals[i]->predicted_cost = mutated_individuals[i]->predicted_cost = cost_model_.Predict(
cost_model_.Predict(mutated_individuals[i]->ir_schedule.GetModule(), tune_task_.target); mutated_individuals[i]->ir_schedule.GetModule(), tune_task_.target);
} }
} }
VLOG(4) << JoinStatesDebugString( VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::Evolve: mutated individuals:", mutated_individuals, /*verbose=*/VLOG_IS_ON(5)); "EvolutionarySearch::Evolve: mutated individuals:",
mutated_individuals,
/*verbose=*/VLOG_IS_ON(5));
// select top ret_num with predicted cost // select top ret_num with predicted cost
utils::SizedMultiSet<SearchState> evolution_with_cost(ret_num); utils::SizedMultiSet<SearchState> evolution_with_cost(ret_num);
for (size_t i = 0; i < evolution.size(); ++i) { for (size_t i = 0; i < evolution.size(); ++i) {
...@@ -241,14 +302,18 @@ std::vector<SearchState> EvolutionarySearch::Evolve(const std::vector<SearchStat ...@@ -241,14 +302,18 @@ std::vector<SearchState> EvolutionarySearch::Evolve(const std::vector<SearchStat
for (size_t i = 0; i < mutated_individuals.size(); ++i) { for (size_t i = 0; i < mutated_individuals.size(); ++i) {
evolution_with_cost.Push(mutated_individuals[i]); evolution_with_cost.Push(mutated_individuals[i]);
} }
auto selected_individuals = evolution_with_cost.ReturnAsContainer<std::vector<SearchState>>(); auto selected_individuals =
evolution_with_cost.ReturnAsContainer<std::vector<SearchState>>();
VLOG(4) << JoinStatesDebugString( VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::Evolve: selected individuals:", selected_individuals, /*verbose=*/VLOG_IS_ON(5)); "EvolutionarySearch::Evolve: selected individuals:",
selected_individuals,
/*verbose=*/VLOG_IS_ON(5));
return selected_individuals; return selected_individuals;
} }
std::vector<SearchState> EvolutionarySearch::PickNextGenerationEpsGreedy(const std::vector<SearchState>& picked_bests, std::vector<SearchState> EvolutionarySearch::PickNextGenerationEpsGreedy(
const std::vector<SearchState>& picked_bests,
const std::vector<SearchState>& random_init, const std::vector<SearchState>& random_init,
int num, int num,
float eps_greedy) { float eps_greedy) {
...@@ -276,18 +341,23 @@ std::vector<SearchState> EvolutionarySearch::PickNextGenerationEpsGreedy(const s ...@@ -276,18 +341,23 @@ std::vector<SearchState> EvolutionarySearch::PickNextGenerationEpsGreedy(const s
if (!visited_candidates_.count(selected)) { // deduplicate if (!visited_candidates_.count(selected)) { // deduplicate
VLOG(4) << JoinStatesDebugString( VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::PickNextGenerationEpsGreedy-Selected", {selected}, /*verbose=*/VLOG_IS_ON(5)); "EvolutionarySearch::PickNextGenerationEpsGreedy-Selected",
{selected},
/*verbose=*/VLOG_IS_ON(5));
visited_candidates_.insert(selected); visited_candidates_.insert(selected);
result.push_back(selected); result.push_back(selected);
} else { } else {
++deduplicated_cnt; ++deduplicated_cnt;
VLOG(4) << JoinStatesDebugString( VLOG(4) << JoinStatesDebugString(
"EvolutionarySearch::PickNextGenerationEpsGreedy-Deduplicated", {selected}, /*verbose=*/VLOG_IS_ON(5)); "EvolutionarySearch::PickNextGenerationEpsGreedy-Deduplicated",
{selected},
/*verbose=*/VLOG_IS_ON(5));
} }
} }
VLOG(4) << utils::StringFormat( VLOG(4) << utils::StringFormat(
"PickNextGenerationEpsGreedy: picked_bests size=%lu,random_init size=%lu,num=%d," "PickNextGenerationEpsGreedy: picked_bests size=%lu,random_init "
"size=%lu,num=%d,"
"eps_greedy=%f,deduplicated_cnt=%d,result size=%lu", "eps_greedy=%f,deduplicated_cnt=%d,result size=%lu",
picked_bests.size(), picked_bests.size(),
random_init.size(), random_init.size(),
......
...@@ -41,7 +41,8 @@ class EvolutionarySearch { ...@@ -41,7 +41,8 @@ class EvolutionarySearch {
* @param tune_task: the TuneTask this class works on. This class doesn't * @param tune_task: the TuneTask this class works on. This class doesn't
* take ownership of the pointer. * take ownership of the pointer.
*/ */
EvolutionarySearch(const TuneTask& tune_task, EvolutionarySearch(
const TuneTask& tune_task,
const ExprCostModel& cost_model, const ExprCostModel& cost_model,
Database* database, Database* database,
utils::LinearRandomEngine::StateType rand_seed = -1, utils::LinearRandomEngine::StateType rand_seed = -1,
...@@ -55,14 +56,16 @@ class EvolutionarySearch { ...@@ -55,14 +56,16 @@ class EvolutionarySearch {
/** /**
* Run the evolutionary search for one iteration. * Run the evolutionary search for one iteration.
* *
* @return SearchState containing the best ir::ModuleExpr searched in this iteration * @return SearchState containing the best ir::ModuleExpr searched in this
* iteration
*/ */
SearchState SearchModuleExpr(const TuningOptions& options); SearchState SearchModuleExpr(const TuningOptions& options);
/** /**
* Run the evolutionary search for one iteration. * Run the evolutionary search for one iteration.
* *
* @return SearchState(s) containing best ir::ModuleExpr(s) searched in this iteration * @return SearchState(s) containing best ir::ModuleExpr(s) searched in this
* iteration
*/ */
std::vector<SearchState> SearchModuleExprBests(const TuningOptions& options); std::vector<SearchState> SearchModuleExprBests(const TuningOptions& options);
...@@ -77,7 +80,8 @@ class EvolutionarySearch { ...@@ -77,7 +80,8 @@ class EvolutionarySearch {
* "eps * total_return_size" random samples and * "eps * total_return_size" random samples and
* "(1 - eps) * total_return_size" best searched samples. * "(1 - eps) * total_return_size" best searched samples.
*/ */
std::vector<SearchState> SearchModuleExprEpsGreedy(const TuningOptions& options); std::vector<SearchState> SearchModuleExprEpsGreedy(
const TuningOptions& options);
#ifdef CINN_WITH_TEST #ifdef CINN_WITH_TEST
/** /**
...@@ -87,13 +91,23 @@ class EvolutionarySearch { ...@@ -87,13 +91,23 @@ class EvolutionarySearch {
* @param search_space: the mock search space, note that EvolutionarySearch * @param search_space: the mock search space, note that EvolutionarySearch
* takes the ownership. * takes the ownership.
*/ */
void SetSearchSpace(SearchSpace* search_space) { search_space_.reset(search_space); } void SetSearchSpace(SearchSpace* search_space) {
search_space_.reset(search_space);
}
// Method only be called during testing, it is a wrapper of private method InitSketch(). // Method only be called during testing, it is a wrapper of private method
std::vector<SearchState> TestInitSketch(int num, const std::string& strategy) { return InitSketch(num, strategy); } // InitSketch().
std::vector<SearchState> TestInitSketch(int num,
const std::string& strategy) {
return InitSketch(num, strategy);
}
// Method only be called during testing, it is a wrapper of private method Evolve(). // Method only be called during testing, it is a wrapper of private method
std::vector<SearchState> TestEvolve(const std::vector<SearchState>& population, int cross_over_num, int ret_num) { // Evolve().
std::vector<SearchState> TestEvolve(
const std::vector<SearchState>& population,
int cross_over_num,
int ret_num) {
return Evolve(population, cross_over_num, ret_num); return Evolve(population, cross_over_num, ret_num);
} }
#endif #endif
...@@ -105,23 +119,31 @@ class EvolutionarySearch { ...@@ -105,23 +119,31 @@ class EvolutionarySearch {
* \brief Generate sketch as initial population of evolutionary search. * \brief Generate sketch as initial population of evolutionary search.
* @param num The number of sketches to generate. * @param num The number of sketches to generate.
* @param strategy The strategy to generate sketches, * @param strategy The strategy to generate sketches,
* Current optional strategies are "rule_prune" or "random_prune" or "random". * Current optional strategies are "rule_prune" or "random_prune" or
* - "rule_prune": will use rules to prune and generate sketches as efficiently as possible. * "random".
* - "random_prune": will use the new interface ApplySketchRules() to simulate the random generation of sketches, * - "rule_prune": will use rules to prune and generate sketches as
* and supports the function of a rule returning multiple SearchStates and random pruning by probability. * efficiently as possible.
* - "random": will randomly select a block and a rule to apply and repeat this step several times, * - "random_prune": will use the new interface ApplySketchRules() to simulate
* however, each rule can only be used on one SearchState at most once. * the random generation of sketches, and supports the function of a rule
* returning multiple SearchStates and random pruning by probability.
* - "random": will randomly select a block and a rule to apply and repeat
* this step several times, however, each rule can only be used on one
* SearchState at most once.
* @return Generated sketches. * @return Generated sketches.
*/ */
std::vector<SearchState> InitSketch(int num, const std::string& strategy); std::vector<SearchState> InitSketch(int num, const std::string& strategy);
SearchState Mutate(const SearchState& state, utils::LinearRandomEngine::StateType* rand_seed); SearchState Mutate(const SearchState& state,
utils::LinearRandomEngine::StateType* rand_seed);
SearchState CrossOver(const SearchState& state1, const SearchState& state2); SearchState CrossOver(const SearchState& state1, const SearchState& state2);
std::vector<SearchState> Evolve(const std::vector<SearchState>& population, int cross_over_num, int ret_num); std::vector<SearchState> Evolve(const std::vector<SearchState>& population,
int cross_over_num,
int ret_num);
std::vector<SearchState> PickNextGenerationEpsGreedy(const std::vector<SearchState>& population, std::vector<SearchState> PickNextGenerationEpsGreedy(
const std::vector<SearchState>& population,
const std::vector<SearchState>& random_init, const std::vector<SearchState>& random_init,
int num, int num,
float eps_greedy); float eps_greedy);
...@@ -132,7 +154,8 @@ class EvolutionarySearch { ...@@ -132,7 +154,8 @@ class EvolutionarySearch {
const ExprCostModel& cost_model_; // not owned const ExprCostModel& cost_model_; // not owned
Database* database_; // not owned Database* database_; // not owned
// used to duplicate states with the same structural IR // used to duplicate states with the same structural IR
std::unordered_set<SearchState, SearchStateHash, SearchStateEqual> visited_candidates_; std::unordered_set<SearchState, SearchStateHash, SearchStateEqual>
visited_candidates_;
// mutate rule names and their weights // mutate rule names and their weights
std::vector<std::tuple<std::string, double>> mutators_; std::vector<std::tuple<std::string, double>> mutators_;
// mutate rules, the key is the accumulate weight of each mutate rule // mutate rules, the key is the accumulate weight of each mutate rule
......
...@@ -34,17 +34,23 @@ ...@@ -34,17 +34,23 @@
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
std::vector<TuneTask> CreateTasks(const frontend::Program& program, const Target& target) { std::vector<TuneTask> CreateTasks(const frontend::Program& program,
const Target& target) {
auto graph = std::make_shared<hlir::framework::Graph>(program, target); auto graph = std::make_shared<hlir::framework::Graph>(program, target);
TaskCreator task_creator; TaskCreator task_creator;
auto tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); auto tasks = task_creator.CreateTuneTaskOpLevel(graph.get());
const auto& dtype_dict = graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>("inferdtype"); const auto& dtype_dict =
const auto& shape_dict = graph->GetAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape"); graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target); "inferdtype");
const auto& shape_dict = graph->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target);
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (auto i = 0; i < tasks.size(); ++i) { for (auto i = 0; i < tasks.size(); ++i) {
tasks[i].Initialize(shape_dict, dtype_dict, op_lowerer.get()); tasks[i].Initialize(shape_dict, dtype_dict, op_lowerer.get());
task_registry->Regist(tasks[i].serialized_key, ir::ModuleExpr(tasks[i].GetLoweredFuncBodyExprs())); task_registry->Regist(tasks[i].serialized_key,
ir::ModuleExpr(tasks[i].GetLoweredFuncBodyExprs()));
} }
return tasks; return tasks;
} }
...@@ -64,7 +70,8 @@ class MockSearchSpace : public SearchSpace { ...@@ -64,7 +70,8 @@ class MockSearchSpace : public SearchSpace {
int GetModuleExprSize() const { return module_expr_size_; } int GetModuleExprSize() const { return module_expr_size_; }
std::vector<SearchState> GenerateSketches(int num, const std::string& strategy) override { std::vector<SearchState> GenerateSketches(
int num, const std::string& strategy) override {
std::vector<SearchState> ret; std::vector<SearchState> ret;
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
std::vector<ir::Expr> exprs; std::vector<ir::Expr> exprs;
...@@ -83,7 +90,8 @@ class MockSearchSpace : public SearchSpace { ...@@ -83,7 +90,8 @@ class MockSearchSpace : public SearchSpace {
}; };
class MockCostModel : public ExprCostModel { class MockCostModel : public ExprCostModel {
float Predict(const ir::ModuleExpr& sample, const common::Target& target) const override { float Predict(const ir::ModuleExpr& sample,
const common::Target& target) const override {
float cost = 0.0f; float cost = 0.0f;
std::vector<ir::Expr> exprs = sample.GetExprs(); std::vector<ir::Expr> exprs = sample.GetExprs();
for (const ir::Expr& expr : exprs) { for (const ir::Expr& expr : exprs) {
...@@ -100,7 +108,8 @@ TEST(EvolutionarySearch, GetOneBest) { ...@@ -100,7 +108,8 @@ TEST(EvolutionarySearch, GetOneBest) {
mock_tune_task.serialized_key = "mock_task"; mock_tune_task.serialized_key = "mock_task";
mock_tune_task.target = common::DefaultTarget(); mock_tune_task.target = common::DefaultTarget();
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
task_registry->Regist(mock_tune_task.serialized_key, ir::ModuleExpr({ir::Expr(0)})); task_registry->Regist(mock_tune_task.serialized_key,
ir::ModuleExpr({ir::Expr(0)}));
MockCostModel cost_model; MockCostModel cost_model;
TuningOptions options; TuningOptions options;
Database db(2); Database db(2);
...@@ -122,7 +131,8 @@ TEST(EvolutionarySearch, GetEpsGreedy) { ...@@ -122,7 +131,8 @@ TEST(EvolutionarySearch, GetEpsGreedy) {
mock_tune_task.serialized_key = "mock_task"; mock_tune_task.serialized_key = "mock_task";
mock_tune_task.target = common::DefaultTarget(); mock_tune_task.target = common::DefaultTarget();
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
task_registry->Regist(mock_tune_task.serialized_key, ir::ModuleExpr({ir::Expr(0)})); task_registry->Regist(mock_tune_task.serialized_key,
ir::ModuleExpr({ir::Expr(0)}));
ExprCostModel cost_model; ExprCostModel cost_model;
TuningOptions options; TuningOptions options;
Database db(2); Database db(2);
...@@ -131,10 +141,12 @@ TEST(EvolutionarySearch, GetEpsGreedy) { ...@@ -131,10 +141,12 @@ TEST(EvolutionarySearch, GetEpsGreedy) {
MockSearchSpace* mock_search_space = new MockSearchSpace(mock_tune_task); MockSearchSpace* mock_search_space = new MockSearchSpace(mock_tune_task);
// Ownership is transferred so don't delete mock_search_space // Ownership is transferred so don't delete mock_search_space
evolutionary_search.SetSearchSpace(mock_search_space); evolutionary_search.SetSearchSpace(mock_search_space);
std::vector<SearchState> search_states = evolutionary_search.SearchModuleExprEpsGreedy(options); std::vector<SearchState> search_states =
evolutionary_search.SearchModuleExprEpsGreedy(options);
EXPECT_GE(search_states.size(), 1UL); EXPECT_GE(search_states.size(), 1UL);
size_t expr_size = static_cast<size_t>(mock_search_space->GetModuleExprSize()); size_t expr_size =
static_cast<size_t>(mock_search_space->GetModuleExprSize());
for (const SearchState& state : search_states) { for (const SearchState& state : search_states) {
EXPECT_EQ(state->ir_schedule.GetModule().GetExprs().size(), expr_size); EXPECT_EQ(state->ir_schedule.GetModule().GetExprs().size(), expr_size);
} }
...@@ -142,7 +154,9 @@ TEST(EvolutionarySearch, GetEpsGreedy) { ...@@ -142,7 +154,9 @@ TEST(EvolutionarySearch, GetEpsGreedy) {
TEST(EvolutionarySearch, Evolve) { TEST(EvolutionarySearch, Evolve) {
auto target = common::DefaultNVGPUTarget(); auto target = common::DefaultNVGPUTarget();
auto tasks = CreateTasks(tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}}), target); auto tasks = CreateTasks(
tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}}),
target);
CHECK_EQ(tasks.size(), 1); CHECK_EQ(tasks.size(), 1);
ExprCostModel cost_model; ExprCostModel cost_model;
std::vector<const ir::ModuleExpr*> cost_model_samples(1); std::vector<const ir::ModuleExpr*> cost_model_samples(1);
...@@ -161,7 +175,8 @@ TEST(EvolutionarySearch, Evolve) { ...@@ -161,7 +175,8 @@ TEST(EvolutionarySearch, Evolve) {
EvolutionarySearch evolutionary_search(tasks[0], cost_model, &db); EvolutionarySearch evolutionary_search(tasks[0], cost_model, &db);
int num_population = 10; int num_population = 10;
std::vector<SearchState> init_sketch = evolutionary_search.TestInitSketch(num_population, "rule_prune"); std::vector<SearchState> init_sketch =
evolutionary_search.TestInitSketch(num_population, "rule_prune");
for (int i = 0; i < num_population; ++i) { for (int i = 0; i < num_population; ++i) {
ir::ModuleExpr me(init_sketch[i]->ir_schedule.GetModule()); ir::ModuleExpr me(init_sketch[i]->ir_schedule.GetModule());
cost_model_samples[0] = &me; cost_model_samples[0] = &me;
...@@ -172,10 +187,12 @@ TEST(EvolutionarySearch, Evolve) { ...@@ -172,10 +187,12 @@ TEST(EvolutionarySearch, Evolve) {
for (auto s : init_sketch) { for (auto s : init_sketch) {
VLOG(6) << "cost = " << s->predicted_cost; VLOG(6) << "cost = " << s->predicted_cost;
} }
std::vector<SearchState>*population_pre_ptr = &init_sketch, *population_next_ptr; std::vector<SearchState>*population_pre_ptr = &init_sketch,
*population_next_ptr;
std::vector<SearchState> population; std::vector<SearchState> population;
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
population = evolutionary_search.TestEvolve(*population_pre_ptr, /*cross_over_num*/ 0, /*ret_num*/ 10); population = evolutionary_search.TestEvolve(
*population_pre_ptr, /*cross_over_num*/ 0, /*ret_num*/ 10);
population_next_ptr = &population; population_next_ptr = &population;
VLOG(6) << "population[" << i + 1 << "] costs:"; VLOG(6) << "population[" << i + 1 << "] costs:";
double total_cost_pre = 0.0, total_cost_next = 0.0; double total_cost_pre = 0.0, total_cost_next = 0.0;
......
...@@ -34,11 +34,14 @@ class MutateRule { ...@@ -34,11 +34,14 @@ class MutateRule {
* @param rand_seed The random seed for mutation. * @param rand_seed The random seed for mutation.
* @return The mutated trace. * @return The mutated trace.
*/ */
virtual ir::ScheduleDesc Apply(const ir::ScheduleDesc& trace, utils::LinearRandomEngine::StateType* rand_seed) = 0; virtual ir::ScheduleDesc Apply(
const ir::ScheduleDesc& trace,
utils::LinearRandomEngine::StateType* rand_seed) = 0;
/** /**
* @brief Create a MutateRule with name. * @brief Create a MutateRule with name.
* @param name The name of mutate rule, consisting of lowercase letters and underscores * @param name The name of mutate rule, consisting of lowercase letters and
* underscores
* @return The created MutateRule. * @return The created MutateRule.
*/ */
static std::unique_ptr<MutateRule> Make(const std::string& name); static std::unique_ptr<MutateRule> Make(const std::string& name);
......
...@@ -44,8 +44,11 @@ std::vector<SampledTile> FindSampledTiles(const ScheduleDesc& trace) { ...@@ -44,8 +44,11 @@ std::vector<SampledTile> FindSampledTiles(const ScheduleDesc& trace) {
break; break;
} }
if (step.type == "SamplePerfectTile") { if (step.type == "SamplePerfectTile") {
std::vector<int> tile_factors = absl::get<std::vector<int>>(step.attrs.at("decision")); std::vector<int> tile_factors =
CHECK(tile_factors.size() >= 2) << "factors size must be greater equal than 2, which is " << tile_factors.size(); absl::get<std::vector<int>>(step.attrs.at("decision"));
CHECK(tile_factors.size() >= 2)
<< "factors size must be greater equal than 2, which is "
<< tile_factors.size();
tiles.push_back(std::make_tuple(step, tile_factors, step_idx)); tiles.push_back(std::make_tuple(step, tile_factors, step_idx));
} }
++step_idx; ++step_idx;
...@@ -89,10 +92,13 @@ ScheduleDesc DoMutateTileSize(const ScheduleDesc& trace, ...@@ -89,10 +92,13 @@ ScheduleDesc DoMutateTileSize(const ScheduleDesc& trace,
// Step 2. Choose the divisor for mutate. // Step 2. Choose the divisor for mutate.
int divisor; int divisor;
if (loop_y == split_size - 1) { if (loop_y == split_size - 1) {
int max_innermost_factor = absl::get<int>(step.attrs.at("max_innermost_factor")); int max_innermost_factor =
absl::get<int>(step.attrs.at("max_innermost_factor"));
int max_optional_factor_idx = optional_factors.size() - 1; int max_optional_factor_idx = optional_factors.size() - 1;
for (; max_optional_factor_idx > 0; --max_optional_factor_idx) { for (; max_optional_factor_idx > 0; --max_optional_factor_idx) {
if (optional_factors.at(max_optional_factor_idx) * tile_factors.at(loop_y) <= max_innermost_factor) { if (optional_factors.at(max_optional_factor_idx) *
tile_factors.at(loop_y) <=
max_innermost_factor) {
break; break;
} }
} }
...@@ -103,27 +109,32 @@ ScheduleDesc DoMutateTileSize(const ScheduleDesc& trace, ...@@ -103,27 +109,32 @@ ScheduleDesc DoMutateTileSize(const ScheduleDesc& trace,
} }
continue; continue;
} }
divisor = optional_factors.at(utils::SampleUniformInt(1, max_optional_factor_idx + 1, rand_seed)); divisor = optional_factors.at(
utils::SampleUniformInt(1, max_optional_factor_idx + 1, rand_seed));
} else { } else {
divisor = optional_factors.at(utils::SampleUniformInt(1, optional_factors.size(), rand_seed)); divisor = optional_factors.at(
utils::SampleUniformInt(1, optional_factors.size(), rand_seed));
} }
// Step 3. Determine the new tile value // Step 3. Determine the new tile value
VLOG(6) << "DoMutateTileSize: divisor = " << divisor << ", before mutate: \n" VLOG(6) << "DoMutateTileSize: divisor = " << divisor
<< "factors[" << loop_x << "] = " << tile_factors[loop_x] << ", factors[" << loop_y << ", before mutate: \n"
<< "] = " << tile_factors[loop_y]; << "factors[" << loop_x << "] = " << tile_factors[loop_x]
<< ", factors[" << loop_y << "] = " << tile_factors[loop_y];
tile_factors[loop_x] /= divisor; tile_factors[loop_x] /= divisor;
tile_factors[loop_y] *= divisor; tile_factors[loop_y] *= divisor;
VLOG(6) << "after mutate: \n" VLOG(6) << "after mutate: \n"
<< "factors[" << loop_x << "] = " << tile_factors[loop_x] << ", factors[" << loop_y << "factors[" << loop_x << "] = " << tile_factors[loop_x]
<< "] = " << tile_factors[loop_y]; << ", factors[" << loop_y << "] = " << tile_factors[loop_y];
// Step 4. Create a new step with new tile values and return the new trace // Step 4. Create a new step with new tile values and return the new trace
int step_idx = std::get<2>(tile); int step_idx = std::get<2>(tile);
return trace.ForkAndUpdate(step_idx, tile_factors, true); return trace.ForkAndUpdate(step_idx, tile_factors, true);
} }
} }
ScheduleDesc MutateTileSize::Apply(const ScheduleDesc& trace, LinearRandomEngine::StateType* rand_seed) { ScheduleDesc MutateTileSize::Apply(const ScheduleDesc& trace,
VLOG(6) << "Start applying MutateTileSize, old trace: \n" << trace.DebugString(); LinearRandomEngine::StateType* rand_seed) {
VLOG(6) << "Start applying MutateTileSize, old trace: \n"
<< trace.DebugString();
std::vector<ScheduleDesc::Step> sample_tile_steps; std::vector<ScheduleDesc::Step> sample_tile_steps;
std::vector<std::vector<int>> sample_tile_data; std::vector<std::vector<int>> sample_tile_data;
...@@ -132,9 +143,12 @@ ScheduleDesc MutateTileSize::Apply(const ScheduleDesc& trace, LinearRandomEngine ...@@ -132,9 +143,12 @@ ScheduleDesc MutateTileSize::Apply(const ScheduleDesc& trace, LinearRandomEngine
VLOG(6) << "MutateTileSize failed, try other mutate rules."; VLOG(6) << "MutateTileSize failed, try other mutate rules.";
return trace; return trace;
} }
int sample_step_idx = utils::SampleUniformInt(0, sampled_tiles.size(), rand_seed); int sample_step_idx =
auto new_trace = DoMutateTileSize(trace, sampled_tiles.at(sample_step_idx), rand_seed); utils::SampleUniformInt(0, sampled_tiles.size(), rand_seed);
VLOG(6) << "End applying MutateTileSize, new trace: \n" << new_trace.DebugString(); auto new_trace =
DoMutateTileSize(trace, sampled_tiles.at(sample_step_idx), rand_seed);
VLOG(6) << "End applying MutateTileSize, new trace: \n"
<< new_trace.DebugString();
return new_trace; return new_trace;
} }
......
...@@ -20,13 +20,16 @@ namespace cinn { ...@@ -20,13 +20,16 @@ namespace cinn {
namespace auto_schedule { namespace auto_schedule {
/** /**
* The rule to mutate tile size, witch will modify the factors of the Split primitive. * The rule to mutate tile size, witch will modify the factors of the Split
* primitive.
*/ */
class MutateTileSize : public MutateRule { class MutateTileSize : public MutateRule {
public: public:
MutateTileSize() = default; MutateTileSize() = default;
ir::ScheduleDesc Apply(const ir::ScheduleDesc& trace, utils::LinearRandomEngine::StateType* rand_seed) override; ir::ScheduleDesc Apply(
const ir::ScheduleDesc& trace,
utils::LinearRandomEngine::StateType* rand_seed) override;
}; };
} // namespace auto_schedule } // namespace auto_schedule
......
...@@ -42,17 +42,27 @@ TEST(MutateTileSize, Basic) { ...@@ -42,17 +42,27 @@ TEST(MutateTileSize, Basic) {
Var k(K.as_int32(), "reduce_axis_k"); Var k(K.as_int32(), "reduce_axis_k");
ir::Tensor C = Compute( ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); {M, N},
[&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
poly::StageMap stages = CreateStages({A, B, C}); poly::StageMap stages = CreateStages({A, B, C});
std::vector<ir::LoweredFunc> funcs = std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestMutateTileSize_Basic", stages, {A, B, C}, {}, {}, nullptr, target, true); lang::LowerVec("TestMutateTileSize_Basic",
stages,
{A, B, C},
{},
{},
nullptr,
target,
true);
ir::Expr ast_expr = funcs[0]->body; ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Original Expr: "; VLOG(6) << "Original Expr: ";
VLOG(6) << ast_expr; VLOG(6) << ast_expr;
ir::ModuleExpr module_expr({ast_expr}); ir::ModuleExpr module_expr({ast_expr});
// We need to fix the seed as a constant to ensure that the result can be repeated. // We need to fix the seed as a constant to ensure that the result can be
// repeated.
utils::LinearRandomEngine::StateType rand_seed = 123; utils::LinearRandomEngine::StateType rand_seed = 123;
ir::IRSchedule ir_schedule(module_expr, rand_seed); ir::IRSchedule ir_schedule(module_expr, rand_seed);
ir::IRSchedule new_ir_schedule(ir_schedule); ir::IRSchedule new_ir_schedule(ir_schedule);
...@@ -64,10 +74,13 @@ TEST(MutateTileSize, Basic) { ...@@ -64,10 +74,13 @@ TEST(MutateTileSize, Basic) {
// apply mutate // apply mutate
MutateTileSize mutator; MutateTileSize mutator;
ir::ScheduleDesc sch_desc = mutator.Apply(ir_schedule.GetTraceDesc(), &rand_seed); ir::ScheduleDesc sch_desc =
mutator.Apply(ir_schedule.GetTraceDesc(), &rand_seed);
sch_desc.Replay(&new_ir_schedule, true); sch_desc.Replay(&new_ir_schedule, true);
VLOG(6) << "Expr before mutate tile size: \n" << ir_schedule.GetModule().GetExprs()[0]; VLOG(6) << "Expr before mutate tile size: \n"
VLOG(6) << "Expr after mutate tile size: \n" << new_ir_schedule.GetModule().GetExprs()[0]; << ir_schedule.GetModule().GetExprs()[0];
VLOG(6) << "Expr after mutate tile size: \n"
<< new_ir_schedule.GetModule().GetExprs()[0];
std::string target_new_ir = R"ROC({ std::string target_new_ir = R"ROC({
ScheduleBlock(root) ScheduleBlock(root)
...@@ -111,7 +124,8 @@ TEST(MutateTileSize, Basic) { ...@@ -111,7 +124,8 @@ TEST(MutateTileSize, Basic) {
sch_desc = mutator.Apply(sch_desc, &rand_seed); sch_desc = mutator.Apply(sch_desc, &rand_seed);
for (auto&& step : sch_desc.Steps()) { for (auto&& step : sch_desc.Steps()) {
if (step.type == "SamplePerfectTile") { if (step.type == "SamplePerfectTile") {
std::vector<int> tile_factors = absl::get<std::vector<int>>(step.attrs.at("decision")); std::vector<int> tile_factors =
absl::get<std::vector<int>>(step.attrs.at("decision"));
ASSERT_EQ(tile_factors.size(), last_tile_factors.size()); ASSERT_EQ(tile_factors.size(), last_tile_factors.size());
ASSERT_NE(tile_factors[0], last_tile_factors[0]); ASSERT_NE(tile_factors[0], last_tile_factors[0]);
ASSERT_NE(tile_factors[1], last_tile_factors[1]); ASSERT_NE(tile_factors[1], last_tile_factors[1]);
......
...@@ -36,7 +36,8 @@ using ::cinn::hlir::framework::NodeData; ...@@ -36,7 +36,8 @@ using ::cinn::hlir::framework::NodeData;
std::vector<TuneTask> TaskCreator::CreateTuneTaskOpLevel(Graph* graph) { std::vector<TuneTask> TaskCreator::CreateTuneTaskOpLevel(Graph* graph) {
std::vector<TuneTask> ret_tasks; std::vector<TuneTask> ret_tasks;
const std::vector<std::shared_ptr<Graph::Group>>* groups = &graph->fusion_groups; const std::vector<std::shared_ptr<Graph::Group>>* groups =
&graph->fusion_groups;
std::vector<std::shared_ptr<Graph::Group>> non_fused_groups; std::vector<std::shared_ptr<Graph::Group>> non_fused_groups;
// The input graph doesn't run Op Fusion // The input graph doesn't run Op Fusion
if (graph->fusion_groups.empty()) { if (graph->fusion_groups.empty()) {
......
...@@ -45,7 +45,8 @@ class TaskOptimizer { ...@@ -45,7 +45,8 @@ class TaskOptimizer {
std::string from; std::string from;
double cost; double cost;
FunctionGroup functions; FunctionGroup functions;
Result(const std::string& from_type) : from(from_type), cost(std::numeric_limits<double>::max()) {} Result(const std::string& from_type)
: from(from_type), cost(std::numeric_limits<double>::max()) {}
}; };
Result OptimizeByManual(bool need_measure); Result OptimizeByManual(bool need_measure);
...@@ -53,7 +54,9 @@ class TaskOptimizer { ...@@ -53,7 +54,9 @@ class TaskOptimizer {
Result OptimizeByEvolution(const TuningOptions& options); Result OptimizeByEvolution(const TuningOptions& options);
// call search candidates once by EvolutionarySearch and prune invalid ones // call search candidates once by EvolutionarySearch and prune invalid ones
std::vector<SearchState> SearchOneRound(const TuningOptions& options, std::vector<MeasureInput>* measure_candidates); std::vector<SearchState> SearchOneRound(
const TuningOptions& options,
std::vector<MeasureInput>* measure_candidates);
private: private:
// the max retry times if continuously get empty result // the max retry times if continuously get empty result
......
...@@ -31,7 +31,8 @@ struct InitialTaskInfo { ...@@ -31,7 +31,8 @@ struct InitialTaskInfo {
std::string task_key; std::string task_key;
ir::ModuleExpr module_expr; ir::ModuleExpr module_expr;
InitialTaskInfo(const std::string& task_key, const ir::ModuleExpr& module_expr) InitialTaskInfo(const std::string& task_key,
const ir::ModuleExpr& module_expr)
: task_key(task_key), module_expr(module_expr) {} : task_key(task_key), module_expr(module_expr) {}
}; };
...@@ -45,19 +46,25 @@ class InitialTaskRegistry : public Registry<InitialTaskInfo> { ...@@ -45,19 +46,25 @@ class InitialTaskRegistry : public Registry<InitialTaskInfo> {
// Get the initial ModuleExpr of a task. // Get the initial ModuleExpr of a task.
inline const InitialTaskInfo* Get(const std::string& task_key) { inline const InitialTaskInfo* Get(const std::string& task_key) {
const InitialTaskInfo* task_info = Registry<InitialTaskInfo>::Find(task_key); const InitialTaskInfo* task_info =
CHECK(task_info) << "InitialTaskInfo [" << task_key << "] is not registered"; Registry<InitialTaskInfo>::Find(task_key);
CHECK(task_info) << "InitialTaskInfo [" << task_key
<< "] is not registered";
return task_info; return task_info;
} }
// Check if the task info with task_key exists; // Check if the task info with task_key exists;
inline const bool Has(const std::string& task_key) { return nullptr != Registry<InitialTaskInfo>::Find(task_key); } inline const bool Has(const std::string& task_key) {
return nullptr != Registry<InitialTaskInfo>::Find(task_key);
}
// Regist the initial ModuleExpr of a task into the map // Regist the initial ModuleExpr of a task into the map
inline void Regist(const std::string& task_key, const ir::ModuleExpr& module_expr) { inline void Regist(const std::string& task_key,
const ir::ModuleExpr& module_expr) {
std::lock_guard<std::mutex> guard(registering_mutex); std::lock_guard<std::mutex> guard(registering_mutex);
if (fmap_.count(task_key) == 0) { if (fmap_.count(task_key) == 0) {
InitialTaskInfo* task_info = new InitialTaskInfo(task_key, optim::IRCopy(module_expr)); InitialTaskInfo* task_info =
new InitialTaskInfo(task_key, optim::IRCopy(module_expr));
__REGISTER__(task_key, task_info); __REGISTER__(task_key, task_info);
} }
} }
...@@ -67,7 +74,8 @@ class InitialTaskRegistry : public Registry<InitialTaskInfo> { ...@@ -67,7 +74,8 @@ class InitialTaskRegistry : public Registry<InitialTaskInfo> {
CINN_DISALLOW_COPY_AND_ASSIGN(InitialTaskRegistry); CINN_DISALLOW_COPY_AND_ASSIGN(InitialTaskRegistry);
// Regist the initial ModuleExpr of a task. // Regist the initial ModuleExpr of a task.
inline InitialTaskInfo* __REGISTER__(const std::string& task_key, InitialTaskInfo* task_info) { inline InitialTaskInfo* __REGISTER__(const std::string& task_key,
InitialTaskInfo* task_info) {
fmap_[task_key] = task_info; fmap_[task_key] = task_info;
const_list_.push_back(task_info); const_list_.push_back(task_info);
entry_list_.push_back(task_info); entry_list_.push_back(task_info);
......
...@@ -27,7 +27,9 @@ int EfficiencyPriority::NextTaskId() { ...@@ -27,7 +27,9 @@ int EfficiencyPriority::NextTaskId() {
return -1; return -1;
} }
bool EfficiencyPriority::IsTaskToTune(const TuneTask* task) { return config_.minimum_gain_threshold > 0.0; } bool EfficiencyPriority::IsTaskToTune(const TuneTask* task) {
return config_.minimum_gain_threshold > 0.0;
}
} // namespace auto_schedule } // namespace auto_schedule
} // namespace cinn } // namespace cinn
...@@ -25,7 +25,8 @@ namespace auto_schedule { ...@@ -25,7 +25,8 @@ namespace auto_schedule {
// is picking a task with the maximum earnings ratio. // is picking a task with the maximum earnings ratio.
class EfficiencyPriority : public TaskScheduler { class EfficiencyPriority : public TaskScheduler {
public: public:
EfficiencyPriority(const std::vector<TuneTask>& tasks, const Config& config) : TaskScheduler(tasks, config) {} EfficiencyPriority(const std::vector<TuneTask>& tasks, const Config& config)
: TaskScheduler(tasks, config) {}
const char* Name() const override { return "efficiency_priority"; }; const char* Name() const override { return "efficiency_priority"; };
......
...@@ -25,7 +25,8 @@ namespace auto_schedule { ...@@ -25,7 +25,8 @@ namespace auto_schedule {
// is picking a task to tune once a time iteratively. // is picking a task to tune once a time iteratively.
class RoundRobin : public TaskScheduler { class RoundRobin : public TaskScheduler {
public: public:
RoundRobin(const std::vector<TuneTask>& tasks, const Config& config) : TaskScheduler(tasks, config) {} RoundRobin(const std::vector<TuneTask>& tasks, const Config& config)
: TaskScheduler(tasks, config) {}
const char* Name() const override { return "round_robin"; }; const char* Name() const override { return "round_robin"; };
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册